<a href="https://colab.research.google.com/github/username12345678901234567890/1234/blob/main/Untitled1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# %% [single-cell train] Plasticity 기반 Self-Organizing MoE - 실전용 (Colab A100 40GB)
# ---------------------------------------
# 설정
# ---------------------------------------
import os, sys, json, time, glob, math, random
from dataclasses import dataclass
from typing import List, Tuple
os.makedirs("./checkpoints", exist_ok=True)
os.makedirs("./samples", exist_ok=True)

# Colab 등에서 LPIPS 설치 (없으면 건너뜀)
try:
    import lpips  # noqa
    _HAS_LPIPS = True
except Exception:
    try:
        # 인터넷이 되지 않는 환경이면 자동으로 실패 -> MSE만 사용
        !pip -q install lpips==0.1.4
        import lpips  # noqa
        _HAS_LPIPS = True
    except Exception:
        _HAS_LPIPS = False

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from PIL import Image
from tqdm import tqdm

# ---------------------------------------
# Config
# ---------------------------------------
@dataclass
class Config:
    data_root: str = "/images"            # 이미지 폴더 (하위 폴더 포함)
    image_size: int = 128
    epochs: int = 20
    batch_size: int = 192                 # A100 40GB 기준 권장
    num_workers: int = 4
    lr: float = 2e-4
    betas: Tuple[float, float] = (0.9, 0.99)
    weight_decay: float = 1e-4
    amp: bool = True

    # Latent / Model
    z_dim: int = 512
    expert_hidden: int = 512

    # MoE / Routing
    init_experts: int = 32
    max_experts: int = 256
    top_k: int = 8
    epsilon_explore: float = 0.1          # 탐험 비율(에폭 지나며 0.01로 감소)
    gate_tau_start: float = 2.0
    gate_tau_end: float = 0.7

    # Plasticity / Stats
    ema_decay: float = 0.9
    baseline_grad_eps: float = 0.005

    # Pruning / Spin-off
    prune_usage_thresh: float = 0.01      # 에폭 EMA 기준 사용률
    prune_contrib_thresh: float = 1e-4    # 에폭 EMA 기준 기여도
    prune_patience_epochs: int = 3
    spin_usage_thresh: float = 0.35       # 한 전문가가 과점이면 분할(스핀오프)
    spin_interval_epochs: int = 2
    spin_noise_std: float = 0.02

    # Saving
    out_ckpt_dir: str = "./checkpoints"
    out_samples_dir: str = "./samples"
    sample_rows: int = 4
    seed: int = 42

cfg = Config()
random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[Device] {device}")

# ---------------------------------------
# Dataset
# ---------------------------------------
IMG_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".webp")

class PlainImageDataset(Dataset):
    def __init__(self, root: str, size: int = 128, train: bool = True):
        self.paths = [p for p in glob.glob(os.path.join(root, "**", "*"), recursive=True)
                      if os.path.splitext(p)[1].lower() in IMG_EXTS]
        if not self.paths:
            raise RuntimeError(f"No images found under {root}. Put 128x128 images in that folder (recursively).")
        aug = [
            transforms.Resize(size),
            transforms.CenterCrop(size),
        ]
        if train:
            aug.insert(1, transforms.RandomHorizontalFlip(p=0.5))
            aug.insert(2, transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05, hue=0.02))
        self.tf = transforms.Compose(aug + [
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)  # [-1, 1]
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        return self.tf(img)

full_ds = PlainImageDataset(cfg.data_root, size=cfg.image_size, train=True)
val_ratio = 0.02 if len(full_ds) > 2000 else 0.1
val_len = max(16, int(len(full_ds) * val_ratio))
train_len = len(full_ds) - val_len
train_ds, val_ds = random_split(full_ds, [train_len, val_len], generator=torch.Generator().manual_seed(cfg.seed))

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                          num_workers=cfg.num_workers, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False,
                        num_workers=cfg.num_workers, pin_memory=True, drop_last=False)
print(f"[Data] train={len(train_ds)} val={len(val_ds)} batch={cfg.batch_size}")

# ---------------------------------------
# Model: Conv AE + Latent MoE
# ---------------------------------------
def conv_block(in_ch, out_ch, k=3, s=1, p=1):
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, k, s, p),
        nn.GroupNorm(8, out_ch),
        nn.SiLU()
    )

class Encoder(nn.Module):
    def __init__(self, z_dim=512):
        super().__init__()
        ch = 64
        self.net = nn.Sequential(
            conv_block(3, ch),
            conv_block(ch, ch),
            conv_block(ch, ch*2, s=2),          # 64
            conv_block(ch*2, ch*2),
            conv_block(ch*2, ch*4, s=2),        # 32
            conv_block(ch*4, ch*4),
            conv_block(ch*4, ch*8, s=2),        # 16
            conv_block(ch*8, ch*8),
            conv_block(ch*8, ch*8, s=2),        # 8
            conv_block(ch*8, ch*8),
        )
        self.proj = nn.Linear(ch*8*8*8, z_dim)

    def forward(self, x):
        h = self.net(x).view(x.size(0), -1)
        z = self.proj(h)
        return z

class Decoder(nn.Module):
    def __init__(self, z_dim=512):
        super().__init__()
        ch = 64
        self.fc = nn.Linear(z_dim, ch*8*8*8)  # 8x8x512
        self.up = nn.Sequential(
            conv_block(ch*8, ch*8),
            nn.Upsample(scale_factor=2, mode='nearest'),   # 16
            conv_block(ch*8, ch*8),
            nn.Upsample(scale_factor=2, mode='nearest'),   # 32
            conv_block(ch*8, ch*4),
            nn.Upsample(scale_factor=2, mode='nearest'),   # 64
            conv_block(ch*4, ch*2),
            nn.Upsample(scale_factor=2, mode='nearest'),   # 128
            conv_block(ch*2, ch),
            nn.Conv2d(ch, 3, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, z):
        h = self.fc(z).view(z.size(0), 512, 8, 8)
        return self.up(h)

class ExpertMLP(nn.Module):
    def __init__(self, z_dim=512, hidden=512):
        super().__init__()
        self.fc1 = nn.Linear(z_dim, hidden)
        self.fc2 = nn.Linear(hidden, z_dim)
        nn.init.kaiming_uniform_(self.fc1.weight, a=math.sqrt(5))
        nn.init.zeros_(self.fc1.bias); nn.init.zeros_(self.fc2.bias)

    def forward(self, z):
        h = F.silu(self.fc1(z))
        return self.fc2(h)

class MoELatent(nn.Module):
    def __init__(self, z_dim=512, hidden=512, init_experts=32):
        super().__init__()
        self.z_dim = z_dim
        self.hidden = hidden
        self.experts = nn.ModuleList([ExpertMLP(z_dim, hidden) for _ in range(init_experts)])
        self.prototypes = nn.Parameter(torch.randn(init_experts, z_dim) * 0.02)  # E x D

    def num_experts(self): return len(self.experts)

    def add_expert(self, clone_idx=None, noise_std=0.02):
        new = ExpertMLP(self.z_dim, self.hidden)
        if clone_idx is not None and 0 <= clone_idx < len(self.experts):
            new.load_state_dict(self.experts[clone_idx].state_dict())
            with torch.no_grad():
                for p in new.parameters():
                    p.add_(torch.randn_like(p) * noise_std)
            proto = self.prototypes.data[clone_idx].clone()
            proto.add_(torch.randn_like(proto) * noise_std)
        else:
            proto = torch.randn(self.z_dim, device=self.prototypes.device) * noise_std
        self.experts.append(new)
        self.prototypes = nn.Parameter(torch.cat([self.prototypes.data, proto[None, :]], dim=0))

    def remove_experts(self, indices: List[int]):
        if not indices: return
        keep = [i for i in range(len(self.experts)) if i not in set(indices)]
        self.experts = nn.ModuleList([self.experts[i] for i in keep])
        with torch.no_grad():
            self.prototypes = nn.Parameter(self.prototypes.data[keep])

    def forward(self, z, top_k=8, tau=1.0, epsilon=0.0):
        # gating via similarity to prototypes
        logits = F.linear(z, self.prototypes) / max(tau, 1e-6)  # [B,E]
        probs = F.softmax(logits, dim=-1)

        E = self.num_experts()
        k = min(top_k, E)
        topk_vals, topk_idx = torch.topk(probs, k=k, dim=-1)  # [B,k]

        # epsilon exploration: 일부를 랜덤 치환
        if epsilon > 0 and E > k:
            B = z.size(0)
            mask = (torch.rand(B, device=z.device) < epsilon)
            if mask.any():
                rand_idx = torch.randint(0, E, (mask.sum(),), device=z.device)
                topk_idx[mask, -1] = rand_idx
                topk_vals[mask, -1] = probs[mask, :].gather(1, rand_idx.view(-1, 1)).squeeze(1)

        moe_sum = torch.zeros_like(z)
        used_mask = torch.zeros(z.size(0), E, dtype=torch.bool, device=z.device)

        # batched expert calls grouped by index
        for j in range(k):
            idx_j = topk_idx[:, j]
            w_j = topk_vals[:, j].unsqueeze(1)
            uniq = idx_j.unique()
            out_j = torch.zeros_like(z)
            for u in uniq:
                sel = (idx_j == u)
                if sel.any():
                    y = self.experts[int(u)](z[sel])
                    out_j[sel] = y
                    used_mask[sel, int(u)] = True
            moe_sum = moe_sum + w_j * out_j

        z_out = z + moe_sum
        return z_out, probs, used_mask, topk_idx, topk_vals

class AutoEncoderMoE(nn.Module):
    def __init__(self, z_dim=512, hidden=512, init_experts=32):
        super().__init__()
        self.enc = Encoder(z_dim=z_dim)
        self.moe = MoELatent(z_dim=z_dim, hidden=hidden, init_experts=init_experts)
        self.dec = Decoder(z_dim=z_dim)

    def forward(self, x, top_k=8, tau=1.0, epsilon=0.0):
        z = self.enc(x)
        z_moe, probs, used_mask, topk_idx, topk_vals = self.moe(z, top_k=top_k, tau=tau, epsilon=epsilon)
        x_rec = self.dec(z_moe)
        return x_rec, z, z_moe, probs, used_mask, topk_idx, topk_vals

# ---------------------------------------
# Utils
# ---------------------------------------
def denorm(x): return (x.clamp(-1,1) + 1.0) * 0.5

def save_sample_grid(epoch, model, loader, device, out_dir, rows=4):
    model.eval()
    with torch.no_grad():
        imgs = next(iter(loader))[:rows*rows].to(device)
        rec, *_ = model(imgs, top_k=cfg.top_k, tau=1.0, epsilon=0.0)
        grid = make_grid(torch.cat([denorm(imgs), denorm(rec)], dim=0), nrow=rows)
        save_image(grid, os.path.join(out_dir, f"epoch_{epoch:03d}.png"))
    model.train()

# ---------------------------------------
# Train Setup
# ---------------------------------------
model = AutoEncoderMoE(cfg.z_dim, cfg.expert_hidden, cfg.init_experts).to(device)

if _HAS_LPIPS:
    percept = lpips.LPIPS(net='vgg').to(device).eval()
else:
    percept = None
    print("[Info] LPIPS 미사용(MSE만 적용). pip 설치가 불가하거나 네트워크가 막힌 환경일 수 있습니다.")

opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, betas=cfg.betas, weight_decay=cfg.weight_decay)
scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)

# Expert Stats (에폭 EMA / 사용/기여 추적)
class ExpertStats:
    def __init__(self, E: int):
        self.reset(E)
    def reset(self, E):
        self.E = E
        dev = device
        self.usage_epoch  = torch.zeros(E, dtype=torch.float32, device=dev)
        self.usage_ema    = torch.zeros(E, dtype=torch.float32, device=dev)
        self.contrib_epoch= torch.zeros(E, dtype=torch.float32, device=dev)
        self.contrib_ema  = torch.zeros(E, dtype=torch.float32, device=dev)
        self.fail_epochs  = torch.zeros(E, dtype=torch.int32 , device=dev)  # pruning patience
        self.quarantine   = torch.zeros(E, dtype=torch.int32 , device=dev)  # (미세조정 여지)
    def on_structure_change(self, newE, kept=None):
        # 구조 변경 시 통계 리맵
        ou, oc, of, oq = [x.detach().clone().cpu().tolist() for x in
                          [self.usage_ema, self.contrib_ema, self.fail_epochs, self.quarantine]]
        dev = device
        if kept is None:
            self.reset(newE); return
        self.E = newE
        self.usage_epoch   = torch.zeros(newE, dtype=torch.float32, device=dev)
        self.contrib_epoch = torch.zeros(newE, dtype=torch.float32, device=dev)
        self.usage_ema     = torch.zeros(newE, dtype=torch.float32, device=dev)
        self.contrib_ema   = torch.zeros(newE, dtype=torch.float32, device=dev)
        self.fail_epochs   = torch.zeros(newE, dtype=torch.int32 , device=dev)
        self.quarantine    = torch.zeros(newE, dtype=torch.int32 , device=dev)
        for new_i, old_i in enumerate(kept):
            self.usage_ema[new_i]   = torch.tensor(ou[old_i], device=dev)
            self.contrib_ema[new_i] = torch.tensor(oc[old_i], device=dev)
            self.fail_epochs[new_i] = torch.tensor(of[old_i], device=dev)
            self.quarantine[new_i]  = torch.tensor(oq[old_i], device=dev)
    def epoch_decay(self, decay):
        self.usage_ema   = decay*self.usage_ema   + (1-decay)*self.usage_epoch
        self.contrib_ema = decay*self.contrib_ema + (1-decay)*self.contrib_epoch
        self.usage_epoch.zero_(); self.contrib_epoch.zero_()

stats = ExpertStats(model.moe.num_experts())

# ---------------------------------------
# Helpers: 구조 조정(프루닝/스핀오프), 요약
# ---------------------------------------
def maybe_spin_or_prune(epoch, model, stats, opt):
    E = model.moe.num_experts()
    total_usage = stats.usage_ema.sum().item() + 1e-6
    usage_ratio = stats.usage_ema / total_usage

    # Prune
    prune_idx = []
    for i in range(E):
        if usage_ratio[i].item() < cfg.prune_usage_thresh and stats.contrib_ema[i].item() < cfg.prune_contrib_thresh:
            stats.fail_epochs[i] += 1
            if stats.fail_epochs[i] >= cfg.prune_patience_epochs and (E - len(prune_idx) > 4):
                prune_idx.append(i)
        else:
            stats.fail_epochs[i] = 0
    note = ""
    if prune_idx:
        keep = [j for j in range(E) if j not in set(int(x) for x in prune_idx)]
        model.moe.remove_experts([int(x) for x in prune_idx])
        stats.on_structure_change(len(keep), kept=keep)
        del opt
        opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, betas=cfg.betas, weight_decay=cfg.weight_decay)
        note += f"Pruned {len(prune_idx)} | "

    # Spin-off
    E = model.moe.num_experts()
    if epoch % cfg.spin_interval_epochs == 0 and E < cfg.max_experts and E > 0:
        top_usage, top_idx = torch.topk(usage_ratio[:E], k=1)
        if top_usage.item() > cfg.spin_usage_thresh:
            model.moe.add_expert(clone_idx=int(top_idx.item()), noise_std=cfg.spin_noise_std)
            kept = list(range(E)) + [E]
            stats.on_structure_change(E+1, kept=kept)
            del opt
            opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, betas=cfg.betas, weight_decay=cfg.weight_decay)
            note += f"Spun-off from expert {int(top_idx.item())} -> new {E}"
    return opt, note.strip()

def epoch_summary(epoch, model, stats, train_loss, val_loss):
    E = model.moe.num_experts()
    contrib = stats.contrib_ema.detach().cpu()
    usage = stats.usage_ema.detach().cpu()
    if E > 0:
        k = min(5, E)
        vals, idx = torch.topk(contrib, k=k)
        top5 = [(int(idx[j]), float(vals[j]), float(usage[int(idx[j])])) for j in range(k)]
    else:
        top5 = []
    neuron_scale = E * cfg.expert_hidden

    print(f"\n[Epoch {epoch}] Experts={E} | NeuronScale={neuron_scale} | TrainLoss={train_loss:.4f} | ValLoss={val_loss:.4f}")
    print("Top-5 experts (id, contrib_ema, usage_ema):")
    for eid, c, u in top5:
        print(f"  - {eid:3d} | contrib={c:.6f} | usage={u:.6f}")

    summary = {
        "epoch": epoch,
        "experts_total": E,
        "neuron_scale": neuron_scale,
        "top5_by_contrib": [{"expert_id": eid, "contrib_ema": c, "usage_ema": u} for (eid,c,u) in top5],
        "train_loss": float(train_loss),
        "val_loss": float(val_loss),
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
    }
    with open(os.path.join(cfg.out_ckpt_dir, f"epoch_{epoch:03d}_summary.json"), "w") as f:
        json.dump(summary, f, indent=2)
    return summary

# ---------------------------------------
# Train Loop
# ---------------------------------------
global_step = 0
for epoch in range(1, cfg.epochs+1):
    model.train()
    tau = cfg.gate_tau_start + (cfg.gate_tau_end - cfg.gate_tau_start) * (epoch-1)/max(1, cfg.epochs-1)
    eps = cfg.epsilon_explore * (1 - (epoch-1)/max(1, cfg.epochs-1)) + 0.01

    running = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg.epochs} (train)")
    for batch in pbar:
        batch = batch.to(device, non_blocking=True)

        # baseline 경로(전문가 없이)로 재구성 손실 측정
        with torch.no_grad():
            z_base = model.enc(batch)
            x_base = model.dec(z_base)
            base_mse = F.mse_loss(x_base, batch)

        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=cfg.amp):
            x_rec, z0, z_moe, probs, used_mask, topk_idx, topk_vals = model(batch, top_k=cfg.top_k, tau=tau, epsilon=eps)
            mse = F.mse_loss(x_rec, batch)
            if _HAS_LPIPS:
                lp = percept(denorm(x_rec), denorm(batch)).mean()
                loss = 0.7*mse + 0.3*lp
            else:
                loss = mse
                lp = torch.tensor(0.0, device=device)

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()

        # 사용/기여 통계 (에폭 누적)
        with torch.no_grad():
            B = batch.size(0)
            E = model.moe.num_experts()
            if E > 0:
                stats.usage_epoch[:E] += used_mask.float().sum(dim=0)
                improv = (base_mse - loss).clamp_min(0.0)  # 전체 개선 추정치
                per_exp_weight = torch.zeros(E, device=device)
                k = topk_idx.size(1)
                for j in range(k):
                    idx_j = topk_idx[:, j]
                    w_j = topk_vals[:, j]
                    for u in idx_j.unique():
                        sel = (idx_j == u)
                        per_exp_weight[int(u)] += w_j[sel].sum()
                if per_exp_weight.sum() > 0:
                    per_exp_weight = per_exp_weight / per_exp_weight.sum()
                    stats.contrib_epoch[:E] += per_exp_weight * improv.item()

        running += loss.item()
        global_step += 1
        pbar.set_postfix(loss=f"{loss.item():.4f}", mse=f"{mse.item():.4f}", lpips=f"{lp.item():.4f}", tau=f"{tau:.2f}", eps=f"{eps:.3f}")

    # EMA 업데이트
    stats.epoch_decay(cfg.ema_decay)

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch}/{cfg.epochs} (val)"):
            batch = batch.to(device, non_blocking=True)
            with torch.cuda.amp.autocast(enabled=cfg.amp):
                x_rec, *_ = model(batch, top_k=cfg.top_k, tau=tau, epsilon=0.0)
                mse = F.mse_loss(x_rec, batch)
                if _HAS_LPIPS:
                    lp = percept(denorm(x_rec), denorm(batch)).mean()
                    loss = 0.7*mse + 0.3*lp
                else:
                    loss = mse
            val_loss += loss.item() * batch.size(0)
    val_loss /= len(val_ds)

    # 구조 조정 (Prune / Spin-off)
    opt, note = maybe_spin_or_prune(epoch, model, stats, opt)
    if note:
        print(f"[Structure] {note} | Experts={model.moe.num_experts()}")

    # 샘플/체크포인트 저장
    save_sample_grid(epoch, model, val_loader, device, cfg.out_samples_dir, rows=cfg.sample_rows)
    ckpt = {
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": opt.state_dict(),
        "stats": {
            "usage_ema": stats.usage_ema.detach().cpu().tolist(),
            "contrib_ema": stats.contrib_ema.detach().cpu().tolist(),
            "fail_epochs": stats.fail_epochs.detach().cpu().tolist(),
            "quarantine": stats.quarantine.detach().cpu().tolist(),
        },
        "config": cfg.__dict__,
    }
    torch.save(ckpt, os.path.join(cfg.out_ckpt_dir, f"epoch_{epoch:03d}.pt"))

    # 에폭 요약 출력/저장
    _ = epoch_summary(epoch, model, stats, train_loss=running/len(train_loader), val_loss=val_loss)

print("✅ Training complete. Check ./checkpoints and ./samples")


Mounted at /content/drive


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It is not recommended to move them to `cpu` as running them will fail. Please make sure to use an accelerator to run the pipeline in inference, due to the lack of support for`float16` operations on this device in PyTorch. Please, remove the `torch_dtype=torch.float16` argum

Epoch 1/3


  pixel_values = torch.tensor(torch.ByteTensor(torch.ByteStorage.from_buffer(image.tobytes())).float()).view(image.size[1], image.size[0], 3).permute(2, 0, 1) / 255.0
  pixel_values = torch.tensor(torch.ByteTensor(torch.ByteStorage.from_buffer(image.tobytes())).float()).view(image.size[1], image.size[0], 3).permute(2, 0, 1) / 255.0
  with autocast():
