In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class WeightedSum(nn.Module):
    def __init__(self, num_tensors, tensor_shape):
        super(WeightedSum, self).__init__()
        self.num_tensors = num_tensors
        self.tensor_shape = tensor_shape
        self.weights = nn.Parameter(torch.randn(num_tensors))
        nn.init.xavier_uniform_(self.weights.unsqueeze(0))

    def forward(self, *tensors):
        assert len(tensors) == self.num_tensors, "Number of input tensors must match num_tensors"
        weights = F.softmax(self.weights, dim=0)
        weighted_sum = sum(w * t for w, t in zip(weights, tensors))
        return weighted_sum

In [3]:
B, T, D = 4, 5, 6  
num_tensors = 3    

model = WeightedSum(num_tensors=num_tensors, tensor_shape=(B, T, D))

tensors = [torch.randn(B, T, D) for _ in range(num_tensors)]

output = model(*tensors)
print(output.shape)  # -> (B, T, D)

torch.Size([4, 5, 6])
