سهیل حمزه بیگی
شماره دانشجویی: ۴۰۳۴۴۳۰۴۷

In [None]:
# Runtime -> Change runtime type -> GPU
!pip install -q einops

In [None]:
import math
import os
from functools import partial
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from einops import rearrange
import torchvision.transforms.functional as TF
from torchvision.transforms import functional as F_tf
import torchvision
import numpy as np
import cv2

In [None]:
def pair(x):
    return x if isinstance(x, tuple) else (x, x)

def exists(val):
    return val is not None


# Patch Embed
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=8, in_chans=3, embed_dim=128):
        super().__init__()
        img_size = pair(img_size)
        patch_size = pair(patch_size)
        assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.embed_dim = embed_dim

    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)  # (B, E, H/ps, W/ps)
        B, E, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # (B, N, E)
        return x, (H, W)

In [None]:
# Transformer components (lightweight)
class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim=None, out_dim=None, drop=0.0):
        super().__init__()
        out_dim = out_dim or in_dim
        hidden_dim = hidden_dim or in_dim * 4
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, out_dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        return self.drop(self.fc2(self.act(self.fc1(x))))

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=4, mlp_ratio=2.0, drop=0., attn_drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=attn_drop, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim * mlp_ratio), None, drop)

    def forward(self, x):
        # x: (B, N, C)
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

In [None]:
# Hierarchical encoder
class HierarchicalEncoder(nn.Module):
    def __init__(self, img_size=32, scales=(8,4,2), embed_dims=(64,96,128),
                 depth_per_scale=(1,2,2), num_heads=(2,3,4), in_chans=3):
        super().__init__()
        assert len(scales)==len(embed_dims)==len(depth_per_scale)==len(num_heads)
        self.scales = scales
        self.patch_embeds = nn.ModuleList()
        self.pos_embeds = nn.ParameterList()
        self.blocks = nn.ModuleList()
        for ps, ed, dpt, nh in zip(scales, embed_dims, depth_per_scale, num_heads):
            pe = PatchEmbed(img_size=img_size, patch_size=ps, in_chans=in_chans, embed_dim=ed)
            self.patch_embeds.append(pe)
            num_patches = pe.num_patches
            self.pos_embeds.append(nn.Parameter(torch.zeros(1, num_patches, ed)))
            # tiny transformer blocks per-scale
            b = nn.Sequential(*[TransformerBlock(ed, num_heads=nh, mlp_ratio=2.0) for _ in range(dpt)])
            self.blocks.append(b)

        # fusion: project concatenated pooled features to fusion_dim
        fusion_in = sum(embed_dims)
        fusion_out = embed_dims[-1]
        self.fusion_proj = nn.Linear(fusion_in, fusion_out)
        self.norm = nn.LayerNorm(fusion_out)

        # initialize pos_embeds
        for p in self.pos_embeds:
            nn.init.trunc_normal_(p, std=0.02)

    def forward(self, x):
        outs = []
        grids = []
        pooled = []
        for pe, pos, blocks in zip(self.patch_embeds, self.pos_embeds, self.blocks):
            tokens, grid = pe(x)  # (B, N, C)
            tokens = tokens + pos
            tokens = blocks(tokens)
            outs.append(tokens)
            grids.append(grid)
            pooled.append(tokens.mean(dim=1))  # global pool of scale
        fused = torch.cat(pooled, dim=1)  # (B, sumC)
        fused = self.fusion_proj(fused)  # (B, C_out)
        fused = self.norm(fused).unsqueeze(1)  # (B, 1, C)
        return fused, outs, grids

In [None]:
# Decoder
class SmallDecoder(nn.Module):
    def __init__(self, enc_dim, decoder_dim=128, depth=3, num_heads=4, out_patch_size=8, img_size=32):
        super().__init__()
        self.in_proj = nn.Linear(enc_dim, decoder_dim)
        self.mask_token = nn.Parameter(torch.zeros(1,1,decoder_dim))
        self.blocks = nn.ModuleList([TransformerBlock(decoder_dim, num_heads=num_heads) for _ in range(depth)])
        self.norm = nn.LayerNorm(decoder_dim)
        # predict per patch average RGB (3) for patches of largest scale grid (img_size//out_patch_size)^2 patches
        gh = img_size // out_patch_size
        self.gh = gh
        self.out_dim = 3  # average RGB per patch
        self.pred = nn.Linear(decoder_dim, self.out_dim)

        nn.init.trunc_normal_(self.mask_token, std=0.02)

    def forward(self, fused_token, largest_scale_tokens, mask):
        # fused_token
        B = fused_token.size(0)
        N = largest_scale_tokens.size(1)
        dec_in = self.in_proj(fused_token).repeat(1, N, 1)  # (B,N,dec_dim)
        if largest_scale_tokens.size(2) != dec_in.size(2):
            proj = nn.Linear(largest_scale_tokens.size(2), dec_in.size(2)).to(largest_scale_tokens.device)
            vt = proj(largest_scale_tokens)
        else:
            vt = largest_scale_tokens
        # add vt for visible
        dec = dec_in.clone()
        visible = ~mask
        dec[visible] = dec[visible] + vt[visible]
        # masked positions: add mask token
        dec[mask] = dec[mask] + self.mask_token.repeat(B, mask.size(1), 1)[mask]
        x = dec
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        preds = self.pred(x)  # (B, N, 3)
        # flatten to B, N*3
        return preds.view(B, -1)

In [None]:
# MAE
class DPHMAE_CIFAR(nn.Module):
    def __init__(self, img_size=32, scales=(8,4,2), embed_dims=(64,96,128),
                 depth_per_scale=(1,2,2)):
        super().__init__()
        self.encoder = HierarchicalEncoder(img_size=img_size, scales=scales,
                                           embed_dims=embed_dims, depth_per_scale=depth_per_scale)
        self.enc_out_dim = embed_dims[-1]
        self.decoder = SmallDecoder(enc_dim=self.enc_out_dim, decoder_dim=128, depth=3, num_heads=4,
                                    out_patch_size=scales[0], img_size=img_size)
        self.scales = scales

    def forward(self, images, mask):
        fused, outs, grids = self.encoder(images)
        # use largest scale outputs for per-patch tokens (outs[0])
        largest_tokens = outs[0]  # (B, N, C_small)
        preds = self.decoder(fused, largest_tokens, mask)
        return preds

In [None]:
# Dynamic Masking + Progressive schedule for CIFAR patch grid
class DynamicMaskerCIFAR:
    def __init__(self, img_size=32, patch_size=8, base_mask=0.3, target_mask=0.75,
                 schedule_epochs=100, mode='variance', device='cpu'):
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid = (img_size // patch_size, img_size // patch_size)
        self.N = self.grid[0] * self.grid[1]
        self.base = base_mask
        self.target = target_mask
        self.epochs = schedule_epochs
        self.mode = mode
        self.device = device

    def ratio(self, epoch):
        t = min(1.0, epoch / max(1, self.epochs))
        return self.base + (self.target - self.base) * t  # linear schedule

    def compute_importance(self, images):
        # images: (B, C, H, W) in [0,1]
        B, C, H, W = images.shape
        ph = self.patch_size
        gh = H // ph
        gw = W // ph
        patches = images.unfold(2, ph, ph).unfold(3, ph, ph)  # (B, C, gh, gw, ph, ph)
        patches = patches.contiguous().view(B, C, gh*gw, ph, ph)
        patches = patches.permute(0,2,1,3,4).reshape(B, gh*gw, C*ph*ph)  # (B, N, D)
        if self.mode == 'variance':
            imp = patches.var(dim=2)  # (B, N)
        elif self.mode == 'edge':
            # compute Sobel edge magnitude per patch quickly by resizing and cv2.Sobel; but keep pure torch:
            gray = images.mean(dim=1, keepdim=True)  # (B,1,H,W)
            # simple laplacian via conv
            lap = F.conv2d(gray, weight=torch.tensor([[[[0,1,0],[1,-4,1],[0,1,0]]]], dtype=gray.dtype, device=gray.device), padding=1)
            p2 = lap.unfold(2, ph, ph).unfold(3, ph, ph).contiguous().view(B, gh*gw, ph*ph)
            imp = p2.abs().mean(dim=2)
        else:
            imp = patches.var(dim=2)
        # normalize to probabilities
        prob = imp + 1e-6
        prob = prob / prob.sum(dim=1, keepdim=True)
        return prob  # (B, N)

    def sample_mask(self, images, epoch):
        B = images.size(0)
        r = self.ratio(epoch)
        prob = self.compute_importance(images)  # higher -> important
        inv = 1.0 - prob
        inv = inv / inv.sum(dim=1, keepdim=True)
        k = int(self.N * r)
        masks = torch.zeros(B, self.N, dtype=torch.bool, device=self.device)
        for i in range(B):
            if k <= 0:
                continue
            idx = torch.multinomial(inv[i], k, replacement=False)
            masks[i, idx] = True
        return masks

In [None]:
# Training on CIFAR-10
def build_dataloaders(batch_size=128, img_size=32):
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
    ])
    train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    val_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, val_loader

def compute_patch_targets(images, patch_size):
    # images
    ph = patch_size
    pooled = F.avg_pool2d(images, kernel_size=ph, stride=ph)  # (B, C, gh, gw)
    B, C, gh, gw = pooled.shape
    return pooled.permute(0,2,3,1).reshape(B, gh*gw* C)  # (B, N*3)

In [None]:
def train_epoch(model, loader, optimizer, masker, epoch, device):
    model.train()
    total_loss = 0.0
    n = 0
    for images, _ in loader:
        images = images.to(device)
        masks = masker.sample_mask(images, epoch)  # (B, N)
        preds = model(images, masks)  # (B, N*3)
        targets = compute_patch_targets(images, masker.patch_size).to(device)
        # compute loss only on masked patches
        B = images.size(0)
        N = masker.N
        preds = preds.view(B, N, 3)
        targets = targets.view(B, N, 3)
        mask = masks.to(device)
        if mask.sum() == 0:
            loss = F.mse_loss(preds, targets)
        else:
            loss = F.mse_loss(preds[mask], targets[mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * B
        n += B
    return total_loss / n

In [None]:
def validate_epoch(model, loader, masker, device):
    model.eval()
    total_loss = 0.0
    n = 0
    with torch.no_grad():
        for images, _ in loader:
            images = images.to(device)
            # use deterministic mask ratio (e.g., target ratio) for eval or use epoch last
            masks = masker.sample_mask(images, masker.epochs)  # evaluate at final difficulty
            preds = model(images, masks)
            targets = compute_patch_targets(images, masker.patch_size).to(device)
            B = images.size(0)
            N = masker.N
            preds = preds.view(B, N, 3)
            targets = targets.view(B, N, 3)
            mask = masks.to(device)
            if mask.sum() == 0:
                loss = F.mse_loss(preds, targets)
            else:
                loss = F.mse_loss(preds[mask], targets[mask])
            total_loss += loss.item() * B
            n += B
    return total_loss / n

In [None]:
# Train

def run_cifar_training(epochs=20, batch_size=128, device=None, save_dir='./checkpoints'):
    device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device:", device)
    train_loader, val_loader = build_dataloaders(batch_size=batch_size, img_size=32)
    model = DPHMAE_CIFAR(img_size=32, scales=(8,4,2), embed_dims=(64,96,128), depth_per_scale=(1,2,2)).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
    masker = DynamicMaskerCIFAR(img_size=32, patch_size=8, base_mask=0.3, target_mask=0.75,
                                schedule_epochs=epochs, mode='variance', device=device)
    os.makedirs(save_dir, exist_ok=True)
    for epoch in range(1, epochs+1):
        train_loss = train_epoch(model, train_loader, optimizer, masker, epoch, device)
        val_loss = validate_epoch(model, val_loader, masker, device)
        print(f"Epoch {epoch:02d}  Train Loss: {train_loss:.6f}  Val Loss: {val_loss:.6f}  MaskRatio: {masker.ratio(epoch):.3f}")
        torch.save({'model': model.state_dict(), 'opt': optimizer.state_dict(), 'epoch': epoch}, f"{save_dir}/dph_mae_epoch{epoch}.pth")
    print("Training finished.")

# If running interactively in Colab, call:
if __name__ == '__main__':
    run_cifar_training(epochs=20, batch_size=256)


Device: cuda


100%|██████████| 170M/170M [00:02<00:00, 71.4MB/s]


Epoch 01  Train Loss: 0.052552  Val Loss: 0.024355  MaskRatio: 0.323
Epoch 02  Train Loss: 0.025017  Val Loss: 0.024590  MaskRatio: 0.345
Epoch 03  Train Loss: 0.024149  Val Loss: 0.023974  MaskRatio: 0.367
Epoch 04  Train Loss: 0.023673  Val Loss: 0.023263  MaskRatio: 0.390
Epoch 05  Train Loss: 0.023774  Val Loss: 0.023378  MaskRatio: 0.412
Epoch 06  Train Loss: 0.023455  Val Loss: 0.023396  MaskRatio: 0.435
Epoch 07  Train Loss: 0.023377  Val Loss: 0.023170  MaskRatio: 0.458
Epoch 08  Train Loss: 0.023435  Val Loss: 0.023090  MaskRatio: 0.480
Epoch 09  Train Loss: 0.023337  Val Loss: 0.023269  MaskRatio: 0.502
Epoch 10  Train Loss: 0.023375  Val Loss: 0.023077  MaskRatio: 0.525
Epoch 11  Train Loss: 0.023415  Val Loss: 0.023025  MaskRatio: 0.547
Epoch 12  Train Loss: 0.023262  Val Loss: 0.023110  MaskRatio: 0.570
Epoch 13  Train Loss: 0.023305  Val Loss: 0.023092  MaskRatio: 0.593
Epoch 14  Train Loss: 0.023336  Val Loss: 0.023379  MaskRatio: 0.615
Epoch 15  Train Loss: 0.023243  Va

## Linear Probing & Fine Tuning

In [None]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)

Device: cuda


In [None]:
import os, time, math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import FakeData

In [None]:
def run_linear_probe(checkpoint_path="./checkpoints/dph_mae_epoch20.pth", epochs=5, batch_size=128):
    import torch
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    data_root = download_imagenette()
    tf = transforms.Compose([transforms.Resize((32,32)), transforms.ToTensor()])
    train_ds = datasets.ImageFolder(os.path.join(data_root,"train"), transform=tf)
    val_ds = datasets.ImageFolder(os.path.join(data_root,"val"), transform=tf)
    train_loader = DataLoader(train_ds,batch_size=batch_size,shuffle=True,num_workers=2)
    val_loader = DataLoader(val_ds,batch_size=batch_size,shuffle=False,num_workers=2)

    # load encoder (pretrained)
    base = DPHMAE_CIFAR()
    ckpt = torch.load(checkpoint_path, map_location=device)
    base.load_state_dict(ckpt['model'], strict=False)
    encoder = base.encoder.to(device)

    model = LinearProbe(encoder, enc_dim=128, num_classes=len(train_ds.classes)).to(device)
    opt = torch.optim.Adam(model.fc.parameters(), lr=1e-3)

    for ep in range(1, epochs+1):
        model.train(); correct,total,loss_sum=0,0,0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)
            loss = F.cross_entropy(logits, labels)
            opt.zero_grad(); loss.backward(); opt.step()
            preds = logits.argmax(1)
            correct += (preds==labels).sum().item(); total+=labels.size(0)
            loss_sum += loss.item()*labels.size(0)
        print(f"[LinearProbe] Epoch {ep} TrainAcc {correct/total:.3f} Loss {loss_sum/total:.4f}")

        # validation
        model.eval(); correct,total=0,0
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                preds = model(imgs).argmax(1)
                correct += (preds==labels).sum().item(); total+=labels.size(0)
        print(f"[LinearProbe] Epoch {ep} ValAcc {correct/total:.3f}")


In [None]:
# Linear probing روی Imagenette
run_linear_probe("./checkpoints/dph_mae_epoch20.pth", epochs=3, batch_size=128)

[LinearProbe] Epoch 1 TrainAcc 0.175 Loss 2.2368
[LinearProbe] Epoch 1 ValAcc 0.220
[LinearProbe] Epoch 2 TrainAcc 0.224 Loss 2.1463
[LinearProbe] Epoch 2 ValAcc 0.236
[LinearProbe] Epoch 3 TrainAcc 0.243 Loss 2.1122
[LinearProbe] Epoch 3 ValAcc 0.255


In [None]:
# UPerNet Fine-tuning (FakeData segmentation)
def encode_with_pos_interp(encoder, x):
    pooled, outs, grids = [], [], []
    device = x.device
    for pe, pos, blk in zip(encoder.patch_embeds, encoder.pos_embeds, encoder.blocks):
        tokens, grid = pe(x)           # (B, N_new, C)
        N_new = tokens.shape[1]
        pos_param = pos
        N_old = pos_param.shape[1]
        if N_old != N_new:
            s_old = int(math.sqrt(N_old))
            c = pos_param.shape[2]
            pos_map = pos_param.reshape(1, s_old, s_old, c).permute(0,3,1,2)
            H_new, W_new = grid
            pos_map_interp = F.interpolate(pos_map.to(device), size=(H_new, W_new), mode='bilinear', align_corners=False)
            pos_interp = pos_map_interp.permute(0,2,3,1).reshape(1, H_new*W_new, c)
        else:
            pos_interp = pos_param.to(device)
        tokens = tokens + pos_interp
        tokens = blk(tokens)
        pooled.append(tokens.mean(dim=1))
        outs.append(tokens)
        grids.append(grid)
    fused = torch.cat(pooled, dim=1)
    fused = encoder.fusion_proj(fused.to(device))
    fused = encoder.norm(fused).unsqueeze(1)
    return fused, outs, grids


# tokens -> feature map
def tokens_to_map(tokens):
    B,N,C = tokens.shape
    s = int(math.sqrt(N))
    return tokens.transpose(1,2).reshape(B,C,s,s)

In [None]:
# FPN + PSP head (UPerNet)
class SimpleFPN(nn.Module):
    def __init__(self,in_channels,out_ch=128):
        super().__init__()
        self.lats = nn.ModuleList([nn.Conv2d(c,out_ch,1) for c in in_channels])
    def forward(self,feats):
        target_h, target_w = feats[0].shape[2:]
        outs=[]
        for f,lat in zip(feats,self.lats):
            p = lat(f)
            if p.shape[2:] != (target_h,target_w):
                p = F.interpolate(p, size=(target_h,target_w), mode='bilinear', align_corners=False)
            outs.append(p)
        return sum(outs)

class PSPModule(nn.Module):
    def __init__(self,in_ch,pool_sizes=(1,2,3,6),out_ch=128):
        super().__init__()
        self.stages = nn.ModuleList([nn.Sequential(
            nn.AdaptiveAvgPool2d(s),
            nn.Conv2d(in_ch, out_ch, 1),
            nn.ReLU(inplace=True)
        ) for s in pool_sizes])
        self.bottleneck = nn.Conv2d(in_ch + len(pool_sizes)*out_ch, out_ch, 3, padding=1)
    def forward(self, x):
        h,w = x.shape[2:]
        pri = [x]
        for st in self.stages:
            pri.append(F.interpolate(st(x), size=(h,w), mode='bilinear', align_corners=False))
        return self.bottleneck(torch.cat(pri, dim=1))

class UPerNetHead(nn.Module):
    def __init__(self, in_channels_list, out_ch=128, num_classes=21):
        super().__init__()
        self.fpn = SimpleFPN(in_channels_list, out_ch=out_ch)
        self.psp = PSPModule(out_ch, pool_sizes=(1,2,3,6), out_ch=out_ch)
        self.conv_last = nn.Sequential(
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        self.classifier = nn.Conv2d(out_ch, num_classes, 1)
    def forward(self, outs_tokens, target_size):
        maps = [ tokens_to_map(t) for t in outs_tokens ]
        fused = self.fpn(maps)
        psp = self.psp(fused)
        x = self.conv_last(psp)
        x = F.interpolate(x, size=target_size, mode='bilinear', align_corners=False)
        return self.classifier(x)

In [None]:
# FakeData segmentation dataset
def build_fake_seg_loaders(batch_size=8, img_size=64, num_classes=21):
    tf_img = transforms.Compose([transforms.Resize((img_size,img_size)), transforms.ToTensor()])
    class FakeSeg(torch.utils.data.Dataset):
        def __init__(self, n=500, split='train'):
            self.ds = FakeData(size=n, image_size=(3,img_size,img_size),
                               num_classes=num_classes, transform=tf_img)
        def __len__(self): return len(self.ds)
        def __getitem__(self, idx):
            img, target = self.ds[idx]
            # target فقط یک کلاس است → کل تصویر رو همون کلاس می‌کنیم
            mask = torch.full((img_size,img_size), target, dtype=torch.long)
            return img, mask
    tr = DataLoader(FakeSeg(500,'train'), batch_size=batch_size, shuffle=True)
    vl = DataLoader(FakeSeg(100,'val'), batch_size=batch_size, shuffle=False)
    return tr, vl

In [None]:
# Training loop
def run_finetune_upernet_fake(checkpoint_path="./checkpoints/dph_mae_epoch20.pth",
                              epochs=3, batch_size=8, lr=1e-4, img_size=64):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Device:", device)

    # load pretrained encoder
    base = DPHMAE_CIFAR()
    ckpt = torch.load(checkpoint_path, map_location='cpu')
    state = ckpt.get('model', ckpt)
    base.load_state_dict(state, strict=False)
    encoder = base.encoder.to(device)

    # UPerNet head
    head = UPerNetHead(in_channels_list=[64,96,128], out_ch=128, num_classes=21).to(device)

    class FullSegModel(nn.Module):
        def __init__(self, encoder, head):
            super().__init__()
            self.encoder = encoder
            self.head = head
        def forward(self, x):
            fused, outs, grids = encode_with_pos_interp(self.encoder, x)
            return self.head(outs, target_size=(x.shape[2], x.shape[3]))

    model = FullSegModel(encoder, head).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    tr_loader, vl_loader = build_fake_seg_loaders(batch_size=batch_size, img_size=img_size)

    for ep in range(1, epochs+1):
        model.train()
        total_loss=0; n=0
        for imgs, masks in tr_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            logits = model(imgs)
            loss = F.cross_entropy(logits, masks)
            optimizer.zero_grad(); loss.backward(); optimizer.step()
            total_loss += loss.item()*imgs.size(0); n+=imgs.size(0)
        print(f"[Epoch {ep}] Train Loss: {total_loss/n:.4f}")

    torch.save({'model': model.state_dict()}, "upernet_finetuned_fake.pth")
    print("✅ Fine-tuning finished. Model saved as upernet_finetuned_fake.pth")

In [None]:
# Fine-tune + UPerNet
run_finetune_upernet_fake("./checkpoints/dph_mae_epoch20.pth", epochs=3, batch_size=8, img_size=64)

Device: cuda
[Epoch 1] Train Loss: 3.0836
[Epoch 2] Train Loss: 3.0422
[Epoch 3] Train Loss: 2.9528
✅ Fine-tuning finished. Model saved as upernet_finetuned_fake.pth
