In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from tqdm.notebook import tqdm
import datetime
import numpy as np




In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchsummary import summary


class UniversalEncoder(nn.Module):
    def __init__(
        self,
        latent_dim=512,
        embed_dim=768,
        backbone_channels=2048,
        hidden_dim=1536,
        negative_slope=0.01
    ):
        super().__init__()
        self.latent_dim_size = latent_dim
        self.embed_dim_size = embed_dim


        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),   
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1), 
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope, inplace=True),

            nn.Conv2d(128, 256, 3, 1, 1), 
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope, inplace=True),

            nn.Conv2d(256, backbone_channels, 3, 1, 1),  
            nn.BatchNorm2d(backbone_channels),
            nn.LeakyReLU(negative_slope, inplace=True),
        )

        # Pool to fixed-size vector
        self.pool = nn.AdaptiveAvgPool2d((1, 1))

        self.fc1 = nn.Linear(backbone_channels, hidden_dim)
        self.act = nn.LeakyReLU(negative_slope, inplace=True)
        self.fc2 = nn.Linear(hidden_dim, latent_dim + embed_dim)

    def forward(self, x):
        B = x.size(0)

        feat = self.conv(x)

        pooled = self.pool(feat).view(B, -1)

        hidden = self.act(self.fc1(pooled))
        out    = self.fc2(hidden)

        # Split into 512 latent + 768 embedding
        latent, emb_pred = out.split([self.latent_dim_size, self.embed_dim_size], dim=1)
        return latent, emb_pred


enc = UniversalEncoder()

img1 = torch.randn(1, 3, 128, 128)   # square
img2 = torch.randn(1, 3, 128, 256)   # rectangular

latent1, emb1 = enc(img1)
latent2, emb2 = enc(img2)

print("img1 → latent:", latent1.shape, ", emb:", emb1.shape)
print("img2 → latent:", latent2.shape, ", emb:", emb2.shape)


img1 → latent: torch.Size([1, 512]) , emb: torch.Size([1, 768])
img2 → latent: torch.Size([1, 512]) , emb: torch.Size([1, 768])


In [22]:
model = UniversalEncoder().cpu()

# Show forward structure for a square image
print("== Square Input (3×512×512) ==")
summary(model, (3, 512, 512))

# Show forward structure for a rectangular image
print("\n== Rectangular Input (3×512×1024) ==")
summary(model, (3, 512, 1024))

== Square Input (3×512×512) ==
Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 2048, 128, 128]      --
|    └─Conv2d: 2-1                       [-1, 64, 256, 256]        3,136
|    └─BatchNorm2d: 2-2                  [-1, 64, 256, 256]        128
|    └─LeakyReLU: 2-3                    [-1, 64, 256, 256]        --
|    └─Conv2d: 2-4                       [-1, 128, 128, 128]       131,200
|    └─BatchNorm2d: 2-5                  [-1, 128, 128, 128]       256
|    └─LeakyReLU: 2-6                    [-1, 128, 128, 128]       --
|    └─Conv2d: 2-7                       [-1, 256, 128, 128]       295,168
|    └─BatchNorm2d: 2-8                  [-1, 256, 128, 128]       512
|    └─LeakyReLU: 2-9                    [-1, 256, 128, 128]       --
|    └─Conv2d: 2-10                      [-1, 2048, 128, 128]      4,720,640
|    └─BatchNorm2d: 2-11                 [-1, 2048, 128, 128]      4,096
|    └─LeakyReLU: 2-12      

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 2048, 128, 256]      --
|    └─Conv2d: 2-1                       [-1, 64, 256, 512]        3,136
|    └─BatchNorm2d: 2-2                  [-1, 64, 256, 512]        128
|    └─LeakyReLU: 2-3                    [-1, 64, 256, 512]        --
|    └─Conv2d: 2-4                       [-1, 128, 128, 256]       131,200
|    └─BatchNorm2d: 2-5                  [-1, 128, 128, 256]       256
|    └─LeakyReLU: 2-6                    [-1, 128, 128, 256]       --
|    └─Conv2d: 2-7                       [-1, 256, 128, 256]       295,168
|    └─BatchNorm2d: 2-8                  [-1, 256, 128, 256]       512
|    └─LeakyReLU: 2-9                    [-1, 256, 128, 256]       --
|    └─Conv2d: 2-10                      [-1, 2048, 128, 256]      4,720,640
|    └─BatchNorm2d: 2-11                 [-1, 2048, 128, 256]      4,096
|    └─LeakyReLU: 2-12                   [-1, 2048, 128, 25