# MAE Downstream Tasks with PyTorch Lightning

Three downstream tasks using pretrained MAE encoder:
1. Classification - CIFAR-10
2. Object Detection - Oxford-IIIT Pet
3. Semantic Segmentation - Oxford-IIIT Pet

In [None]:
import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from torchmetrics import Accuracy
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.patches as patches

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

PROJECT_ROOT = Path.cwd() if (Path.cwd() / "checkpoints").exists() else Path.cwd().parent
WEIGHT_PATH = PROJECT_ROOT / "checkpoints" / "mae_pretrain_vit_base.pth"
LOG_DIR = PROJECT_ROOT / "logs"


TypeError: expected string or bytes-like object

In [None]:

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)
        
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.act = nn.GELU()
        
    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim * mlp_ratio))
        
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class MAEEncoder(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.):
        super().__init__()
        self.patch_size = patch_size
        num_patches = (img_size // patch_size) ** 2
        
        class PatchEmbed(nn.Module):
            def __init__(self):
                super().__init__()
                self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
            def forward(self, x):
                return self.proj(x)
        
        self.patch_embed = PatchEmbed()
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.blocks = nn.ModuleList([Block(embed_dim, num_heads, mlp_ratio) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x

def load_mae_encoder(weight_path):
    model = MAEEncoder()
    ckpt = torch.load(weight_path, map_location="cpu", weights_only=False)
    if 'model' in ckpt:
        state_dict = ckpt['model']
    elif 'state_dict' in ckpt:
        state_dict = ckpt['state_dict']
    else:
        state_dict = ckpt
    
    for prefix in ["module.", "model.", "net."]:
        state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
    
    model.load_state_dict(state_dict, strict=False)
    return model

encoder = load_mae_encoder(WEIGHT_PATH)


# 1. Classification (CIFAR-10)


In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# CIFAR-10 DataModule
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=32, num_workers=0):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
    def prepare_data(self):
        datasets.CIFAR10(self.data_dir, train=True, download=True)
        datasets.CIFAR10(self.data_dir, train=False, download=True)
    
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            full_train = datasets.CIFAR10(self.data_dir, train=True, transform=self.transform)
            train_size = int(0.9 * len(full_train))
            val_size = len(full_train) - train_size
            self.train_dataset, self.val_dataset = random_split(full_train, [train_size, val_size])
        
        if stage == 'test' or stage is None:
            self.test_dataset = datasets.CIFAR10(self.data_dir, train=False, transform=self.transform)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

# Classification Model
class MAEClassifier(pl.LightningModule):
    def __init__(self, encoder, num_classes=10, freeze_encoder=True, lr=1e-3, weight_decay=1e-4):
        super().__init__()
        self.save_hyperparameters(ignore=['encoder'])
        
        self.encoder = encoder
        if freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False
        
        self.head = nn.Linear(768, num_classes)
        self.lr = lr
        self.weight_decay = weight_decay
        
        self.train_acc = Accuracy(task='multiclass', num_classes=num_classes)
        self.val_acc = Accuracy(task='multiclass', num_classes=num_classes)
        self.test_acc = Accuracy(task='multiclass', num_classes=num_classes)
        
    def forward(self, x):
        features = self.encoder(x)
        cls_token = features[:, 0]
        return self.head(cls_token)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = self.train_acc(preds, y)
        
        self.log('train/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train/acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = self.val_acc(preds, y)
        
        self.log('val/loss', loss, on_epoch=True, prog_bar=True)
        self.log('val/acc', acc, on_epoch=True, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = self.test_acc(preds, y)
        
        self.log('test/loss', loss)
        self.log('test/acc', acc)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.head.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val/loss'
            }
        }


In [None]:
# Classification 학습
data_module = CIFAR10DataModule(data_dir=str(PROJECT_ROOT / "data"), batch_size=32, num_workers=0)
model = MAEClassifier(encoder=encoder, num_classes=10, freeze_encoder=True, lr=1e-3, weight_decay=1e-4)

tensorboard_logger = TensorBoardLogger(LOG_DIR, name="mae_classification")
csv_logger = CSVLogger(LOG_DIR, name="mae_classification")

checkpoint_callback = ModelCheckpoint(
    dirpath=LOG_DIR / "checkpoints",
    filename="mae-{epoch:02d}-{val/acc:.4f}",
    monitor="val/acc",
    mode="max",
    save_top_k=3
)

early_stop_callback = EarlyStopping(monitor="val/loss", patience=20, mode="min", verbose=True)
lr_monitor = LearningRateMonitor(logging_interval='epoch')

trainer = pl.Trainer(
    max_epochs=200,
    accelerator="auto",
    devices=1,
    logger=[tensorboard_logger, csv_logger],
    callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],
    log_every_n_steps=10,
    enable_progress_bar=True
)

print("\nStarting Classification training...")
trainer.fit(model, data_module)
print("Training completed!")

print("\nEvaluating on validation set...")
trainer.validate(model, data_module)
print("Validation completed!")

print("\nEvaluating on test set...")
trainer.test(model, data_module)
print("Test completed!")


In [None]:
# Classification 시각화
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

model = model.to(device)
model.eval()
test_loader = data_module.test_dataloader()

all_preds, all_labels = [], []
sample_images, sample_preds, sample_labels = [], [], []

with torch.no_grad():
    for i, (x, y) in enumerate(test_loader):
        x = x.to(device)
        logits = model(x)
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(y.numpy())
        if i == 0:
            sample_images = x[:16].cpu()
            sample_preds = preds[:16]
            sample_labels = y[:16].numpy()

all_preds, all_labels = np.array(all_preds), np.array(all_labels)

fig = plt.figure(figsize=(16, 10))
fig.suptitle('Classification Test Results', fontsize=16, fontweight='bold')
gs = fig.add_gridspec(3, 6, hspace=0.4, wspace=0.3)

for idx in range(12):
    ax = fig.add_subplot(gs[idx // 6, idx % 6])
    img = sample_images[idx].permute(1, 2, 0).numpy()
    img = (img * 0.5 + 0.5).clip(0, 1)
    ax.imshow(img)
    pred_name = class_names[sample_preds[idx]]
    true_name = class_names[sample_labels[idx]]
    color = 'green' if sample_preds[idx] == sample_labels[idx] else 'red'
    ax.set_title(f'Pred: {pred_name}\nTrue: {true_name}', fontsize=9, color=color)
    ax.axis('off')

ax_cm = fig.add_subplot(gs[2, :])
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names, ax=ax_cm, cbar_kws={'shrink': 0.8})
ax_cm.set_xlabel('Predicted')
ax_cm.set_ylabel('True')
ax_cm.set_title('Confusion Matrix')

plt.savefig(PROJECT_ROOT / "logs" / "classification_test_results.png", dpi=150, bbox_inches='tight')
plt.show()

accuracy = (all_preds == all_labels).mean() * 100
print(f"\nTest Accuracy: {accuracy:.2f}%")
print(f"Total samples: {len(all_labels)}")


# 2. Object Detection (Oxford-IIIT Pet)


In [None]:
from torchvision.datasets import OxfordIIITPet

class PetDetectionDataset(torch.utils.data.Dataset):
    def __init__(self, root, split='trainval', transform=None):
        self.dataset_seg = OxfordIIITPet(root=root, split=split, target_types='segmentation', download=True)
        self.dataset_cat = OxfordIIITPet(root=root, split=split, target_types='category', download=False)
        self.transform = transform if transform else transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        self.num_classes = 37
        
    def __len__(self):
        return len(self.dataset_seg)
    
    def __getitem__(self, idx):
        img, mask = self.dataset_seg[idx]
        _, class_idx = self.dataset_cat[idx]
        
        mask_np = np.array(mask)
        non_zero = np.where(mask_np > 0)
        if len(non_zero[0]) > 0:
            y_min, y_max = non_zero[0].min(), non_zero[0].max()
            x_min, x_max = non_zero[1].min(), non_zero[1].max()
        else:
            h, w = mask_np.shape
            x_min, y_min, x_max, y_max = 0, 0, w, h
        
        orig_w, orig_h = img.size
        bbox = torch.tensor([
            x_min / orig_w * 224,
            y_min / orig_h * 224,
            x_max / orig_w * 224,
            y_max / orig_h * 224
        ], dtype=torch.float32)
        
        img_tensor = self.transform(img)
        return img_tensor, bbox, class_idx


In [None]:
# Pet Detection Model
class PetDetector(pl.LightningModule):
    def __init__(self, encoder, num_classes=37, freeze_encoder=True, lr=1e-3):
        super().__init__()
        self.save_hyperparameters(ignore=['encoder'])
        
        self.encoder = encoder
        if freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False
        
        self.bbox_head = nn.Sequential(
            nn.Linear(768, 512), nn.ReLU(), nn.Dropout(0.1),
            nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 4)
        )
        self.class_head = nn.Sequential(
            nn.Linear(768, 512), nn.ReLU(), nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )
        self.lr = lr
        
    def forward(self, x):
        features = self.encoder(x)
        cls_token = features[:, 0]
        bbox = self.bbox_head(cls_token)
        cls = self.class_head(cls_token)
        return bbox, cls
    
    def training_step(self, batch, batch_idx):
        x, gt_boxes, gt_labels = batch
        pred_boxes, pred_cls = self(x)
        bbox_loss = F.mse_loss(pred_boxes, gt_boxes)
        cls_loss = F.cross_entropy(pred_cls, gt_labels)
        loss = bbox_loss + cls_loss
        self.log('train_det/bbox_loss', bbox_loss, prog_bar=True)
        self.log('train_det/cls_loss', cls_loss, prog_bar=True)
        self.log('train_det/total_loss', loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, gt_boxes, gt_labels = batch
        pred_boxes, pred_cls = self(x)
        bbox_loss = F.mse_loss(pred_boxes, gt_boxes)
        cls_loss = F.cross_entropy(pred_cls, gt_labels)
        loss = bbox_loss + cls_loss
        preds = torch.argmax(pred_cls, dim=1)
        acc = (preds == gt_labels).float().mean()
        self.log('val_det/bbox_loss', bbox_loss)
        self.log('val_det/cls_loss', cls_loss)
        self.log('val_det/total_loss', loss, prog_bar=True)
        self.log('val_det/acc', acc, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(list(self.bbox_head.parameters()) + list(self.class_head.parameters()), 
                                lr=self.lr, weight_decay=1e-4)


In [None]:
# Detection 학습
from torch.utils.data import random_split

pet_det_dataset = PetDetectionDataset(root=str(PROJECT_ROOT / "data"), split='trainval')
train_size = int(0.8 * len(pet_det_dataset))
val_size = len(pet_det_dataset) - train_size
pet_det_train, pet_det_val = random_split(pet_det_dataset, [train_size, val_size])

pet_det_train_loader = DataLoader(pet_det_train, batch_size=16, shuffle=True, num_workers=0)
pet_det_val_loader = DataLoader(pet_det_val, batch_size=16, shuffle=False, num_workers=0)

print(f"Train samples: {train_size}, Val samples: {val_size}")

pet_detector = PetDetector(encoder=encoder, num_classes=37, freeze_encoder=True, lr=1e-3)

pet_det_trainer = pl.Trainer(
    max_epochs=50,
    accelerator="auto",
    devices=1,
    logger=TensorBoardLogger(LOG_DIR, name="pet_detection"),
    callbacks=[
        ModelCheckpoint(dirpath=LOG_DIR / "checkpoints_pet_det", 
                       filename="pet-detector-{epoch:02d}-{val_det/total_loss:.4f}",
                       monitor="val_det/total_loss", mode="min", save_top_k=3),
        EarlyStopping(monitor="val_det/total_loss", patience=10, mode="min", verbose=True),
        LearningRateMonitor(logging_interval='epoch')
    ],
    log_every_n_steps=10,
    enable_progress_bar=True
)

print("\nStarting Pet Detection training...")
pet_det_trainer.fit(pet_detector, pet_det_train_loader, pet_det_val_loader)
print("Pet Detection training completed!")

print("\nEvaluating on validation set...")
pet_det_trainer.validate(pet_detector, pet_det_val_loader)
print("Pet Detection evaluation completed!")


In [None]:
# Detection 테스트 시각화
pet_detector = pet_detector.to(device)
pet_detector.eval()

test_loader = DataLoader(pet_det_val, batch_size=8, shuffle=True, num_workers=0)

with torch.no_grad():
    for x, gt_boxes, gt_labels in test_loader:
        x = x.to(device)
        pred_boxes, pred_cls = pet_detector(x)
        
        x = x.cpu()
        pred_boxes = pred_boxes.cpu()
        pred_cls = torch.argmax(pred_cls, dim=1).cpu()
        
        fig, axes = plt.subplots(2, 4, figsize=(20, 10))
        fig.suptitle('Pet Detection Test Results', fontsize=16, fontweight='bold')
        axes = axes.flatten()
        
        for idx in range(8):
            ax = axes[idx]
            img = x[idx].numpy()
            img = img * np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1) + np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
            img = np.clip(img, 0, 1).transpose(1, 2, 0)
            ax.imshow(img)
            
            gt_box = gt_boxes[idx].numpy()
            pred_box = pred_boxes[idx].numpy()
            
            rect_gt = patches.Rectangle((gt_box[0], gt_box[1]), gt_box[2]-gt_box[0], gt_box[3]-gt_box[1],
                                        linewidth=2, edgecolor='green', facecolor='none', label='GT')
            rect_pred = patches.Rectangle((pred_box[0], pred_box[1]), pred_box[2]-pred_box[0], pred_box[3]-pred_box[1],
                                          linewidth=2, edgecolor='red', facecolor='none', linestyle='--', label='Pred')
            ax.add_patch(rect_gt)
            ax.add_patch(rect_pred)
            
            gt_class = gt_labels[idx].item()
            pred_class = pred_cls[idx].item()
            color = 'green' if gt_labels[idx] == pred_cls[idx] else 'red'
            ax.set_title(f'GT: class {gt_class}\nPred: class {pred_class}', fontsize=10, color=color)
            ax.axis('off')
            if idx == 0:
                ax.legend(loc='upper right', fontsize=8)
        
        plt.tight_layout()
        plt.savefig(PROJECT_ROOT / "logs" / "pet_detection_test_results.png", dpi=150, bbox_inches='tight')
        plt.show()
        break


# 3. Semantic Segmentation (Oxford-IIIT Pet)


In [None]:
# Pet Segmentation Dataset
class PetSegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, root, split='trainval'):
        self.dataset = OxfordIIITPet(root=root, split=split, target_types='segmentation', download=True)
        self.img_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        self.mask_transform = transforms.Compose([
            transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.NEAREST),
            transforms.PILToTensor()
        ])
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        img, mask = self.dataset[idx]
        
        img_tensor = self.img_transform(img)
        mask_tensor = self.mask_transform(mask).squeeze(0).long()
        
        mask_tensor[mask_tensor == 2] = 0
        mask_tensor[mask_tensor > 0] = 1
        
        return img_tensor, mask_tensor


In [None]:
# Pet Segmentation Model
class PetSegmenter(pl.LightningModule):
    def __init__(self, encoder, num_classes=2, freeze_encoder=True, lr=1e-3):
        super().__init__()
        self.save_hyperparameters(ignore=['encoder'])
        
        self.encoder = encoder
        if freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False
        
        self.seg_head = nn.Sequential(
            nn.ConvTranspose2d(768, 384, kernel_size=2, stride=2),
            nn.BatchNorm2d(384), nn.ReLU(),
            nn.ConvTranspose2d(384, 192, kernel_size=2, stride=2),
            nn.BatchNorm2d(192), nn.ReLU(),
            nn.ConvTranspose2d(192, 96, kernel_size=2, stride=2),
            nn.BatchNorm2d(96), nn.ReLU(),
            nn.ConvTranspose2d(96, num_classes, kernel_size=2, stride=2)
        )
        self.lr = lr
        self.num_classes = num_classes
        
    def forward(self, x):
        features = self.encoder(x)
        patch_features = features[:, 1:]
        B, N, C = patch_features.shape
        H = W = int(N ** 0.5)
        patch_features = patch_features.transpose(1, 2).reshape(B, C, H, W)
        seg_map = self.seg_head(patch_features)
        return seg_map
    
    def training_step(self, batch, batch_idx):
        x, mask = batch
        seg_output = self(x)
        loss = F.cross_entropy(seg_output, mask)
        pred_mask = seg_output.argmax(dim=1)
        acc = (pred_mask == mask).float().mean()
        self.log('train_seg/loss', loss, prog_bar=True)
        self.log('train_seg/acc', acc, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, mask = batch
        seg_output = self(x)
        loss = F.cross_entropy(seg_output, mask)
        pred_mask = seg_output.argmax(dim=1)
        acc = (pred_mask == mask).float().mean()
        self.log('val_seg/loss', loss, prog_bar=True)
        self.log('val_seg/acc', acc, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.seg_head.parameters(), lr=self.lr, weight_decay=1e-4)


In [None]:
# Segmentation 학습
pet_seg_dataset = PetSegmentationDataset(root=str(PROJECT_ROOT / "data"), split='trainval')
train_size = int(0.8 * len(pet_seg_dataset))
val_size = len(pet_seg_dataset) - train_size
pet_seg_train, pet_seg_val = random_split(pet_seg_dataset, [train_size, val_size])

pet_seg_train_loader = DataLoader(pet_seg_train, batch_size=8, shuffle=True, num_workers=0)
pet_seg_val_loader = DataLoader(pet_seg_val, batch_size=8, shuffle=False, num_workers=0)

print(f"Train samples: {train_size}, Val samples: {val_size}")

pet_segmenter = PetSegmenter(encoder=encoder, num_classes=2, freeze_encoder=True, lr=1e-3)

pet_seg_trainer = pl.Trainer(
    max_epochs=50,
    accelerator="auto",
    devices=1,
    logger=TensorBoardLogger(LOG_DIR, name="pet_segmentation"),
    callbacks=[
        ModelCheckpoint(dirpath=LOG_DIR / "checkpoints_pet_seg", 
                       filename="pet-segmenter-{epoch:02d}-{val_seg/loss:.4f}",
                       monitor="val_seg/loss", mode="min", save_top_k=3),
        EarlyStopping(monitor="val_seg/loss", patience=10, mode="min", verbose=True),
        LearningRateMonitor(logging_interval='epoch')
    ],
    log_every_n_steps=10,
    enable_progress_bar=True
)

print("\nStarting Pet Segmentation training...")
pet_seg_trainer.fit(pet_segmenter, pet_seg_train_loader, pet_seg_val_loader)
print("Pet Segmentation training completed!")

print("\nEvaluating on validation set...")
pet_seg_trainer.validate(pet_segmenter, pet_seg_val_loader)
print("Pet Segmentation evaluation completed!")


In [None]:
# Segmentation 테스트 시각화
pet_segmenter = pet_segmenter.to(device)
pet_segmenter.eval()

test_loader = DataLoader(pet_seg_val, batch_size=8, shuffle=True, num_workers=0)

with torch.no_grad():
    for x, gt_mask in test_loader:
        x = x.to(device)
        pred_seg = pet_segmenter(x)
        
        x = x.cpu()
        pred_mask = torch.argmax(pred_seg, dim=1).cpu()
        gt_mask = gt_mask.cpu()
        
        fig, axes = plt.subplots(3, 8, figsize=(24, 9))
        fig.suptitle('Pet Segmentation Test Results', fontsize=16, fontweight='bold')
        
        for idx in range(8):
            img = x[idx].numpy()
            img = img * np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1) + np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
            img = np.clip(img, 0, 1).transpose(1, 2, 0)
            
            axes[0, idx].imshow(img)
            axes[0, idx].set_title('Input', fontsize=10)
            axes[0, idx].axis('off')
            
            axes[1, idx].imshow(gt_mask[idx], cmap='gray', vmin=0, vmax=1)
            axes[1, idx].set_title('GT Mask', fontsize=10)
            axes[1, idx].axis('off')
            
            axes[2, idx].imshow(pred_mask[idx], cmap='gray', vmin=0, vmax=1)
            acc = ((pred_mask[idx] == gt_mask[idx]).float().mean() * 100).item()
            axes[2, idx].set_title(f'Pred (Acc: {acc:.1f}%)', fontsize=10)
            axes[2, idx].axis('off')
        
        plt.tight_layout()
        plt.savefig(PROJECT_ROOT / "logs" / "pet_segmentation_test_results.png", dpi=150, bbox_inches='tight')
        plt.show()
        break

mean_iou = ((pred_mask == gt_mask).float().mean() * 100).item()
print(f"Mean Pixel Accuracy: {mean_iou:.2f}%")


# Summary

## Three Downstream Tasks

### 1. Image Classification (CIFAR-10)
- Dataset: 50,000 train (45,000 train / 5,000 val), 10,000 test
- Model: CLS token → Linear(768→10)
- Epochs: 200 (Early Stop: patience=20)
- Log: logs/mae_classification/

### 2. Object Detection (Oxford-IIIT Pet)
- Dataset: ~7,349 images (auto-download, ~800MB)
- Classes: 37 classes
- Model: CLS token → {Bbox(4), Class(37)}
- Epochs: 50 (Early Stop: patience=10)
- Log: logs/pet_detection/

### 3. Semantic Segmentation (Oxford-IIIT Pet)
- Dataset: ~7,349 images with segmentation masks
- Classes: 2 classes (background, foreground)
- Model: Spatial(768×14×14) → Upsample(4×) → Mask(2×224×224)
- Epochs: 50 (Early Stop: patience=10)
- Log: logs/pet_segmentation/

## TensorBoard

```bash
tensorboard --logdir=logs
```

Browser: http://localhost:6006

Available tasks:
- mae_classification
- pet_detection
- pet_segmentation


In [None]:
import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from torchmetrics import Accuracy

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

PROJECT_ROOT = Path.cwd() if (Path.cwd() / "checkpoints").exists() else Path.cwd().parent
WEIGHT_PATH = PROJECT_ROOT / "checkpoints" / "mae_pretrain_vit_base.pth"
LOG_DIR = PROJECT_ROOT / "logs"


In [None]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)
        
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.act = nn.GELU()
        
    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim * mlp_ratio))
        
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class MAEEncoder(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.):
        super().__init__()
        self.patch_size = patch_size
        num_patches = (img_size // patch_size) ** 2
        
        class PatchEmbed(nn.Module):
            def __init__(self):
                super().__init__()
                self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
            def forward(self, x):
                return self.proj(x)
        
        self.patch_embed = PatchEmbed()
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.blocks = nn.ModuleList([Block(embed_dim, num_heads, mlp_ratio) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x

def load_mae_encoder(weight_path):
    model = MAEEncoder()
    ckpt = torch.load(weight_path, map_location="cpu", weights_only=False)
    if 'model' in ckpt:
        state_dict = ckpt['model']
    elif 'state_dict' in ckpt:
        state_dict = ckpt['state_dict']
    else:
        state_dict = ckpt
    
    for prefix in ["module.", "model.", "net."]:
        state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
    
    model.load_state_dict(state_dict, strict=False)
    return model


In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=32, num_workers=0):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
    def prepare_data(self):
        datasets.CIFAR10(self.data_dir, train=True, download=True)
        datasets.CIFAR10(self.data_dir, train=False, download=True)
    
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            full_train = datasets.CIFAR10(self.data_dir, train=True, transform=self.transform)
            train_size = int(0.9 * len(full_train))
            val_size = len(full_train) - train_size
            self.train_dataset, self.val_dataset = random_split(full_train, [train_size, val_size])
        
        if stage == 'test' or stage is None:
            self.test_dataset = datasets.CIFAR10(self.data_dir, train=False, transform=self.transform)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, 
                         shuffle=True, num_workers=self.num_workers)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, 
                         shuffle=False, num_workers=self.num_workers)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, 
                         shuffle=False, num_workers=self.num_workers)


In [None]:
class MAEClassifier(pl.LightningModule):
    def __init__(self, encoder, num_classes=10, freeze_encoder=True, lr=1e-3, weight_decay=1e-4):
        super().__init__()
        self.save_hyperparameters(ignore=['encoder'])
        
        self.encoder = encoder
        if freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False
        
        self.head = nn.Linear(768, num_classes)
        self.lr = lr
        self.weight_decay = weight_decay
        
        self.train_acc = Accuracy(task='multiclass', num_classes=num_classes)
        self.val_acc = Accuracy(task='multiclass', num_classes=num_classes)
        self.test_acc = Accuracy(task='multiclass', num_classes=num_classes)
        
    def forward(self, x):
        features = self.encoder(x)
        cls_token = features[:, 0]
        return self.head(cls_token)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = self.train_acc(preds, y)
        
        self.log('train/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train/acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = self.val_acc(preds, y)
        
        self.log('val/loss', loss, on_epoch=True, prog_bar=True)
        self.log('val/acc', acc, on_epoch=True, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = self.test_acc(preds, y)
        
        self.log('test/loss', loss)
        self.log('test/acc', acc)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.head.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val/loss'
            }
        }


In [None]:
encoder = load_mae_encoder(WEIGHT_PATH)
print("MAE Encoder loaded")

data_module = CIFAR10DataModule(
    data_dir=str(PROJECT_ROOT / "data"),
    batch_size=32,
    num_workers=0
)

model = MAEClassifier(
    encoder=encoder,
    num_classes=10,
    freeze_encoder=True,
    lr=1e-3,
    weight_decay=1e-4
)

tensorboard_logger = TensorBoardLogger(LOG_DIR, name="mae_classification")
csv_logger = CSVLogger(LOG_DIR, name="mae_classification")

checkpoint_callback = ModelCheckpoint(
    dirpath=LOG_DIR / "checkpoints",
    filename="mae-{epoch:02d}-{val/acc:.4f}",
    monitor="val/acc",
    mode="max",
    save_top_k=3
)

early_stop_callback = EarlyStopping(
    monitor="val/loss",
    patience=20,
    mode="min",
    verbose=True
)

lr_monitor = LearningRateMonitor(logging_interval='epoch')

trainer = pl.Trainer(
    max_epochs=200,
    accelerator="auto",
    devices=1,
    logger=[tensorboard_logger, csv_logger],
    callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],
    log_every_n_steps=10,
    enable_progress_bar=True
)

print("\nStarting training...")
trainer.fit(model, data_module)


In [None]:
print("\nTesting...")
trainer.test(model, data_module)

print(f"\nLogs saved to: {LOG_DIR}")
print(f"TensorBoard: tensorboard --logdir={LOG_DIR}")
print(f"CSV logs: {LOG_DIR / 'mae_classification' / 'version_0' / 'metrics.csv'}")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

metrics_file = list((LOG_DIR / 'mae_classification').rglob('metrics.csv'))[-1]
df = pd.read_csv(metrics_file)

print("Available columns:")
print(df.columns.tolist())
print("\nFirst few rows:")
print(df.head())

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

train_loss_col = 'train_loss_epoch' if 'train_loss_epoch' in df.columns else 'train/loss_epoch'
val_loss_col = 'val_loss' if 'val_loss' in df.columns else 'val/loss'
train_acc_col = 'train_acc_epoch' if 'train_acc_epoch' in df.columns else 'train/acc_epoch'
val_acc_col = 'val_acc' if 'val_acc' in df.columns else 'val/acc'

train_loss = df[['epoch', train_loss_col]].dropna()
val_loss = df[['epoch', val_loss_col]].dropna()
axes[0].plot(train_loss['epoch'], train_loss[train_loss_col], label='Train Loss', marker='o')
axes[0].plot(val_loss['epoch'], val_loss[val_loss_col], label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training & Validation Loss')
axes[0].legend()
axes[0].grid(True)

train_acc = df[['epoch', train_acc_col]].dropna()
val_acc = df[['epoch', val_acc_col]].dropna()
axes[1].plot(train_acc['epoch'], train_acc[train_acc_col] * 100, label='Train Acc', marker='o')
axes[1].plot(val_acc['epoch'], val_acc[val_acc_col] * 100, label='Val Acc', marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training & Validation Accuracy')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig(PROJECT_ROOT / "logs" / "training_curves.png", dpi=150)
plt.show()

print(f"\nBest Val Acc: {val_acc[val_acc_col].max() * 100:.2f}%")
print(f"Final Train Acc: {train_acc[train_acc_col].iloc[-1] * 100:.2f}%")


# Classification Test 결과 시각화


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

model = model.to(device)
model.eval()
test_loader = data_module.test_dataloader()

all_preds = []
all_labels = []
sample_images = []
sample_preds = []
sample_labels = []

with torch.no_grad():
    for i, (x, y) in enumerate(test_loader):
        x = x.to(device)
        logits = model(x)
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        
        all_preds.extend(preds)
        all_labels.extend(y.numpy())
        
        if i == 0:
            sample_images = x[:16].cpu()
            sample_preds = preds[:16]
            sample_labels = y[:16].numpy()

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

fig = plt.figure(figsize=(16, 10))

fig.suptitle('Classification Test Results', fontsize=16, fontweight='bold')

gs = fig.add_gridspec(3, 6, hspace=0.4, wspace=0.3)

for idx in range(12):
    ax = fig.add_subplot(gs[idx // 6, idx % 6])
    img = sample_images[idx].permute(1, 2, 0).numpy()
    img = (img * 0.5 + 0.5).clip(0, 1)
    
    ax.imshow(img)
    pred_name = class_names[sample_preds[idx]]
    true_name = class_names[sample_labels[idx]]
    color = 'green' if sample_preds[idx] == sample_labels[idx] else 'red'
    ax.set_title(f'Pred: {pred_name}\nTrue: {true_name}', fontsize=9, color=color)
    ax.axis('off')

ax_cm = fig.add_subplot(gs[2, :])
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names, ax=ax_cm, cbar_kws={'shrink': 0.8})
ax_cm.set_xlabel('Predicted')
ax_cm.set_ylabel('True')
ax_cm.set_title('Confusion Matrix')

plt.savefig(PROJECT_ROOT / "logs" / "classification_test_results.png", dpi=150, bbox_inches='tight')
plt.show()

accuracy = (all_preds == all_labels).mean() * 100
print(f"\nTest Accuracy: {accuracy:.2f}%")
print(f"Total samples: {len(all_labels)}")


# TensorBoard 사용법

터미널에서 다음 명령어 실행:
```bash
tensorboard --logdir=logs
```

그 다음 브라우저에서 http://localhost:6006 접속

---

## Lightning 로그 구조:
```
logs/
├── mae_classification/
│   └── version_0/
│       ├── hparams.yaml       # 하이퍼파라미터
│       ├── metrics.csv        # CSV 로그
│       └── events.out.tfevents.xxx  # TensorBoard 로그
└── checkpoints/
    └── mae-epoch=XX-val_acc=0.XXXX.ckpt  # 체크포인트
```

---

## 로그 내용:
- `train/loss`: 배치별 + 에폭별 학습 loss
- `train/acc`: 배치별 + 에폭별 학습 정확도
- `val/loss`: 에폭별 검증 loss
- `val/acc`: 에폭별 검증 정확도
- `lr-Adam`: Learning rate 변화
- `test/loss`, `test/acc`: 테스트 결과


# 전체 요약

## 구현된 3가지 Downstream Tasks (실제 데이터셋):

### 1. Image Classification (CIFAR-10)
- **Dataset**: 50,000 train (45,000 train / 5,000 val), 10,000 test
- **Model**: CLS token → Linear(768→10)
- **Epochs**: 200 (Early Stop: patience=20)
- **Metrics**: `train/loss`, `train/acc`, `val/loss`, `val/acc`
- **Log**: `logs/mae_classification/`
- **Result**: ~79% test accuracy

### 2. Object Detection (Pascal VOC 2012)
- **Dataset**: ~5,717 train, ~5,823 val (실제 이미지)
- **Classes**: 20 classes (aeroplane, bicycle, bird, ...)
- **Model**: CLS token → {Bbox(4), Class(20)}
- **Epochs**: 50 (Early Stop: patience=10)
- **Metrics**: `train_det/bbox_loss`, `train_det/cls_loss`, `val_det/acc`
- **Log**: `logs/voc_detection/`
- **Note**: 가장 큰 객체 하나만 detection (simplified)

### 3. Semantic Segmentation (Pascal VOC 2012)
- **Dataset**: ~1,464 train, ~1,449 val (실제 이미지 + 마스크)
- **Classes**: 21 classes (background + 20 objects)
- **Model**: Spatial(768×14×14) → Upsample(4×) → Mask(21×224×224)
- **Epochs**: 50 (Early Stop: patience=10)
- **Metrics**: `train_seg/loss`, `train_seg/acc`, `val_seg/loss`
- **Log**: `logs/voc_segmentation/`

---

## TensorBoard로 모든 Task 동시에 보기:

```bash
tensorboard --logdir=logs
```

브라우저: http://localhost:6006

왼쪽 메뉴에서 각 task별로 전환 가능:
- `mae_classification`
- `voc_detection`
- `voc_segmentation`

---

## 저장된 파일:

- `logs/classification_test_results.png` - Classification confusion matrix
- `logs/voc_detection_test_results.png` - Detection bbox visualization  
- `logs/voc_segmentation_test_results.png` - Segmentation mask visualization
- `logs/checkpoints/` - Classification best models
- `logs/checkpoints_voc_det/` - Detection best models
- `logs/checkpoints_voc_seg/` - Segmentation best models


In [None]:
import numpy as np

def create_detection_sample(size=224, num_boxes=3):
    img = torch.rand(3, size, size)
    boxes = []
    labels = []
    for _ in range(num_boxes):
        x1, y1 = np.random.randint(0, size-50, 2)
        w, h = np.random.randint(30, 70, 2)
        boxes.append([x1, y1, x1+w, y1+h])
        labels.append(np.random.randint(0, 10))
        img[:, y1:y1+h, x1:x1+w] = torch.rand(3, 1, 1) * 0.5 + 0.5
    return img, torch.tensor(boxes, dtype=torch.float32), torch.tensor(labels)

class DetectionDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples=1000):
        self.num_samples = num_samples
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        img, boxes, labels = create_detection_sample()
        return img, (boxes, labels)

class MAEDetector(pl.LightningModule):
    def __init__(self, encoder, num_classes=10, freeze_encoder=True, lr=1e-3):
        super().__init__()
        self.save_hyperparameters(ignore=['encoder'])
        
        self.encoder = encoder
        if freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False
        
        self.bbox_head = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 4)
        )
        self.class_head = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )
        self.lr = lr
        
    def forward(self, x):
        features = self.encoder(x)
        patch_features = features[:, 1:]
        pooled = patch_features.mean(dim=1)
        bbox = self.bbox_head(pooled)
        cls = self.class_head(pooled)
        return bbox, cls
    
    def training_step(self, batch, batch_idx):
        x, (gt_boxes, gt_labels) = batch
        pred_boxes, pred_cls = self(x)
        
        bbox_loss = F.mse_loss(pred_boxes, gt_boxes[:, 0])
        cls_loss = F.cross_entropy(pred_cls, gt_labels[:, 0])
        loss = bbox_loss + cls_loss
        
        self.log('train/bbox_loss', bbox_loss, prog_bar=True)
        self.log('train/cls_loss', cls_loss, prog_bar=True)
        self.log('train/total_loss', loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, (gt_boxes, gt_labels) = batch
        pred_boxes, pred_cls = self(x)
        
        bbox_loss = F.mse_loss(pred_boxes, gt_boxes[:, 0])
        cls_loss = F.cross_entropy(pred_cls, gt_labels[:, 0])
        loss = bbox_loss + cls_loss
        
        self.log('val/bbox_loss', bbox_loss)
        self.log('val/cls_loss', cls_loss)
        self.log('val/total_loss', loss, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(list(self.bbox_head.parameters()) + list(self.class_head.parameters()), 
                                lr=self.lr)

print("Object Detection module defined")


In [None]:
def create_segmentation_sample(size=224, num_classes=5):
    img = torch.rand(3, size, size)
    mask = torch.zeros(size, size, dtype=torch.long)
    for c in range(1, num_classes):
        x, y = np.random.randint(0, size-60, 2)
        w, h = np.random.randint(40, 80, 2)
        mask[y:y+h, x:x+w] = c
        img[:, y:y+h, x:x+w] = torch.tensor([c/num_classes, 0.5, 1-c/num_classes]).view(3, 1, 1)
    return img, mask

class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples=1000):
        self.num_samples = num_samples
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return create_segmentation_sample()

class MAESegmenter(pl.LightningModule):
    def __init__(self, encoder, num_classes=5, freeze_encoder=True, lr=1e-3):
        super().__init__()
        self.save_hyperparameters(ignore=['encoder'])
        
        self.encoder = encoder
        if freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False
        
        self.seg_head = nn.Sequential(
            nn.ConvTranspose2d(768, 256, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, num_classes, kernel_size=2, stride=2)
        )
        self.lr = lr
        self.num_classes = num_classes
        
    def forward(self, x):
        features = self.encoder(x)
        patch_features = features[:, 1:]
        B, N, C = patch_features.shape
        H = W = int(N ** 0.5)
        patch_features = patch_features.transpose(1, 2).reshape(B, C, H, W)
        seg_map = self.seg_head(patch_features)
        return seg_map
    
    def training_step(self, batch, batch_idx):
        x, mask = batch
        seg_output = self(x)
        loss = F.cross_entropy(seg_output, mask)
        
        pred_mask = seg_output.argmax(dim=1)
        acc = (pred_mask == mask).float().mean()
        
        self.log('train/seg_loss', loss, prog_bar=True)
        self.log('train/seg_acc', acc, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, mask = batch
        seg_output = self(x)
        loss = F.cross_entropy(seg_output, mask)
        
        pred_mask = seg_output.argmax(dim=1)
        acc = (pred_mask == mask).float().mean()
        
        self.log('val/seg_loss', loss, prog_bar=True)
        self.log('val/seg_acc', acc, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.seg_head.parameters(), lr=self.lr)

print("Semantic Segmentation module defined")


In [None]:
print("\n" + "="*60)
print("Object Detection Training")
print("="*60)

det_train_dataset = DetectionDataset(num_samples=800)
det_val_dataset = DetectionDataset(num_samples=200)

det_train_loader = DataLoader(det_train_dataset, batch_size=16, shuffle=True)
det_val_loader = DataLoader(det_val_dataset, batch_size=16, shuffle=False)

detector = MAEDetector(encoder=encoder, num_classes=10, freeze_encoder=True, lr=1e-3)

det_trainer = pl.Trainer(
    max_epochs=200,
    accelerator="auto",
    devices=1,
    logger=TensorBoardLogger(LOG_DIR, name="mae_detection"),
    callbacks=[
        ModelCheckpoint(dirpath=LOG_DIR / "checkpoints_det", 
                       filename="detector-{epoch:02d}-{val/total_loss:.4f}",
                       monitor="val/total_loss", mode="min", save_top_k=3),
        EarlyStopping(monitor="val/total_loss", patience=20, mode="min", verbose=True),
        LearningRateMonitor(logging_interval='epoch')
    ],
    log_every_n_steps=10,
    enable_progress_bar=True
)

det_trainer.fit(detector, det_train_loader, det_val_loader)
print("Detection training completed!")


# Detection Test 결과 시각화


In [None]:
import matplotlib.patches as patches

test_dataset = DetectionDataset(num_samples=16)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

detector = detector.to(device)
detector.eval()
with torch.no_grad():
    for x, (gt_boxes, gt_labels) in test_loader:
        x = x.to(device)
        pred_boxes, pred_cls = detector(x)
        
        x = x.cpu()
        pred_boxes = pred_boxes.cpu()
        pred_cls = torch.argmax(pred_cls, dim=1).cpu()
        
        fig, axes = plt.subplots(2, 4, figsize=(16, 8))
        fig.suptitle('Object Detection Test Results', fontsize=16, fontweight='bold')
        axes = axes.flatten()
        
        for idx in range(8):
            ax = axes[idx]
            img = x[idx].permute(1, 2, 0).numpy()
            ax.imshow(img)
            
            gt_box = gt_boxes[idx][0].numpy()
            pred_box = pred_boxes[idx].numpy()
            
            rect_gt = patches.Rectangle((gt_box[0], gt_box[1]), gt_box[2]-gt_box[0], gt_box[3]-gt_box[1],
                                        linewidth=2, edgecolor='green', facecolor='none', label='GT')
            rect_pred = patches.Rectangle((pred_box[0], pred_box[1]), pred_box[2]-pred_box[0], pred_box[3]-pred_box[1],
                                          linewidth=2, edgecolor='red', facecolor='none', linestyle='--', label='Pred')
            
            ax.add_patch(rect_gt)
            ax.add_patch(rect_pred)
            
            gt_label = gt_labels[idx][0].item()
            pred_label = pred_cls[idx].item()
            ax.set_title(f'GT: class {gt_label} | Pred: class {pred_label}', fontsize=10)
            ax.axis('off')
            
            if idx == 0:
                ax.legend(loc='upper right', fontsize=8)
        
        plt.tight_layout()
        plt.savefig(PROJECT_ROOT / "logs" / "detection_test_results.png", dpi=150, bbox_inches='tight')
        plt.show()
        break

print("Detection visualization completed!")


In [None]:
print("\n" + "="*60)
print("Semantic Segmentation Training")
print("="*60)

seg_train_dataset = SegmentationDataset(num_samples=800)
seg_val_dataset = SegmentationDataset(num_samples=200)

seg_train_loader = DataLoader(seg_train_dataset, batch_size=16, shuffle=True)
seg_val_loader = DataLoader(seg_val_dataset, batch_size=16, shuffle=False)

segmenter = MAESegmenter(encoder=encoder, num_classes=5, freeze_encoder=True, lr=1e-3)

seg_trainer = pl.Trainer(
    max_epochs=200,
    accelerator="auto",
    devices=1,
    logger=TensorBoardLogger(LOG_DIR, name="mae_segmentation"),
    callbacks=[
        ModelCheckpoint(dirpath=LOG_DIR / "checkpoints_seg", 
                       filename="segmenter-{epoch:02d}-{val/seg_loss:.4f}",
                       monitor="val/seg_loss", mode="min", save_top_k=3),
        EarlyStopping(monitor="val/seg_loss", patience=20, mode="min", verbose=True),
        LearningRateMonitor(logging_interval='epoch')
    ],
    log_every_n_steps=10,
    enable_progress_bar=True
)

seg_trainer.fit(segmenter, seg_train_loader, seg_val_loader)
print("Segmentation training completed!")


# 전체 요약

## 구현된 3가지 Downstream Tasks:

### 1. Image Classification (CIFAR-10)
- Dataset: 50,000 train (45,000 train / 5,000 val), 10,000 test
- Model: CLS token → Linear(768→10)
- Epochs: 200 (Early Stop: patience=20)
- Metrics: `train/loss`, `train/acc`, `val/loss`, `val/acc`
- Log: `logs/mae_classification/`

### 2. Object Detection (Synthetic)
- Dataset: 800 train, 200 val (synthetic)
- Model: Patch tokens → Pool → {Bbox(4), Class(10)}
- Epochs: 200 (Early Stop: patience=20)
- Metrics: `train/bbox_loss`, `train/cls_loss`, `val/total_loss`
- Log: `logs/mae_detection/`

### 3. Semantic Segmentation (Synthetic)
- Dataset: 800 train, 200 val (synthetic)
- Model: Spatial(768×14×14) → Upsample(4×) → Mask(5×224×224)
- Epochs: 200 (Early Stop: patience=20)
- Metrics: `train/seg_loss`, `train/seg_acc`, `val/seg_loss`
- Log: `logs/mae_segmentation/`

---

## TensorBoard로 모든 Task 동시에 보기:

```bash
tensorboard --logdir=logs
```

브라우저: http://localhost:6006

왼쪽 메뉴에서 각 task별로 전환 가능:
- `mae_classification`
- `mae_detection`
- `mae_segmentation`


# Segmentation Test 결과 시각화


In [None]:
test_dataset = SegmentationDataset(num_samples=8)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

segmenter = segmenter.to(device)
segmenter.eval()
with torch.no_grad():
    for x, gt_mask in test_loader:
        x = x.to(device)
        pred_seg = segmenter(x)
        
        x = x.cpu()
        pred_mask = torch.argmax(pred_seg, dim=1).cpu()
        gt_mask = gt_mask.cpu()
        
        fig, axes = plt.subplots(3, 8, figsize=(20, 8))
        fig.suptitle('Semantic Segmentation Test Results', fontsize=16, fontweight='bold')
        
        for idx in range(8):
            img = x[idx].permute(1, 2, 0).numpy()
            
            axes[0, idx].imshow(img)
            axes[0, idx].set_title('Input', fontsize=10)
            axes[0, idx].axis('off')
            
            axes[1, idx].imshow(gt_mask[idx], cmap='tab10', vmin=0, vmax=9)
            axes[1, idx].set_title('GT Mask', fontsize=10)
            axes[1, idx].axis('off')
            
            axes[2, idx].imshow(pred_mask[idx], cmap='tab10', vmin=0, vmax=9)
            iou = ((pred_mask[idx] == gt_mask[idx]).float().mean() * 100).item()
            axes[2, idx].set_title(f'Pred (IoU: {iou:.1f}%)', fontsize=10)
            axes[2, idx].axis('off')
        
        plt.tight_layout()
        plt.savefig(PROJECT_ROOT / "logs" / "segmentation_test_results.png", dpi=150, bbox_inches='tight')
        plt.show()
        break

mean_iou = ((pred_mask == gt_mask).float().mean() * 100).item()
print(f"Mean IoU: {mean_iou:.2f}%")
print("Segmentation visualization completed!")
