##### Use this notebook to understand the U-net blocks and the U-net

In [None]:
import torch.nn as nn
import numpy as np

In [None]:
class SpatialAttention(nn.Module):
    """
    Self-attention
    """
    def __init__(self, channels, num_heads:int =8):
        super().__init__()
        assert channels%num_heads == 0
        self.mha = nn.MultiheadAttention(
            embed_dim=channels,
            num_heads=num_heads,
            batch_first=True
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass
        x: (B, C, H, W)
        """
        B, C, H, W = x.shape
        # Reshape to (B, H*W, C)
        x = rearrange(x, "b c h w -> b (h w) c")
        # Apply attention
        x, _ = self.mha(x, x, x)
        # Reshape back to (B, C, H, W)
        x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
        return x

In [None]:
class DownBlock(nn.Module):
    """
    ONE down-block
    """
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        time_emb_dim: int = 1280,
        num_groups: int = 32,
        use_attention: bool = False
    ):
        super().__init__()
        
        # First conv block
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.norm1 = nn.GroupNorm(num_groups, out_channels)
        
        # Time embedding projection
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        
        # Second conv block
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.norm2 = nn.GroupNorm(num_groups, out_channels)
        
        # Residual connection (if channels change)
        self.residual_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        
        # Optional attention
        self.attention = SpatialAttention(out_channels) if use_attention else nn.Identity()
        
        # Downsample
        self.downsample = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1)
        
        self.silu = nn.SiLU()
    
    def forward(self, x, t_emb):
        """
        Forward pass
        """
        residual = x
        print(f"The shape of the residual is {residual.shape} and that of the input is {x.shape}")
        # First conv block
        x = self.conv1(x)
        print(f"\n After first convolution, the shape of the inputs is \n {x.shape}")
        x = self.norm1(x)
        print(f"\n After first group norm, the shape of the inputs is \n {x.shape}")
        x = self.silu(x)
        print(f"\n After first activation, the shape of the inputs is \n {x.shape}")
        
        # Add time embedding (broadcast over spatial dims)
        t = self.time_mlp(t_emb)
        #assert t.shape == t_emb.shape
        print(f"\n The time embeddings are of originally shape \n {t_emb.shape}")
        t = t[:, :, None, None]  # (B, C, 1, 1)
        print(f"\n The time embeddings are reshaped through broadcasting to \n {t.shape}")
        x = x + t
        print(f"\n The inputs are concatenated with the time embeddings to give \n {x.shape}")
        
        # Second conv block
        x = self.conv2(x)
        print(f"\n After the second convolution, the shape of the inputs is \n {x.shape}")
        x = self.norm2(x)
        print(f"\n After the second group norm, the shape of the inputs is \n {x.shape}")
        
        # Residual connection
        print(f"The shape of the residual is {residual.shape}")
        residual = self.residual_conv(residual)
        print(f"\n After the residual convolution, the shape of the residual is \n {residual.shape}")
        x = x + residual
        print(f"After adding the residual, the shape of the input is {x.shape}")
        x = self.silu(x)
        print(f"After the final activation, the shape of the inputs is {x.shape}")
        
        # Optional attention
        x = self.attention(x)
        print(f"After attention, the shape is {x.shape}")
        
        # Downsample
        x = self.downsample(x)
        print(f"After downsampling, the shape of the inputs is {x.shape}")
        return x

In [None]:
B, C, H, W = 2, 16, 32, 32
x = torch.randn(B, C, H, W)
t_emb_dim = 1280
down_block = DownBlock(in_channels=16,
                        out_channels=32)
t_emb = torch.randn(B, t_emb_dim)
out = down_block(x, t_emb)

In [None]:
class UpBlock(nn.Module):
    """
    Upsampling block for U-Net decoder
    Input: (B, in_channels, H, W) + skip (B, skip_channels, H*2, W*2)
    Output: (B, out_channels, H*2, W*2)
    """
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        skip_channels: int,
        time_emb_dim: int = 1280,
        num_groups: int = 32,
        use_attention: bool = False,
        do_upsample: bool = True
    ):
        super().__init__()
        self.do_upsample = do_upsample
        
        # After concat with skip: in_channels + skip_channels
        total_channels = in_channels + skip_channels
        
        # First conv block
        self.conv1 = nn.Conv2d(total_channels, out_channels, 3, padding=1)
        self.norm1 = nn.GroupNorm(num_groups, out_channels)
        
        # Time embedding
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        
        # Second conv block
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.norm2 = nn.GroupNorm(num_groups, out_channels)
        
        # Residual
        self.residual_conv = nn.Conv2d(total_channels, out_channels, 1)
        
        # Optional attention
        self.attention = SpatialAttention(out_channels) if use_attention else nn.Identity()
        
        # Upsample LAST - operates on out_channels!
        self.upsample = nn.ConvTranspose2d(out_channels, out_channels, 4, stride=2, padding=1)  # ← FIX!
        
        self.silu = nn.SiLU()
    
    def forward(
        self, 
        x: torch.Tensor, 
        skip: torch.Tensor, 
        t_emb: torch.Tensor
    ) -> torch.Tensor:
        """
        x: (B, in_channels, H, W)
        skip: (B, skip_channels, H*2, W*2) from encoder
        t_emb: (B, time_emb_dim)
        """
        print(f"The inputs are of shape \n {x.shape}")
        print(f"\n The skip connections are of shape \n {skip.shape}")
        print(f"\n The time embedding is of shape {t_emb.shape}")
        # Concatenate with skip
        x = torch.cat([x, skip], dim=1)  # (B, in_channels+skip_channels, H*2, W*2)
        print(f"\n The concatenated input (input + skip) is of shape \n {x.shape}")
        
        # Save for residual
        residual = x
        
        # First conv
        x = self.conv1(x)
        print(f"\n After first convolution, the shape of the inputs is \n {x.shape}")
        x = self.norm1(x)
        print(f"\n After first group norm, the shape of the inputs is \n {x.shape}")
        x = self.silu(x)
        print(f"\n After first silu activation, the shape of the inputs is \n {x.shape}")
        
        # Time embedding
        t = self.time_mlp(t_emb)[:, :, None, None]
        print(f"\n The time embeddings are reshaped to \n {t.shape}")
        x = x + t
        print(f"\n After adding the time embeddings/position info, the shape is \n {x.shape}")
        
        # Second conv
        x = self.conv2(x)
        print(f"\n After second convolution, the shape of the inputs is \n {x.shape}")
        x = self.norm2(x)
        print(f"\n After second group norm, the shape of the inputs is \n {x.shape}")

        # Residual
        residual = self.residual_conv(residual)
        print(f"\n The residual shape is \n {residual.shape}")
        x = x + residual
        print(f"\n After adding the residual, the shape of the inputs is \n {x.shape}")
        x = self.silu(x)
        print(f"\n After the second silu activation, the shape of the inputs is \n {x.shape}")
        
        # Attention
        x = self.attention(x)
        print(f"\n After attention, the shape of the inputs is \n {x.shape}")
        # Upsample
        if self.do_upsample:
            x = self.upsample(x)  # (B, in_channels, H*2, W*2)
        print(f"\n After the upsampling, the shape of the inputs is \n {x.shape}")
        return x


In [None]:
# Test the Up-convolution block
B, in_ch, out_ch, skip_ch = 2, 64, 128, 32
H, W = 16, 16

x = torch.randn(B, in_ch, H, W)
skip = torch.randn(B, skip_ch, H, W)  # ← SAME size as x!
t_emb = torch.randn(B, 1280)

up_block = UpBlock(in_ch, out_ch, skip_ch)
out = up_block(x, skip, t_emb)