<a href="https://colab.research.google.com/github/sangjun315/FedAvg-Pytorch/blob/main/2025_Khuggle_Baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Diffusion Distillation Challenge Baseline**

**Teacher**: DDPM-style UNet on Tiny-ImageNet (64√ó64)  
**Student**: Smaller DDPM-style UNet

Ïù¥ ÎÖ∏Ìä∏Î∂ÅÏùÄ Îã§ÏùåÏùÑ Ìè¨Ìï®Ìï©ÎãàÎã§.

1. Tiny-ImageNet Îç∞Ïù¥ÌÑ∞ Î°úÎçî / val ÌîåÎû´Îãù
2. DDPM Teacher UNet Ï†ïÏùò
3. TeacherÏö© DDIM ÏÉòÌîåÎü¨ + FID Í≥ÑÏÇ∞ (pytorch-fid ÏÇ¨Ïö©)
4. Student DDPM UNet (Íµ¨Ï°∞ Í≥†Ï†ï)
5. Teacher ‚Üí Student epsilon distillation ÌïôÏäµ ÏΩîÎìú (baseline)
6. Student ÏÉòÌîå Î∞è FID ÏòàÏãú Í≥ÑÏÇ∞
7. (ÎåÄÌöåÏö©) Ï†êÏàò Í≥ÑÏÇ∞ ÏòàÏãú

================================================================================
1. Tiny-ImageNet data loader / val set flattening
2. Definition of the DDPM Teacher UNet
3. DDIM sampler for the Teacher + FID computation (using pytorch-fid)
4. Student DDPM UNet (architecture fixed)
5. Training code for Teacher ‚Üí Student epsilon distillation (baseline)
6. Example of Student sampling and FID computation
7. (For the competition) Example of score calculation

## **1. TinyImageNet Download and Unzip**


In [None]:
# Tiny-ImageNet Îã§Ïö¥Î°úÎìú & ÏïïÏ∂ïÌï¥Ï†ú

!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip -O tiny-imagenet-200.zip

!unzip -q tiny-imagenet-200.zip -d .

!ls tiny-imagenet-200
!pip install pytorch-fid

--2025-11-23 08:17:38--  http://cs231n.stanford.edu/tiny-imagenet-200.zip
Resolving cs231n.stanford.edu (cs231n.stanford.edu)... 171.64.64.64
Connecting to cs231n.stanford.edu (cs231n.stanford.edu)|171.64.64.64|:80... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://cs231n.stanford.edu/tiny-imagenet-200.zip [following]
--2025-11-23 08:17:38--  https://cs231n.stanford.edu/tiny-imagenet-200.zip
Connecting to cs231n.stanford.edu (cs231n.stanford.edu)|171.64.64.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 248100043 (237M) [application/zip]
Saving to: ‚Äòtiny-imagenet-200.zip‚Äô


2025-11-23 08:18:08 (7.98 MB/s) - ‚Äòtiny-imagenet-200.zip‚Äô saved [248100043/248100043]

test  train  val  wnids.txt  words.txt
Collecting pytorch-fid
  Downloading pytorch_fid-0.3.0-py3-none-any.whl.metadata (5.3 kB)
Downloading pytorch_fid-0.3.0-py3-none-any.whl (15 kB)
Installing collected packages: pytorch-fid
Successfully installed py

## **2. Set up the environments and import the basic module**


In [None]:
# 1. ÌôòÍ≤Ω ÏÑ§Ï†ï & Í∏∞Î≥∏ import
import os
import math
import time
import sys
import shutil
from pathlib import Path
import subprocess

import torch
import torch.nn as nn
import numpy as np
import random
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid

print("Python:", sys.version)
print("PyTorch:", torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("[info] using device:", device)

def format_time(secs: float) -> str:
    secs = int(secs)
    h = secs // 3600
    m = (secs % 3600) // 60
    s = secs % 60
    return f"{h:02d}:{m:02d}:{s:02d}"


def make_beta_schedule(num_train_timesteps=1000, beta_start=1e-4, beta_end=0.02):
    betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    return betas, alphas, alphas_cumprod


def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)


Python: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
PyTorch: 2.9.0+cu126
[info] using device: cuda


## **3. Dataset loader**

In [None]:
from torchvision.datasets import ImageFolder

def get_tiny_train_loader(
    data_root: str,
    batch_size: int = 128,
    num_workers: int = 4,
):
    """
    Tiny-ImageNet train Ìè¥Îçî Í∏∞Ï§Ä Î°úÎçî.
    - data_root/tiny-imagenet-200/train/<class>/*.JPEG
    - ÏûÖÎ†• Ïù¥ÎØ∏ÏßÄÎäî 64x64Î°ú resize, [-1,1]Î°ú normalize.
    """
    train_dir = os.path.join(data_root, "tiny-imagenet-200", "train")
    print("[info] train_dir:", train_dir)

    tfm = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),  # [0,1]
        transforms.Normalize((0.5, 0.5, 0.5),
                             (0.5, 0.5, 0.5)),  # [-1,1]
    ])

    dataset = ImageFolder(train_dir, transform=tfm)
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True,
    )
    print("[info] Tiny-ImageNet train size:", len(dataset))
    return loader


def flatten_tiny_val(
    data_root: str,
    out_dir: str = "./tiny_val_flat",
):
    """

    - data_root/tiny-imagenet-200/val/images/*.JPEG
    - out_dir/val_000000.jpeg ...
    """
    root_tiny = Path(data_root) / "tiny-imagenet-200"
    val_dir = root_tiny / "val" / "images"
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    cnt = 0
    if val_dir.exists():
        for img_path in val_dir.glob("*.JPEG"):
            dst = out_dir / f"val_{cnt:06d}.jpeg"
            if not dst.exists():
                shutil.copy(str(img_path), str(dst))
            cnt += 1
        print("flattened val images:", cnt, "->", out_dir)
    else:
        print("WARNING: val/images not found under", root_tiny)

    return str(out_dir)

flatten_tiny_val(
    data_root="/content/",  # Tiny-ImageNet ÏïïÏ∂ïÏùÑ ÌíÄÏñ¥Îëî Î£®Ìä∏
    out_dir="./tiny_val_flat",
)


flattened val images: 10000 -> tiny_val_flat


'tiny_val_flat'

## **4. Sinusoidal time embedding & UNet Blocks**

In [None]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        """
        t: (B,) long or float, time index [0..T-1]
        return: (B, dim)
        """
        device = t.device
        half = self.dim // 2
        freqs = torch.exp(
            torch.arange(half, device=device, dtype=torch.float32)
            * -(math.log(10000.0) / (half - 1))
        )
        if t.dtype != torch.float32:
            t = t.float()
        args = t[:, None] * freqs[None, :]
        emb = torch.cat([args.sin(), args.cos()], dim=-1)
        return emb


class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim, groups=8):
        super().__init__()
        self.norm1 = nn.GroupNorm(groups, in_ch)
        self.act1  = nn.SiLU()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)

        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_dim, out_ch)
        )

        self.norm2 = nn.GroupNorm(groups, out_ch)
        self.act2  = nn.SiLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)

        if in_ch != out_ch:
            self.skip = nn.Conv2d(in_ch, out_ch, 1)
        else:
            self.skip = nn.Identity()

    def forward(self, x, t_emb):
        h = self.conv1(self.act1(self.norm1(x)))
        t_h = self.time_mlp(t_emb)[:, :, None, None]
        h = h + t_h
        h = self.conv2(self.act2(self.norm2(h)))
        return h + self.skip(x)


class DownBlockT(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim):
        super().__init__()
        self.res1 = ResBlock(in_ch, out_ch, time_dim)
        self.res2 = ResBlock(out_ch, out_ch, time_dim)
        self.down = nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1)

    def forward(self, x, t_emb):
        x = self.res1(x, t_emb)
        x = self.res2(x, t_emb)
        skip = x
        x = self.down(x)
        return x, skip


class UpBlockT(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim):
        super().__init__()
        self.res1 = ResBlock(in_ch, out_ch, time_dim)
        self.res2 = ResBlock(out_ch, out_ch, time_dim)
        self.up   = nn.Upsample(scale_factor=2, mode="nearest")
        self.conv = nn.Conv2d(out_ch, out_ch, 3, padding=1)

    def forward(self, x, skip, t_emb):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        x = self.res1(x, t_emb)
        x = self.res2(x, t_emb)
        x = self.conv(x)
        return x


## **5. Teacher model structure**

In [None]:
class TeacherUNet(nn.Module):
    """
    Tiny-ImageNet 64x64Ïö© DDPM-style Teacher UNet (epsilon prediction).
    ÏûÖÎ†•: x_t (B,3,64,64), t (B,) int64
    Ï∂úÎ†•: eps_pred (B,3,64,64)
    """
    def __init__(self, img_ch=3, base_ch=128, time_dim=512):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_dim),
            nn.Linear(time_dim, time_dim),
            nn.SiLU(),
            nn.Linear(time_dim, time_dim),
        )

        self.inc = nn.Conv2d(img_ch, base_ch, 3, padding=1)

        self.down1 = DownBlockT(base_ch,      base_ch*2, time_dim)  # 64->32
        self.down2 = DownBlockT(base_ch*2,    base_ch*4, time_dim)  # 32->16
        self.down3 = DownBlockT(base_ch*4,    base_ch*4, time_dim)  # 16->8
        self.down4 = DownBlockT(base_ch*4,    base_ch*4, time_dim)  # 8->4

        self.mid1 = ResBlock(base_ch*4, base_ch*4, time_dim)
        self.mid2 = ResBlock(base_ch*4, base_ch*4, time_dim)

        self.up4 = UpBlockT(base_ch*4 + base_ch*4, base_ch*4, time_dim)  # 4->8
        self.up3 = UpBlockT(base_ch*4 + base_ch*4, base_ch*4, time_dim)  # 8->16
        self.up2 = UpBlockT(base_ch*4 + base_ch*4, base_ch*2, time_dim)  # 16->32
        self.up1 = UpBlockT(base_ch*2 + base_ch*2, base_ch,   time_dim)  # 32->64

        self.outc = nn.Conv2d(base_ch, img_ch, 3, padding=1)

    def forward(self, x, t):
        t_emb = self.time_mlp(t)
        x0 = self.inc(x)
        x1, s1 = self.down1(x0, t_emb)
        x2, s2 = self.down2(x1, t_emb)
        x3, s3 = self.down3(x2, t_emb)
        x4, s4 = self.down4(x3, t_emb)

        m  = self.mid1(x4, t_emb)
        m  = self.mid2(m,  t_emb)

        u4 = self.up4(m,  s4, t_emb)
        u3 = self.up3(u4, s3, t_emb)
        u2 = self.up2(u3, s2, t_emb)
        u1 = self.up1(u2, s1, t_emb)

        out = self.outc(u1)
        return out  # eps prediction


## **6. DDIM Sampler**

In [None]:
@torch.no_grad()
def ddim_sample(model, alphas_cumprod, num_train_timesteps, z, steps=50):
    """
    DDIM(eta=0) deterministic sampler.
    - model: epsilon prediction UNet (Teacher or Student)
    - alphas_cumprod: (T,)
    - z: (B,3,64,64) ~ N(0,I)
    """
    device = z.device
    b = z.size(0)
    x = z

    T = num_train_timesteps
    step_indices = torch.linspace(T - 1, 0, steps=steps, device=device).long()

    for i, t in enumerate(step_indices):
        t_batch = t.repeat(b)
        eps = model(x, t_batch)

        alpha_t = alphas_cumprod[t]
        sqrt_alpha_t = alpha_t.sqrt()
        sqrt_one_minus_alpha_t = (1.0 - alpha_t).sqrt()

        x0_pred = (x - sqrt_one_minus_alpha_t * eps) / sqrt_alpha_t
        x0_pred = x0_pred.clamp(-1.0, 1.0)

        if i == steps - 1:
            x = x0_pred
        else:
            t_next = step_indices[i + 1]
            alpha_next = alphas_cumprod[t_next]
            sqrt_alpha_next = alpha_next.sqrt()
            sqrt_one_minus_alpha_next = (1.0 - alpha_next).sqrt()
            x = sqrt_alpha_next * x0_pred + sqrt_one_minus_alpha_next * eps

    return x  # [-1,1] approx x0


## **6. Student model structure**

In [None]:
class StudentUNetDDPM(nn.Module):
    """
    ÌïôÏÉù Î™®Îç∏: TeacherÏôÄ Í∞ôÏùÄ Íµ¨Ï°∞ÏßÄÎßå base_chÎ•º Ï§ÑÏù∏ DDPM-style UNet.
    base_ch=64 (TeacherÎäî 128)
    ÏûÖÎ†•: x_t (B,3,64,64), t (B,)
    Ï∂úÎ†•: eps_pred (B,3,64,64)
    """
    def __init__(self, img_ch=3, base_ch=64, time_dim=512):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_dim),
            nn.Linear(time_dim, time_dim),
            nn.SiLU(),
            nn.Linear(time_dim, time_dim),
        )

        self.inc = nn.Conv2d(img_ch, base_ch, 3, padding=1)

        self.down1 = DownBlockT(base_ch,      base_ch*2, time_dim)  # 64->32
        self.down2 = DownBlockT(base_ch*2,    base_ch*4, time_dim)  # 32->16
        self.down3 = DownBlockT(base_ch*4,    base_ch*4, time_dim)  # 16->8
        self.down4 = DownBlockT(base_ch*4,    base_ch*4, time_dim)  # 8->4

        self.mid1 = ResBlock(base_ch*4, base_ch*4, time_dim)
        self.mid2 = ResBlock(base_ch*4, base_ch*4, time_dim)

        self.up4 = UpBlockT(base_ch*4 + base_ch*4, base_ch*4, time_dim)  # 4->8
        self.up3 = UpBlockT(base_ch*4 + base_ch*4, base_ch*4, time_dim)  # 8->16
        self.up2 = UpBlockT(base_ch*4 + base_ch*4, base_ch*2, time_dim)  # 16->32
        self.up1 = UpBlockT(base_ch*2 + base_ch*2, base_ch,   time_dim)  # 32->64

        self.outc = nn.Conv2d(base_ch, img_ch, 3, padding=1)

    def forward(self, x, t):
        t_emb = self.time_mlp(t)
        x0 = self.inc(x)
        x1, s1 = self.down1(x0, t_emb)
        x2, s2 = self.down2(x1, t_emb)
        x3, s3 = self.down3(x2, t_emb)
        x4, s4 = self.down4(x3, t_emb)

        m  = self.mid1(x4, t_emb)
        m  = self.mid2(m,  t_emb)

        u4 = self.up4(m,  s4, t_emb)
        u3 = self.up3(u4, s3, t_emb)
        u2 = self.up2(u3, s2, t_emb)
        u1 = self.up1(u2, s1, t_emb)

        out = self.outc(u1)
        return out  # eps prediction


## **7. Teacher epsilon Wrapper + Student Distillation training loop**

In [None]:
class TeacherEpsWrapper(nn.Module):
    """
    teacher_tiny.ptÏóêÏÑú TeacherUNetÍ≥º alphas_cumprodÎ•º Î°úÎìúÌïòÎäî ÎûòÌçº.
    - forward(x_t, t) -> eps_teacher
    """
    def __init__(self, ckpt_path: str, device="cuda"):
        super().__init__()
        ckpt = torch.load(ckpt_path, map_location=device)
        self.model = TeacherUNet().to(device)
        self.model.load_state_dict(ckpt["model"])
        self.model.eval()
        for p in self.model.parameters():
            p.requires_grad_(False)

        self.alphas_cumprod = ckpt["alphas_cumprod"].to(device)  # (T,)
        self.num_train_timesteps = ckpt["num_train_timesteps"]
        self.device = device

    @torch.no_grad()
    def forward(self, x_t, t):
        return self.model(x_t, t)


def train_student_ddpm_distill(
    data_root: str,
    teacher_ckpt: str = "teacher_tiny.pt",
    epochs: int = 20,
    batch_size: int = 128,
    lr: float = 2e-4,
    ema_decay: float = 0.999,
    save_path: str = "student_ddpm_tiny.pt",
):
    """
    DDPM Teacher ‚Üí DDPM Student epsilon distillation (baseline).
    - Tiny-ImageNet trainÏóêÏÑú x0Î•º ÎΩëÍ≥†
    - q(x_t | x0)Î°ú x_tÎ•º ÏÉòÌîåÎßÅ
    - Teacher/StudentÏùò epsilonÏùÑ L2Î°ú ÎßûÏ∂§
    """
    train_loader = get_tiny_train_loader(data_root, batch_size)

    teacher = TeacherEpsWrapper(ckpt_path=teacher_ckpt, device=device)
    alphas_cumprod = teacher.alphas_cumprod  # (T,)
    T = teacher.num_train_timesteps

    student = StudentUNetDDPM().to(device)
    ema = StudentUNetDDPM().to(device)
    ema.load_state_dict(student.state_dict())
    for p in ema.parameters():
        p.requires_grad_(False)

    opt = optim.AdamW(student.parameters(), lr=lr)
    mse = nn.MSELoss()

    total_iters = epochs * len(train_loader)
    global_step = 0
    start_time = time.time()

    print(f"[student distill] start | epochs={epochs}, total_iters‚âà{total_iters}")

    for epoch in range(epochs):
        for it_in_epoch, (x0, _) in enumerate(train_loader, start=1):
            global_step += 1
            student.train()

            x0 = x0.to(device)  # [-1,1]
            b = x0.size(0)

            t = torch.randint(low=0, high=T, size=(b,), device=device, dtype=torch.long)
            eps = torch.randn_like(x0)

            alpha_t = alphas_cumprod[t].view(b, 1, 1, 1)
            sqrt_alpha_t = alpha_t.sqrt()
            sqrt_one_minus_alpha_t = (1.0 - alpha_t).sqrt()
            x_t = sqrt_alpha_t * x0 + sqrt_one_minus_alpha_t * eps

            with torch.no_grad():
                eps_teacher = teacher(x_t, t)

            eps_student = student(x_t, t)
            loss = mse(eps_student, eps_teacher)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()

            # EMA update
            with torch.no_grad():
                for p_ema, p in zip(ema.parameters(), student.parameters()):
                    p_ema.data.mul_(ema_decay).add_(p.data, alpha=1.0 - ema_decay)

            if global_step % 50 == 0 or global_step == 1:
                elapsed = time.time() - start_time
                progress = global_step / max(total_iters, 1)
                eta = elapsed / max(progress, 1e-8) - elapsed
                print(
                    f"[student distill] epoch {epoch+1}/{epochs} "
                    f"iter {it_in_epoch}/{len(train_loader)} "
                    f"| global {global_step}/{total_iters} "
                    f"({progress*100:5.1f}%) "
                    f"| loss {loss.item():.4f} "
                    f"| elapsed {format_time(elapsed)} "
                    f"| ETA {format_time(eta)}"
                )

    torch.save(
        {
            "student": ema.state_dict(),  # EMAÎ•º ÏµúÏ¢Ö studentÎ°ú ÏÇ¨Ïö©
            "num_train_timesteps": T,
        },
        save_path,
    )
    print(f"=> saved EMA student to {save_path}")
    return ema, alphas_cumprod, T


## **8. Student sample & FID evaluation**

In [None]:
import csv

# ===== Student Distillation & FID ÏòàÏãú =====

data_root = "/content/"  # TODO: Your dataset directory
EPOCHS_OFFICIAL = 3
BATCH_SIZE_OFFICIAL = 128
LEARNING_RATE_OFFICIAL = 2e-4

# Student distillation ÏàòÌñâ
student_ema, student_alphas_cumprod, T = train_student_ddpm_distill(
    data_root=data_root,
    teacher_ckpt="/content/teacher.pt",
    epochs=EPOCHS_OFFICIAL,
    batch_size=BATCH_SIZE_OFFICIAL,
    lr=LEARNING_RATE_OFFICIAL,
    save_path="student_ddpm_tiny.pt",
)

def compute_fid_with_pytorch_fid(fake_dir: str, real_dir: str):
    """
    pytorch-fidÎ•º subprocessÎ°ú Ìò∏Ï∂úÌï¥ÏÑú FID Í∞íÏùÑ floatÎ°ú Î∞òÌôò.
    pip install pytorch-fid ÌïÑÏöî.
    """
    cmd = [sys.executable, "-m", "pytorch_fid", fake_dir, real_dir]
    print("Running:", " ".join(cmd))
    res = subprocess.run(cmd, capture_output=True, text=True)
    if res.returncode != 0:
        print("stderr:", res.stderr)
        raise RuntimeError("FID computation failed")

    fid_val = None
    for line in res.stdout.splitlines():
        if "FID:" in line:
            try:
                fid_val = float(line.strip().split("FID:")[-1])
            except Exception:
                pass
    print(res.stdout)
    return fid_val

# 2) Student ÏÉòÌîå ÏÉùÏÑ± (Ïòà: 10,000Ïû•)
@torch.no_grad()
def generate_student_samples_for_fid(
    student_model,
    alphas_cumprod,
    num_train_timesteps,
    out_dir: str = "./student_samples_baseline",
    n_samples: int = 10000,
    batch_size: int = 64,
    steps: int = 10,
    log_interval_sec: float = 60.0,  # ~1Î∂ÑÎßàÎã§ ÏßÑÌñâ Î°úÍ∑∏
):
    """
    StudentÏö© ÏÉòÌîå ÏÉùÏÑ± + latency Ï∏°Ï†ï Ìï®Ïàò (FID ÌèâÍ∞ÄÏö©).

    Î∞òÌôò:
      - out_dir: Ïù¥ÎØ∏ÏßÄÍ∞Ä Ï†ÄÏû•Îêú ÎîîÎ†âÌÜ†Î¶¨
      - latency_ms: Ïù¥ÎØ∏ÏßÄ 1Ïû•Îãπ ÌèâÍ∑† ÏãúÍ∞Ñ (ms)
    """
    os.makedirs(out_dir, exist_ok=True)
    student_model.eval()

    left = n_samples
    idx = 0  # ÏßÄÍ∏àÍπåÏßÄ ÏÉùÏÑ±Îêú Ïù¥ÎØ∏ÏßÄ Ïàò

    start_time = time.time()
    last_log_time = start_time

    print(
        f"[student sample] start | n_samples={n_samples}, "
        f"batch_size={batch_size}, steps={steps}"
    )

    while left > 0:
        bs = min(batch_size, left)
        z = torch.randn(bs, 3, 64, 64, device=device)

        # DDIM ÏÉòÌîåÎßÅ
        x = ddim_sample(
            student_model,
            alphas_cumprod,
            num_train_timesteps,
            z,
            steps=steps,
        )
        x = (x.clamp(-1, 1) * 0.5 + 0.5)  # [-1,1] -> [0,1]

        # Ïù¥ÎØ∏ÏßÄ Ï†ÄÏû•
        for i in range(bs):
            save_image(x[i], os.path.join(out_dir, f"s{idx:06d}.png"))
            idx += 1
        left -= bs

        # ----- ÏßÑÌñâ Î°úÍ∑∏: 1Î∂ÑÎßàÎã§ + ÎßàÏßÄÎßâÏóê Ìïú Î≤à -----
        now = time.time()
        if (now - last_log_time) >= log_interval_sec or idx == n_samples:
            elapsed = now - start_time
            progress = idx / n_samples
            eta = elapsed / progress - elapsed if progress > 0 else 0.0

            print(
                f"[student sample] {idx}/{n_samples} "
                f"({progress*100:5.1f}%) | "
                f"elapsed {format_time(elapsed)} | "
                f"ETA {format_time(eta)}"
            )
            last_log_time = now
        # ------------------------------------------

    total_time = time.time() - start_time
    latency_ms = total_time / n_samples * 1000.0

    print(
        f"=> generated {n_samples} student samples in "
        f"{format_time(total_time)} "
        f"({latency_ms:.2f} ms / image)"
    )

    return out_dir, latency_ms

def log_result_to_csv(
    csv_path: str,
    run_id: str,
    fid: float,
    latency_ms: float,
):
    """
    Í≤∞Í≥ºÎ•º CSVÎ°ú Í∏∞Î°ù.
    - csv_pathÍ∞Ä ÏóÜÏúºÎ©¥ Ìó§Îçî(id,fid,latency_ms)Î•º ÎßåÎì§Í≥†,
    - ÏûàÏúºÎ©¥ Îß® Îí§Ïóê Ìïú Ï§Ñ append.
    """
    file_exists = os.path.isfile(csv_path)

    with open(csv_path, "a", newline="") as f:
        writer = csv.writer(f)
        if not file_exists:
            writer.writerow(["id", "fid", "latency_ms"])
        writer.writerow([run_id, fid, latency_ms])

    print(f"[csv] saved result to {csv_path} (id={run_id})")


fake_dir_student, latency_ms = generate_student_samples_for_fid(
    student_model=student_ema,
    alphas_cumprod=student_alphas_cumprod,
    num_train_timesteps=T,
    out_dir="./student_samples_baseline",
    n_samples=10000,
    batch_size=64,
    steps=10,   # üîπ stepsÎ°ú ÎßûÏ∂îÍ∏∞
)

# real_dirÎäî ÏïûÏóêÏÑú flatten_tiny_valÎ°ú ÎßåÎì† "./tiny_val_flat" ÏÇ¨Ïö©
fid_student = compute_fid_with_pytorch_fid(fake_dir=fake_dir_student, real_dir="./tiny_val_flat")
print("Student FID (baseline):", fid_student)


log_result_to_csv(
    csv_path="submission.csv",
    run_id=0,
    fid=fid_student,
    latency_ms=latency_ms,
)



[info] train_dir: /content/tiny-imagenet-200/train
[info] Tiny-ImageNet train size: 100000
[student distill] start | epochs=3, total_iters‚âà2343
[student distill] epoch 1/3 iter 1/781 | global 1/2343 (  0.0%) | loss 1.0586 | elapsed 00:00:04 | ETA 03:13:46
[student distill] epoch 1/3 iter 50/781 | global 50/2343 (  2.1%) | loss 0.0777 | elapsed 00:02:10 | ETA 01:40:06
[student distill] epoch 1/3 iter 100/781 | global 100/2343 (  4.3%) | loss 0.0614 | elapsed 00:04:23 | ETA 01:38:36
[student distill] epoch 1/3 iter 150/781 | global 150/2343 (  6.4%) | loss 0.0400 | elapsed 00:06:36 | ETA 01:36:37
[student distill] epoch 1/3 iter 200/781 | global 200/2343 (  8.5%) | loss 0.0378 | elapsed 00:08:49 | ETA 01:34:32
[student distill] epoch 1/3 iter 250/781 | global 250/2343 ( 10.7%) | loss 0.0285 | elapsed 00:11:02 | ETA 01:32:26
[student distill] epoch 1/3 iter 300/781 | global 300/2343 ( 12.8%) | loss 0.0203 | elapsed 00:13:15 | ETA 01:30:15
[student distill] epoch 1/3 iter 350/781 | globa