In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class FlexibleCVAE3D(nn.Module):
    def __init__(self, input_shape, latent_dim, conv_layers=3, fc_layers=2):
        super(FlexibleCVAE3D, self).__init__()
        self.input_shape = input_shape
        self.latent_dim = latent_dim
        self.conv_layers = conv_layers
        self.fc_layers = fc_layers

        # Calculate dimensions for each layer
        self.encoder_dims = self.calculate_encoder_dims()
        self.decoder_dims = self.calculate_decoder_dims()

        # Build encoder
        self.encoder = self.build_encoder()
        
        # Build decoder
        self.decoder = self.build_decoder()

        print("Encoder dimensions:", self.encoder_dims)
        print("Decoder dimensions:", self.decoder_dims)

    def calculate_encoder_dims(self):
        dims = [self.input_shape]
        channels = [1, 32, 64, 128]  # You can adjust this list as needed
        for i in range(self.conv_layers):
            h = dims[-1][1] // 2
            w = dims[-1][2] // 2
            d = dims[-1][3] // 2 if i < 2 else dims[-1][3]  # Only reduce depth in first two layers
            c = channels[i+1] if i+1 < len(channels) else channels[-1]
            dims.append((c, h, w, d))
        return dims

    def calculate_decoder_dims(self):
        dims = self.encoder_dims[::-1]
        for i in range(len(dims) - 1):
            dims[i+1] = (dims[i+1][0], dims[i][1] * 2, dims[i][2] * 2, dims[i][3] * 2 if i < 2 else dims[i][3])
        return dims

    def build_encoder(self):
        layers = []
        for i in range(self.conv_layers):
            layers.append(nn.Conv3d(self.encoder_dims[i][0], self.encoder_dims[i+1][0], kernel_size=3, stride=2, padding=1))
            layers.append(nn.ReLU())
        
        layers.append(nn.Flatten())
        
        fc_input_dim = math.prod(self.encoder_dims[-1])
        fc_dims = [fc_input_dim] + [512] * (self.fc_layers - 1) + [self.latent_dim * 2]
        
        for i in range(self.fc_layers):
            layers.append(nn.Linear(fc_dims[i], fc_dims[i+1]))
            if i < self.fc_layers - 1:
                layers.append(nn.ReLU())
        
        return nn.Sequential(*layers)

    def build_decoder(self):
        layers = []
        
        fc_output_dim = math.prod(self.decoder_dims[0])
        fc_dims = [self.latent_dim] + [512] * (self.fc_layers - 1) + [fc_output_dim]
        
        for i in range(self.fc_layers):
            layers.append(nn.Linear(fc_dims[i], fc_dims[i+1]))
            layers.append(nn.ReLU())
        
        layers.append(nn.Unflatten(1, self.decoder_dims[0]))
        
        for i in range(self.conv_layers - 1):
            layers.append(nn.ConvTranspose3d(self.decoder_dims[i][0], self.decoder_dims[i+1][0], 
                                             kernel_size=3, stride=2, padding=1, output_padding=1))
            layers.append(nn.ReLU())
        
        layers.append(nn.ConvTranspose3d(self.decoder_dims[-2][0], self.decoder_dims[-1][0], 
                                         kernel_size=3, stride=2, padding=1, output_padding=1))
        
        return nn.Sequential(*layers)

    def encode(self, x):
        h = self.encoder(x)
        return torch.chunk(h, 2, dim=1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Usage example
input_shape = (1, 40, 40, 200)  # (channels, height, width, depth)
latent_dim = 256
model = FlexibleCVAE3D(input_shape, latent_dim)

# Print model summary
print(model)

# Example forward pass
x = torch.randn(1, *input_shape)
print("Input shape:", x.shape)

# Debug: print shapes at each layer
for name, layer in model.encoder.named_children():
    x = layer(x)
    print(f"After {name}: {x.shape}")

output, mu, logvar = model(torch.randn(1, *input_shape))
print(f"Output shape: {output.shape}")
print(f"Mu shape: {mu.shape}")
print(f"Logvar shape: {logvar.shape}")

Encoder dimensions: [(1, 40, 40, 200), (32, 20, 20, 100), (64, 10, 10, 50), (128, 5, 5, 50)]
Decoder dimensions: [(128, 5, 5, 50), (64, 10, 10, 100), (32, 20, 20, 200), (1, 40, 40, 200)]
FlexibleCVAE3D(
  (encoder): Sequential(
    (0): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (1): ReLU()
    (2): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (3): ReLU()
    (4): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (5): ReLU()
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=160000, out_features=512, bias=True)
    (8): ReLU()
    (9): Linear(in_features=512, out_features=512, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=160000, bias=True)
    (3): ReLU()
    (4): Unflatten(dim=1, unflattened_size=(128, 5, 5, 50))
    (5): ConvTranspose3d(128, 64, kerne

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x80000 and 160000x512)