In [1]:
import torch
import torch.nn as nn

In [2]:
#Define the Base Network Class
class BaseNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super(BaseNetwork, self).__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.linear(x)

In [3]:
# Define the factory function
def create_network_instances(num_instances, input_size, output_size):
    networks = []
    for _ in range(num_instances):
        network_instance = BaseNetwork(input_size, output_size)
        networks.append(network_instance)
    return networks

In [4]:
# Define concactenated network class
class ConcatenatedNetwork(nn.Module):
    def __init__(self, num_instances, input_size, output_size):
        super(ConcatenatedNetwork, self).__init__()
        self.networks = create_network_instances(num_instances, input_size, output_size)

    def forward(self, x):
        outputs = [network(x) for network in self.networks]
        concatenated_output = torch.cat(outputs, dim=1)
        return concatenated_output

In [5]:
# Create instance of concactenated network
# Example usage:
num_instances = 3
input_size = 10
output_size = 5

concatenated_network = ConcatenatedNetwork(num_instances, input_size, output_size)

In [6]:
# Input tensor for testing
input_tensor = torch.randn((2, input_size))  # Batch size of 2, input size of 10

In [7]:
# Forward pass through concactenated network
output = concatenated_network(input_tensor)

In [8]:
# Print input and output sizes
print("Input size:", input_tensor.shape)
print("Output size:", output.shape)

Input size: torch.Size([2, 10])
Output size: torch.Size([2, 15])
