In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from tqdm.notebook import tqdm
import datetime
import numpy as np




In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchsummary import summary


class Encoder(nn.Module):
    def __init__(
        self,
        latent_dim=512,
        embed_dim=768,
        backbone_channels=2048,
        hidden_dim=1536,
        negative_slope=0.01
    ):
        super().__init__()
        self.latent_dim_size = latent_dim
        self.embed_dim_size = embed_dim

        self.conv = nn.Sequential(
            # Conv Layer 1: kernel_size=4, stride=2, padding=1
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope, inplace=True),

            # Conv Layer 2: kernel_size=4, stride=2, padding=1
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope, inplace=True),

            # Conv Layer 3: kernel_size=3, stride=1, padding=1
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope, inplace=True),

            # Conv Layer 4: kernel_size=3, stride=1, padding=1
            nn.Conv2d(in_channels=256, out_channels=backbone_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(backbone_channels),
            nn.LeakyReLU(negative_slope, inplace=True),
        )

        self.pool = nn.AdaptiveAvgPool2d((1, 1))

        self.fc1 = nn.Linear(backbone_channels, hidden_dim)
        self.act = nn.LeakyReLU(negative_slope, inplace=True)
        self.fc2 = nn.Linear(hidden_dim, latent_dim + embed_dim)

    def forward(self, x):
        B = x.size(0)
        feat = self.conv(x)
        pooled = self.pool(feat).view(B, -1)
        hidden = self.act(self.fc1(pooled))
        out = self.fc2(hidden)
        latent, emb_pred = out.split([self.latent_dim_size, self.embed_dim_size], dim=1)
        return latent, emb_pred

class Decoder(nn.Module):
    def __init__(
        self,
        latent_dim=512,
        embed_dim=768,
        image_res=512,
        negative_slope=0.01
    ):
        super().__init__()
        self.input_dim = latent_dim + embed_dim + 1  # +1 for ratio
        self.init_size = image_res // 32

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 256 * self.init_size * self.init_size),
            nn.LeakyReLU(negative_slope, inplace=True),
            nn.Unflatten(1, (256, self.init_size, self.init_size)),
        )

        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),  # 16 → 32
            nn.LeakyReLU(negative_slope, inplace=True),
            nn.ConvTranspose2d(in_channels=128, out_channels=64,  kernel_size=4, stride=2, padding=1),  # 32 → 64
            nn.LeakyReLU(negative_slope, inplace=True),
            nn.ConvTranspose2d(in_channels=64,  out_channels=32,  kernel_size=4, stride=2, padding=1),  # 64 → 128
            nn.LeakyReLU(negative_slope, inplace=True),
            nn.ConvTranspose2d(in_channels=32,  out_channels=16,  kernel_size=4, stride=2, padding=1),  # 128 → 256
            nn.LeakyReLU(negative_slope, inplace=True),
            nn.ConvTranspose2d(in_channels=16,  out_channels=3,   kernel_size=4, stride=2, padding=1),  # 256 → 512
            nn.Sigmoid()
        )
    def forward(self, latent, embed, ratio):
        """
        latent: [B, latent_dim]
        embed:  [B, embed_dim]
        ratio:  [B, 1] (e.g., torch.tensor([[1.0], [1.33], ...]))
        """
        x = torch.cat([latent, embed, ratio], dim=1)
        x = self.fc(x)
        x = self.deconv(x)
        return x

enc = Encoder()

img1 = torch.randn(1, 3, 128, 128)   # square
img2 = torch.randn(1, 3, 128, 256)   # rectangular

latent1, emb1 = enc(img1)
latent2, emb2 = enc(img2)

print("img1 → latent:", latent1.shape, ", emb:", emb1.shape)
print("img2 → latent:", latent2.shape, ", emb:", emb2.shape)


img1 → latent: torch.Size([1, 512]) , emb: torch.Size([1, 768])
img2 → latent: torch.Size([1, 512]) , emb: torch.Size([1, 768])


In [2]:
decoder = Decoder().cpu()

summary(decoder, input_data=[
    torch.zeros((1, 512)),   # latent
    torch.zeros((1, 768)),   # embed
    torch.zeros((1, 1))      # ratio
])

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 256, 16, 16]         --
|    └─Linear: 2-1                       [-1, 65536]               84,017,152
|    └─LeakyReLU: 2-2                    [-1, 65536]               --
|    └─Unflatten: 2-3                    [-1, 256, 16, 16]         --
├─Sequential: 1-2                        [-1, 3, 512, 512]         --
|    └─ConvTranspose2d: 2-4              [-1, 128, 32, 32]         524,416
|    └─LeakyReLU: 2-5                    [-1, 128, 32, 32]         --
|    └─ConvTranspose2d: 2-6              [-1, 64, 64, 64]          131,136
|    └─LeakyReLU: 2-7                    [-1, 64, 64, 64]          --
|    └─ConvTranspose2d: 2-8              [-1, 32, 128, 128]        32,800
|    └─LeakyReLU: 2-9                    [-1, 32, 128, 128]        --
|    └─ConvTranspose2d: 2-10             [-1, 16, 256, 256]        8,208
|    └─LeakyReLU: 2-11                   [-1, 16, 256, 256] 

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 256, 16, 16]         --
|    └─Linear: 2-1                       [-1, 65536]               84,017,152
|    └─LeakyReLU: 2-2                    [-1, 65536]               --
|    └─Unflatten: 2-3                    [-1, 256, 16, 16]         --
├─Sequential: 1-2                        [-1, 3, 512, 512]         --
|    └─ConvTranspose2d: 2-4              [-1, 128, 32, 32]         524,416
|    └─LeakyReLU: 2-5                    [-1, 128, 32, 32]         --
|    └─ConvTranspose2d: 2-6              [-1, 64, 64, 64]          131,136
|    └─LeakyReLU: 2-7                    [-1, 64, 64, 64]          --
|    └─ConvTranspose2d: 2-8              [-1, 32, 128, 128]        32,800
|    └─LeakyReLU: 2-9                    [-1, 32, 128, 128]        --
|    └─ConvTranspose2d: 2-10             [-1, 16, 256, 256]        8,208
|    └─LeakyReLU: 2-11                   [-1, 16, 256, 256] 

In [None]:
model = Encoder().cpu()

# Show forward structure for a square image
print("== Square Input (3×512×512) ==")
summary(model, (3, 512, 512))

# Show forward structure for a rectangular image
print("\n== Rectangular Input (3×512×1024) ==")
summary(model, (3, 512, 1024))

== Square Input (3×512×512) ==
Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 2048, 128, 128]      --
|    └─Conv2d: 2-1                       [-1, 64, 256, 256]        3,136
|    └─BatchNorm2d: 2-2                  [-1, 64, 256, 256]        128
|    └─LeakyReLU: 2-3                    [-1, 64, 256, 256]        --
|    └─Conv2d: 2-4                       [-1, 128, 128, 128]       131,200
|    └─BatchNorm2d: 2-5                  [-1, 128, 128, 128]       256
|    └─LeakyReLU: 2-6                    [-1, 128, 128, 128]       --
|    └─Conv2d: 2-7                       [-1, 256, 128, 128]       295,168
|    └─BatchNorm2d: 2-8                  [-1, 256, 128, 128]       512
|    └─LeakyReLU: 2-9                    [-1, 256, 128, 128]       --
|    └─Conv2d: 2-10                      [-1, 2048, 128, 128]      4,720,640
|    └─BatchNorm2d: 2-11                 [-1, 2048, 128, 128]      4,096
|    └─LeakyReLU: 2-12      

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 2048, 128, 256]      --
|    └─Conv2d: 2-1                       [-1, 64, 256, 512]        3,136
|    └─BatchNorm2d: 2-2                  [-1, 64, 256, 512]        128
|    └─LeakyReLU: 2-3                    [-1, 64, 256, 512]        --
|    └─Conv2d: 2-4                       [-1, 128, 128, 256]       131,200
|    └─BatchNorm2d: 2-5                  [-1, 128, 128, 256]       256
|    └─LeakyReLU: 2-6                    [-1, 128, 128, 256]       --
|    └─Conv2d: 2-7                       [-1, 256, 128, 256]       295,168
|    └─BatchNorm2d: 2-8                  [-1, 256, 128, 256]       512
|    └─LeakyReLU: 2-9                    [-1, 256, 128, 256]       --
|    └─Conv2d: 2-10                      [-1, 2048, 128, 256]      4,720,640
|    └─BatchNorm2d: 2-11                 [-1, 2048, 128, 256]      4,096
|    └─LeakyReLU: 2-12                   [-1, 2048, 128, 25