In [1]:
%load_ext autoreload
%autoreload 2

In [17]:
import torch
from network import InversionNet3D


def test_inversion_net_3d_shapes():
    """
    Test function to validate the input and output shapes of the InversionNet3D model.
    
    Expected input shape: [batch, 1, 5, 1000, 70]
    Expected output shape: [batch, 1, 1, 70, 70]
    """
    # Set random seed for reproducibility
    torch.manual_seed(42)
    
    # Create model instance
    model = InversionNet3D()
    
    # Set model to evaluation mode
    model.eval()
    
    # Create dummy input tensor with shape [batch, 1, 5, 1000, 70]
    batch_size = 2
    input_tensor = torch.randn(batch_size, 1, 5, 1000, 70)
    
    print(f"Input shape: {input_tensor.shape}")
    
    # Forward pass
    with torch.no_grad():
        output = model(input_tensor)
    
    print(f"Output shape: {output.shape}")
    
    # Validate shapes
    expected_output_shape = (batch_size, 1, 1, 70, 70)
    assert output.shape == expected_output_shape, f"Expected output shape {expected_output_shape}, got {output.shape}"
    
    print("Shape validation successful!")
    
    # Print more detailed shapes at each layer for debugging (optional)
    print("\nDetailed shapes through network:")
    def hook_fn(module, input, output):
        print(f"{module.__class__.__name__}: Input shape: {input[0].shape} | Output shape: {output.shape}")
    
    # Remove all hooks first (in case we run this multiple times)
    if hasattr(model, 'hooks'):
        for handle in model.hooks:
            handle.remove()
    
    # Register hooks for all modules
    hooks = []
    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
            hooks.append(module.register_forward_hook(hook_fn))
    
    # Run forward pass again to see shapes
    with torch.no_grad():
        output = model(input_tensor)
    
    # Clean up hooks
    for handle in hooks:
        handle.remove()

test_inversion_net_3d_shapes()

Input shape: torch.Size([2, 1, 5, 1000, 70])
torch.Size([2, 32, 1, 80, 80])
torch.Size([2, 1, 1, 70, 70])
Output shape: torch.Size([2, 1, 1, 70, 70])
Shape validation successful!

Detailed shapes through network:
Conv3d: Input shape: torch.Size([2, 1, 5, 1000, 70]) | Output shape: torch.Size([2, 32, 5, 500, 70])
Conv3d: Input shape: torch.Size([2, 32, 5, 500, 70]) | Output shape: torch.Size([2, 64, 5, 250, 70])
Conv3d: Input shape: torch.Size([2, 64, 5, 250, 70]) | Output shape: torch.Size([2, 64, 5, 250, 70])
Conv3d: Input shape: torch.Size([2, 64, 5, 250, 70]) | Output shape: torch.Size([2, 64, 5, 125, 70])
Conv3d: Input shape: torch.Size([2, 64, 5, 125, 70]) | Output shape: torch.Size([2, 64, 5, 125, 70])
Conv3d: Input shape: torch.Size([2, 64, 5, 125, 70]) | Output shape: torch.Size([2, 128, 3, 63, 70])
Conv3d: Input shape: torch.Size([2, 128, 3, 63, 70]) | Output shape: torch.Size([2, 128, 3, 63, 70])
Conv3d: Input shape: torch.Size([2, 128, 3, 63, 70]) | Output shape: torch.Size(