In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    """
    A basic residual block with two fully connected layers and a skip connection.
    """
    def __init__(self, in_features, hidden_features=None):
        super(ResidualBlock, self).__init__()
        if hidden_features is None:
            hidden_features = in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.activation = nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.activation(self.fc1(x))
        out = self.fc2(out)
        out += residual  # Residual connection
        out = self.activation(out)
        return out

class Encoder(nn.Module):
    """
    Encoder module that downsamples the input vector using linear layers followed by Residual Blocks.
    """
    def __init__(self, input_dim, hidden_dims):
        super(Encoder, self).__init__()
        layers = []
        self.down_layers = nn.ModuleList()
        last_dim = input_dim
        for h_dim in hidden_dims:
            # Downsample the vector dimension using a linear layer
            self.down_layers.append(nn.Linear(last_dim, h_dim))
            layers.append(ResidualBlock(h_dim))
            last_dim = h_dim
        self.res_blocks = nn.ModuleList(layers)

    def forward(self, x):
        skip_connections = []
        for down, block in zip(self.down_layers, self.res_blocks):
            x = F.relu(down(x))
            x = block(x)
            skip_connections.append(x)
        return x, skip_connections

class Decoder(nn.Module):
    """
    Decoder module that upsamples the encoded vector using linear layers and integrates skip connections.
    """
    def __init__(self, output_dim, hidden_dims):
        super(Decoder, self).__init__()
        self.up_layers = nn.ModuleList()
        self.res_blocks = nn.ModuleList()
        reversed_dims = list(reversed(hidden_dims))
        last_dim = reversed_dims[0]
        for h_dim in reversed_dims[1:]:
            self.up_layers.append(nn.Linear(last_dim, h_dim))
            self.res_blocks.append(ResidualBlock(h_dim))
            last_dim = h_dim
        self.final_layer = nn.Linear(last_dim, output_dim)

    def forward(self, x, skip_connections):
        # Reverse the skip connections to match the decoder order
        skip_connections = skip_connections[::-1]
        for up, block, skip in zip(self.up_layers, self.res_blocks, skip_connections[:-1]):
            x = F.relu(up(x))
            # Combine with corresponding skip connection (concatenation or addition)
            # Here we use element-wise addition assuming matching dimensions.
            x = x + skip  
            x = block(x)
        x = self.final_layer(x)
        return x

class ResidualMLPUnet(nn.Module):
    """
    A simplified Residual MLP UNet for fixed-size vector data.
    """
    def __init__(self, input_dim, hidden_dims):
        super(ResidualMLPUnet, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dims)
        self.decoder = Decoder(input_dim, hidden_dims)

    def forward(self, x):
        # Encode the input vector and store skip connections
        encoded, skip_connections = self.encoder(x)
        # Decode using the skip connections to recover the original dimension
        out = self.decoder(encoded, skip_connections)
        return out

# Example usage:
if __name__ == "__main__":
    # Define input vector dimension and hidden dimensions for the encoder/decoder
    input_dim = 128
    hidden_dims = [256, 512, 256]  # Example configuration: encoder will go from 128->256->512->256

    # Create model instance
    model = ResidualMLPUnet(input_dim=input_dim, hidden_dims=hidden_dims)
    
    # Generate a random input vector batch of size (batch_size, input_dim)
    x = torch.randn(16, input_dim)
    
    # Forward pass
    output = model(x)
    
    print("Input shape:", x.shape)
    print("Output shape:", output.shape)


RuntimeError: The size of tensor a (512) must match the size of tensor b (256) at non-singleton dimension 1

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class ResidualBlock(nn.Module):
    """
    A basic residual block with two fully connected layers and a skip connection.
    """
    def __init__(self, in_features, hidden_features=None):
        super(ResidualBlock, self).__init__()
        if hidden_features is None:
            hidden_features = in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.activation = nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.activation(self.fc1(x))
        out = self.fc2(out)
        out += residual  # Residual connection
        out = self.activation(out)
        return out


class ResUnet(nn.Module):
    """
    A simplified Residual MLP UNet for fixed-size vector data
    """

    def __init__(self, input_dim, bottleneck_dim, down_up_factor=2):
        """
        Args:
            input_dim (int): The size of the input vector.
            bottleneck_dim (int): The size of the bottleneck vector.
            down_up_factor (int): The factor by which dimensions are divided when down-sampling
                                  and multiplied when up-sampling.
        """
        super().__init__()

        # 1) Generate down-sampling dimensions from input_dim -> bottleneck_dim
        self.down_dims = self._compute_dims(input_dim, bottleneck_dim, down_up_factor)

        # 2) Generate up-sampling dimensions by reversing the down dims
        #    e.g., if down_dims = [input_dim, ..., bottleneck_dim]
        #    then up_dims = [bottleneck_dim, ..., input_dim]
        self.up_dims = list(reversed(self.down_dims))

        # 3) Create down-sampling (encoder) layers and associated residual blocks
        self.down_layers = nn.ModuleList()
        self.down_resblocks = nn.ModuleList()
        for i in range(len(self.down_dims) - 1):
            in_dim = self.down_dims[i]
            out_dim = self.down_dims[i + 1]
            self.down_layers.append(nn.Linear(in_dim, out_dim))
            self.down_resblocks.append(ResidualBlock(out_dim))

        # 4) Create up-sampling (decoder) layers and associated residual blocks
        self.up_layers = nn.ModuleList()
        self.up_resblocks = nn.ModuleList()
        for i in range(len(self.up_dims) - 1):
            in_dim = self.up_dims[i]
            out_dim = self.up_dims[i + 1]
            self.up_layers.append(nn.Linear(in_dim, out_dim))
            self.up_resblocks.append(ResidualBlock(out_dim))

    def _compute_dims(self, start_dim, bottleneck_dim, factor):
        """
        Compute the list of dimensions from start_dim down to bottleneck_dim
        by dividing by 'factor' at each step (rounding up), ensuring it does
        not go below bottleneck_dim.
        """
        dims = [start_dim]
        while dims[-1] > bottleneck_dim:
            # Divide by factor and round up
            next_dim = math.ceil(dims[-1] / factor)
            if next_dim < bottleneck_dim:
                next_dim = bottleneck_dim
            dims.append(next_dim)
        return dims

    def forward(self, x):
        # ---------------------
        # Down-sampling path
        # ---------------------
        skip_connections = []
        for linear_layer, resblock in zip(self.down_layers, self.down_resblocks):
            x = F.relu(linear_layer(x))
            x = resblock(x)
            # Save each "down" output for skip connection
            skip_connections.append(x)

        # ---------------------
        # Up-sampling path
        # ---------------------
        # Note: skip_connections is from smallest to largest in the forward pass,
        # but we need them in reverse order for the up path.
        skip_connections = skip_connections[::-1]

        # We traverse up_layers normally, but each time we add the corresponding
        # skip from the down path (with matching dimension).
        for i, (linear_layer, resblock) in enumerate(zip(self.up_layers, self.up_resblocks)):
            x = F.relu(linear_layer(x))
            # Add skip connection (element-wise addition)
            x = x + skip_connections[i + 1]  # +1 to skip the bottleneck skip
            x = resblock(x)

        # The final output of x should match input_dim in size since
        # self.up_dims[-1] is the original input dimension.
        return x


In [6]:
model = ResidualMLPUnet(input_dim=128, bottleneck_dim=31, down_up_factor=1.5)

[128, 86, 58, 39, 31]