<a href="https://colab.research.google.com/github/username12345678901234567890/1234/blob/main/Untitled1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!mkdir -p images && curl -L -o temp.zip "https://www.dropbox.com/scl/fi/ll1vl4ao3qcs8t7x19ej2/anime-face.zip?rlkey=r7rcnruhsxpjvs9jgzm2pygq7&st=t4c8ii3r&dl=1" && unzip -d images temp.zip && rm temp.zip
import os, sys, math, glob, time, json, random, subprocess
from dataclasses import dataclass
from typing import List, Tuple, Optional

# ---------- Setup ----------
os.makedirs("./checkpoints", exist_ok=True)
os.makedirs("./samples", exist_ok=True)
os.makedirs("./logs", exist_ok=True)

def _pip(pkg):
    try:
        __import__(pkg.split("==")[0].replace("-", "_"))
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])

_pip("lpips==0.1.4")
_pip("diffusers==0.30.3")
_pip("transformers==4.42.0")
_pip("accelerate==0.33.0")
_pip("safetensors==0.4.3")

import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from PIL import Image
from tqdm.auto import tqdm
from diffusers import AutoencoderKL
import lpips

torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[Device] {device}")

# ---------- Config ----------
@dataclass
class Config:
    # Data
    data_root: str="/images"
    image_size: int=128
    batch_size: int=192
    epochs: int=20
    num_workers: int=0
    pin_memory: bool=True
    seed: int=42

    # Optim/AMP
    lr: float=2e-4
    betas: Tuple[float,float]=(0.9,0.99)
    weight_decay: float=1e-4
    amp: bool=True
    clip_grad_norm: float=0.3

    # SD-VAE
    vae_repo: str="runwayml/stable-diffusion-v1-5"
    vae_subfolder: str="vae"
    vae_scale: float=0.18215
    vae_use_mode: bool=True
    vae_ft_last_block: bool=True
    vae_ft_quant_conv: bool=False

    # Latent projector
    z_dim: int=512
    proj_channels: int=128

    # Sequential NeuroPool
    stages: int=3
    pools_per_stage: int=4
    experts_per_pool: int=32
    adapters_per_expert: int=16
    lora_rank: int=4
    hidden: int=512
    top_k: int=4

    # Routing
    epsilon: float=0.15
    tau_high: float=1.8
    tau_low: float=0.7
    capacity_alpha: float=1.2

    # Regularizers
    repulsion_lambda: float=3e-3
    repulsion_subset: int=256
    repulsion_sigma: float=0.15

    # Loss weights (base before normalization)
    w_lpips: float=0.5
    w_ssim: float=0.3
    w_l1: float=0.15
    w_mse: float=0.05

    # Stable loss scaling (EMA)
    ema_beta_loss: float=0.99
    loss_eps: float=1e-6

    # Loss Diffusion (broadcast to non-selected experts)
    diffuse_enabled: bool=True
    diffuse_lambda: float=0.02
    diffuse_batch_items: int=24         # sampled items per batch for diffusion compute
    diffuse_experts_per_stage: int=8    # sampled non-selected experts per stage

    # Explorer bias (diversity)
    explorer_ratio: float=0.10
    explorer_bias_init: float=0.5
    explorer_bias_decay: float=0.9

    # Pruning (soft; bias suppression instead of deletion)
    prune_warmup_epochs: int=6
    prune_interval: int=2
    prune_share_thresh: float=0.003
    prune_max_per_stage: int=4
    prune_bias_push: float= -1.0        # subtract on biased experts (soft mask)

    # Save
    out_ckpt_dir: str="./checkpoints"
    out_samples_dir: str="./samples"
    out_log_dir: str="./logs"
    sample_rows: int=4

cfg = Config()
random.seed(cfg.seed); torch.manual_seed(cfg.seed)

# ---------- Dataset ----------
IMG_EXTS=(".png",".jpg",".jpeg",".bmp",".webp")
class ImageFolderDS(Dataset):
    def __init__(self, root, size=128, train=True):
        self.paths=[p for p in glob.glob(os.path.join(root,"**","*"), recursive=True)
                    if os.path.splitext(p)[1].lower() in IMG_EXTS]
        if not self.paths:
            raise RuntimeError(f"No images found under {root}")
        aug=[transforms.Resize(size)]
        if train:
            aug += [transforms.RandomHorizontalFlip(0.5),
                    transforms.ColorJitter(0.1,0.1,0.05,0.02)]
        aug += [transforms.CenterCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5]*3,[0.5]*3)]
        self.tf=transforms.Compose(aug)
    def __len__(self): return len(self.paths)
    def __getitem__(self, i):
        return self.tf(Image.open(self.paths[i]).convert("RGB"))

full=ImageFolderDS(cfg.data_root, size=cfg.image_size, train=True)
val_ratio=0.02 if len(full)>2000 else 0.1
val_len=max(16, int(len(full)*val_ratio)); train_len=len(full)-val_len
train_ds, val_ds = random_split(full,[train_len,val_len], generator=torch.Generator().manual_seed(cfg.seed))
train_loader=DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                        num_workers=cfg.num_workers, pin_memory=cfg.pin_memory, drop_last=True)
val_loader=DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False,
                      num_workers=cfg.num_workers, pin_memory=cfg.pin_memory, drop_last=False)
print(f"[Data] train={len(train_ds)} val={len(val_ds)} batch={cfg.batch_size}")

# ---------- SD-VAE (Encoder only) ----------
vae = AutoencoderKL.from_pretrained(cfg.vae_repo, subfolder=cfg.vae_subfolder).to(device)
for p in vae.parameters(): p.requires_grad=False
if cfg.vae_ft_last_block:
    for p in vae.encoder.down_blocks[-1].parameters(): p.requires_grad=True
    for p in vae.encoder.mid_block.parameters(): p.requires_grad=True
if cfg.vae_ft_quant_conv:
    for p in vae.quant_conv.parameters(): p.requires_grad=True
vae.eval()

# ---------- Blocks ----------
def conv_block(in_ch,out_ch,k=3,s=1,p=1):
    return nn.Sequential(nn.Conv2d(in_ch,out_ch,k,s,p), nn.GroupNorm(8,out_ch), nn.SiLU())

class LatentProjector(nn.Module):
    def __init__(self, z_dim, c=4, width=128):
        super().__init__()
        self.net=nn.Sequential(
            conv_block(c,width), conv_block(width,width),
            conv_block(width,width,3,2,1),
            conv_block(width,width), conv_block(width,width)
        )
        self.head=nn.Sequential(nn.AdaptiveAvgPool2d((8,8)), nn.Flatten(), nn.Linear(width*8*8,z_dim))
    def forward(self, zmap):
        return self.head(self.net(zmap))

class Decoder(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        ch=64
        self.fc=nn.Linear(z_dim, ch*8*8*8)
        self.up=nn.Sequential(
            conv_block(ch*8,ch*8), nn.Upsample(scale_factor=2,mode='nearest'),
            conv_block(ch*8,ch*8), nn.Upsample(scale_factor=2,mode='nearest'),
            conv_block(ch*8,ch*4), nn.Upsample(scale_factor=2,mode='nearest'),
            conv_block(ch*4,ch*2), nn.Upsample(scale_factor=2,mode='nearest'),
            conv_block(ch*2,ch), nn.Conv2d(ch,3,3,1,1), nn.Tanh()
        )
    def forward(self,z):
        h=self.fc(z).view(z.size(0),512,8,8)
        return self.up(h)

class ExpertLoRA(nn.Module):
    def __init__(self, z_dim, hidden, W, rank):
        super().__init__()
        self.W=W; self.r=rank
        self.a1=nn.Parameter(torch.randn(W, rank, z_dim)*0.02)
        self.b1=nn.Parameter(torch.randn(W, hidden, rank)*0.02)
        self.a2=nn.Parameter(torch.randn(W, rank, hidden)*0.02)
        self.b2=nn.Parameter(torch.randn(W, z_dim, rank)*0.02)
        self.scale=1.0
    def forward_batch(self, z, a_idx, h_pre, out_pre):
        dt=z.dtype
        A1=self.a1[a_idx].to(dt); B1=self.b1[a_idx].to(dt)
        t1=torch.einsum('nd,nrd->nr', z, A1)
        d1=torch.einsum('nr,nhr->nh', t1, B1)*self.scale
        h=F.silu(h_pre + d1).to(dt)
        A2=self.a2[a_idx].to(dt); B2=self.b2[a_idx].to(dt)
        t2=torch.einsum('nh,nrh->nr', h, A2)
        d2=torch.einsum('nr,ndr->nd', t2, B2)*self.scale
        return (out_pre + d2).to(dt)  # expert output in z-space

# ---------- NeuroPool Stage ----------
class NeuroPoolStage(nn.Module):
    def __init__(self, z_dim, hidden, pools:int, experts_per_pool:int, adapters_per_expert:int, rank:int):
        super().__init__()
        self.z_dim=z_dim; self.hidden=hidden
        self.P=pools; self.EPP=experts_per_pool; self.W=adapters_per_expert; self.rank=rank
        self.E_total=self.P*self.EPP

        self.pool_proto=nn.Parameter(torch.randn(self.P, z_dim)*0.02)
        self.pool_bias =nn.Parameter(torch.zeros(self.P))

        self.expert_proto=nn.Parameter(torch.randn(self.E_total, z_dim)*0.02)
        self.expert_bias =nn.Parameter(torch.zeros(self.E_total))
        self.adapter_keys=nn.Parameter(torch.randn(self.E_total, self.W, z_dim)*0.02)
        self.experts=nn.ModuleList([ExpertLoRA(z_dim, hidden, self.W, rank) for _ in range(self.E_total)])

        self.fc1=nn.Linear(z_dim, hidden)
        self.fc2=nn.Linear(hidden, z_dim)
        nn.init.kaiming_uniform_(self.fc1.weight, a=math.sqrt(5))
        nn.init.zeros_(self.fc1.bias); nn.init.zeros_(self.fc2.bias)

        pool_ids=[]
        for p in range(self.P): pool_ids += [p]*self.EPP
        self.register_buffer("expert_pool_id", torch.tensor(pool_ids, dtype=torch.long))

        self.register_buffer("explorer_mask", torch.zeros(self.P, dtype=torch.bool))
        self.register_buffer("explorer_bias", torch.zeros(self.P))

    def capacity_route(self, probs, top_k, capacity):
        B,E=probs.shape; k=min(top_k,E)
        cand_k=min(E, max(k*4,k))
        vals,idxs=torch.topk(probs, k=cand_k, dim=-1)
        assigned_idx=torch.full((B,k), -1, dtype=torch.long, device=probs.device)
        assigned_w=torch.zeros(B,k, dtype=probs.dtype, device=probs.device)
        cap=torch.zeros(E, dtype=torch.int32, device=probs.device)
        cap_limit=torch.full((E,), max(1,capacity), dtype=torch.int32, device=probs.device)
        cap_hit=0; total=B*k
        for b in range(B):
            slot=0
            for j in range(cand_k):
                if slot>=k: break
                e=int(idxs[b,j].item())
                if cap[e] < cap_limit[e]:
                    assigned_idx[b,slot]=e
                    assigned_w[b,slot]=vals[b,j]
                    cap[e]+=1
                    if j>0: cap_hit+=1
                    slot+=1
        return assigned_idx, assigned_w, cap_hit/max(1,total)

    def pick_adapters(self, z, expert_ids):
        # expert_ids: a single expert id repeated (we call per expert)
        dt=z.dtype
        logits = z @ self.adapter_keys[expert_ids[0]].to(dt).t()
        return torch.argmax(logits, dim=-1)

    def forward(self, z, tau, epsilon, top_k, capacity):
        """
        Returns:
          z_out: [B,D], used_mask: [B,E_total], pools_used: [B,P],
          cap_hit_rate: float, pool_probs: [B,P], cache:(z, h_pre, out_pre)
        """
        B,D=z.shape; dt=z.dtype
        h_pre=F.silu(self.fc1(z)); out_pre=self.fc2(h_pre)

        # Pool gate
        pool_logits = F.linear(z, self.pool_proto)/max(1e-6,tau) + self.pool_bias
        pool_logits = pool_logits + (self.explorer_bias * self.explorer_mask.float()).to(pool_logits.dtype)
        pool_probs = F.softmax(pool_logits, dim=-1)
        if epsilon>0:
            pool_probs = (1-epsilon)*pool_probs + epsilon*(torch.ones_like(pool_probs)/self.P)
        pool_idx = torch.argmax(pool_probs, dim=-1)
        pools_used = F.one_hot(pool_idx, num_classes=self.P).to(dtype=torch.float32)

        # Expert routing per chosen pool
        z_out=torch.zeros_like(z)
        used_mask=torch.zeros(B, self.E_total, dtype=torch.bool, device=z.device)
        cap_hits=0.0; batches=0

        for p in pool_idx.unique():
            p=int(p.item())
            sel=(pool_idx==p)
            if not sel.any(): continue
            z_sel=z[sel]; h_sel=h_pre[sel]; out_sel=out_pre[sel]
            start=p*self.EPP; end=start+self.EPP
            e_ids=torch.arange(start,end, device=z.device)

            exp_logits=(z_sel @ self.expert_proto[e_ids].to(dt).t())/max(1e-6,tau)+self.expert_bias[e_ids].to(dt)
            exp_probs = F.softmax(exp_logits, dim=-1)
            if epsilon>0:
                exp_probs=(1-epsilon)*exp_probs+epsilon*(torch.ones_like(exp_probs)/exp_probs.size(1))

            a_idx, a_w, cap_hit = self.capacity_route(exp_probs, top_k, capacity)
            cap_hits+=cap_hit; batches+=1

            for j in range(a_idx.size(1)):
                idx_j=a_idx[:,j]; w_j=a_w[:,j].unsqueeze(1).to(dt)
                sub=(idx_j>=0)
                if not sub.any(): continue
                idx_j_sel=idx_j[sub] + start
                z_s=z_sel[sub]; h_s=h_sel[sub]; out_s=out_sel[sub]
                uniq = idx_j_sel.unique()
                for e in uniq:
                    e=int(e.item())
                    sub2=(idx_j_sel==e)
                    z_ss=z_s[sub2]; h_ss=h_s[sub2]; out_ss=out_s[sub2]
                    aids=self.pick_adapters(z_ss, torch.tensor([e], device=z.device))
                    y=self.experts[e].forward_batch(z_ss, aids, h_ss, out_ss).to(dt)
                    b_idx=torch.nonzero(sel, as_tuple=False).squeeze(1)[sub][sub2]
                    z_out[b_idx]=z_out[b_idx] + w_j[sub][sub2]*y
                    used_mask[b_idx, e]=True

        cap_hit_rate=cap_hits/max(1,batches)
        cache=(z, h_pre, out_pre)
        return z + z_out, used_mask, pools_used, cap_hit_rate, pool_probs, cache

# ---------- Sequential NeuroPool Model ----------
class SequentialNeuroPoolModel(nn.Module):
    def __init__(self, stages:int, pools:int, epp:int, adapters:int, rank:int, z_dim:int, hidden:int):
        super().__init__()
        self.vae=vae
        self.projector=LatentProjector(z_dim, c=4, width=cfg.proj_channels)
        self.stages=nn.ModuleList([NeuroPoolStage(z_dim, hidden, pools, epp, adapters, rank) for _ in range(stages)])
        self.dec=Decoder(z_dim)
    def encode_vae(self, x):
        enc=self.vae.encode(x).latent_dist
        z=enc.mode() if cfg.vae_use_mode else enc.sample()
        return z*cfg.vae_scale
    def forward(self, x, tau, epsilon, top_k, capacities):
        with torch.set_grad_enabled(self.training and (cfg.vae_ft_last_block or cfg.vae_ft_quant_conv)):
            zmap=self.encode_vae(x)
        z=self.projector(zmap)

        stage_masks=[]; stage_pools=[]; stage_pool_probs=[]; stage_deltas=[]; stage_cache=[]; cap_hits=[]
        z_in=z
        for i,stage in enumerate(self.stages):
            z_out, used_mask, pools_used, cap_hit, pool_probs, cache = stage(z_in, tau, epsilon, top_k, capacities[i])
            stage_masks.append(used_mask); stage_pools.append(pools_used); stage_pool_probs.append(pool_probs)
            stage_deltas.append(z_out - z_in); stage_cache.append(cache); cap_hits.append(cap_hit)
            z_in=z_out

        x_rec=self.dec(z_in)
        return x_rec, z, z_in, stage_masks, stage_pools, stage_pool_probs, stage_deltas, stage_cache, sum(cap_hits)/max(1,len(cap_hits))

def denorm(x): return (x.clamp(-1,1)+1)*0.5

# ---------- Losses ----------
percept=lpips.LPIPS(net='vgg').to(device).eval()

def _gauss_kernel(ch, k=11, sig=1.5, dtype=torch.float32, device='cpu'):
    coords=torch.arange(k, dtype=dtype, device=device)-(k-1)/2
    g=torch.exp(-(coords**2)/(2*sig*sig)); g=(g/g.sum()).unsqueeze(0)
    kernel=(g.t()@g); kernel=kernel/kernel.sum()
    kernel=kernel.view(1,1,k,k).repeat(ch,1,1,1)
    return kernel

def ssim_loss(img1, img2, k=11, sig=1.5, eps=1e-6):
    B,C,H,W=img1.shape; dtype=img1.dtype; dev=img1.device
    kernel=_gauss_kernel(C,k,sig,dtype,dev); pad=k//2
    mu1=F.conv2d(img1, kernel, padding=pad, groups=C)
    mu2=F.conv2d(img2, kernel, padding=pad, groups=C)
    mu1_sq=mu1*mu1; mu2_sq=mu2*mu2; mu1_mu2=mu1*mu2
    sigma1_sq=F.conv2d(img1*img1, kernel, padding=pad, groups=C)-mu1_sq
    sigma2_sq=F.conv2d(img2*img2, kernel, padding=pad, groups=C)-mu2_sq
    sigma12  =F.conv2d(img1*img2, kernel, padding=pad, groups=C)-mu1_mu2

    C1=0.01**2; C2=0.03**2
    ssim_map=((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)+eps)
    return (1 - ssim_map.mean())

def repulsion_loss(P, m=256, sigma=0.15):
    E=P.size(0)
    if E<=1: return P.sum()*0
    idx=torch.randperm(E, device=P.device)[:min(m,E)]
    X=F.normalize(P[idx], dim=-1)
    S=X@X.t()
    mask=~torch.eye(X.size(0), device=X.device, dtype=torch.bool)
    s=S[mask]
    return torch.exp(s/sigma).mean()

# ---------- Stats ----------
class StageStats:
    def __init__(self, E_total:int, P:int, D:int):
        dev=device
        self.E=E_total; self.P=P; self.D=D
        self.usage_epoch=torch.zeros(E_total, dtype=torch.float32, device=dev)
        self.usage_ema  =torch.zeros(E_total, dtype=torch.float32, device=dev)
        self.pool_usage_epoch=torch.zeros(P, dtype=torch.float32, device=dev)
        self.pool_usage_ema  =torch.zeros(P, dtype=torch.float32, device=dev)
        self.task_vec=torch.zeros(E_total, D, dtype=torch.float32, device=dev)  # EMA of Δz per expert
    def step_decay(self, beta=0.9):
        self.usage_ema = beta*self.usage_ema + (1-beta)*self.usage_epoch
        self.pool_usage_ema = beta*self.pool_usage_ema + (1-beta)*self.pool_usage_epoch
        self.usage_epoch.zero_(); self.pool_usage_epoch.zero_()

def share_hhi(v):
    tot=v.sum().item()+1e-6
    s=v/tot
    hhi=float((s*s).sum().item())
    return s, hhi

# ---------- Build model (OOM-aware) ----------
def build_model_oom():
    stages=cfg.stages; pools=cfg.pools_per_stage; epp=cfg.experts_per_pool
    adapters=cfg.adapters_per_expert; rank=cfg.lora_rank; z_dim=cfg.z_dim; hidden=cfg.hidden
    while True:
        try:
            model=SequentialNeuroPoolModel(stages,pools,epp,adapters,rank,z_dim,hidden).to(device)
            stats=[StageStats(stage.E_total, stage.P, z_dim) for stage in model.stages]
            print(f"[Model] stages={stages} pools/stage={pools} experts/pool={epp} totalE/stage={pools*epp}")
            return model, stats
        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                if epp>=16:
                    epp//=2; print(f"[Model] OOM → experts_per_pool={epp}")
                elif pools>=4:
                    pools//=2; print(f"[Model] OOM → pools={pools}")
                else:
                    raise
            else:
                raise

model, stats_list = build_model_oom()

# ---------- Optimizers ----------
params=[]
params += list(model.projector.parameters())
params += list(model.dec.parameters())
for s in model.stages: params += list(s.parameters())
if cfg.vae_ft_last_block or cfg.vae_ft_quant_conv:
    for n,p in model.vae.named_parameters():
        if p.requires_grad: params.append(p)
opt=torch.optim.AdamW(params, lr=cfg.lr, betas=cfg.betas, weight_decay=cfg.weight_decay)
scaler=torch.amp.GradScaler("cuda", enabled=cfg.amp)

# ---------- Schedules ----------
def tau_sched(epoch):
    if epoch<=5: return cfg.tau_high
    t=(epoch-5)/max(1,(cfg.epochs-5))
    return cfg.tau_high + (cfg.tau_low-cfg.tau_high)*t

def compute_capacity(B, E, k):
    return max(1, int(math.ceil(cfg.capacity_alpha * B * k / max(1,E))))

# ---------- Stable loss EMA ----------
loss_ema={"lp":1.0, "ssim":1.0, "l1":1.0, "mse":1.0}
def update_loss_ema(vals:dict):
    for k,v in vals.items():
        loss_ema[k]=cfg.ema_beta_loss*loss_ema[k] + (1-cfg.ema_beta_loss)*float(v)

# ---------- Loss Diffusion (broadcast) ----------
def diffusion_loss_for_stage(stage:NeuroPoolStage, st:StageStats, cache, delta_batch, used_mask, max_experts:int, max_items:int):
    if not cfg.diffuse_enabled: return delta_batch.new_tensor(0.0)
    z_in, h_pre, out_pre = cache   # [B,D], [B,H], [B,D]
    B=z_in.size(0); D=z_in.size(1)
    with torch.no_grad():
        # teacher delta (normalize mean over batch)
        tvec = delta_batch.mean(dim=0)
        tnorm = F.normalize(tvec, dim=0, eps=1e-6)
    # choose subset of items & experts (non-selected experts only)
    idx_items = torch.randperm(B, device=z_in.device)[:min(max_items, B)]
    used_e = used_mask[idx_items].any(dim=0)     # [E]
    cand_e = (~used_e).nonzero(as_tuple=False).squeeze(1)
    if cand_e.numel()==0: return delta_batch.new_tensor(0.0)
    pick_e = cand_e[torch.randperm(cand_e.numel(), device=z_in.device)[:min(max_experts, cand_e.numel())]]

    # similarity gate using task_vec
    if st.task_vec.sum().abs().item()>0:
        sim = F.cosine_similarity(F.normalize(st.task_vec[pick_e], dim=-1, eps=1e-6),
                                  tnorm.view(1,-1).expand(pick_e.numel(), D), dim=-1).clamp(0,1)
        gate = sim.detach()
    else:
        gate = torch.ones(pick_e.numel(), device=z_in.device)

    # compute expert outputs -> deltas, then align to teacher
    loss_sum = tnorm.new_tensor(0.0)
    cnt = 0
    for e_i, g in zip(pick_e.tolist(), gate.tolist()):
        if g<=0: continue
        # per-expert, on sampled items
        z_s = z_in[idx_items]; h_s=h_pre[idx_items]; out_s=out_pre[idx_items]
        aids = stage.pick_adapters(z_s, torch.tensor([e_i], device=z_in.device))
        y = stage.experts[e_i].forward_batch(z_s, aids, h_s, out_s)          # [m,D]
        delta_e = y - out_s
        # alignment loss: 1 - cos(delta_e, tnorm)
        dnorm = F.normalize(delta_e, dim=-1, eps=1e-6)
        align = 1.0 - (dnorm @ tnorm.view(-1,1)).squeeze(1).clamp(-1+1e-6, 1-1e-6)
        loss_sum = loss_sum + g * align.mean()
        cnt += 1
    if cnt==0: return delta_batch.new_tensor(0.0)
    return cfg.diffuse_lambda * (loss_sum / cnt)

# ---------- Soft Pruning (bias suppression) ----------
def soft_prune(epoch, model, stats_list):
    if epoch<cfg.prune_warmup_epochs or epoch%cfg.prune_interval!=0: return
    for stage, st in zip(model.stages, stats_list):
        share,_=share_hhi(st.usage_ema)
        # pick worst by share (but very conservative)
        E=stage.E_total
        k=min(cfg.prune_max_per_stage, max(1, E//100))  # <=1% or max cfg
        vals, ids = torch.topk(share, k=min(E, E))  # all sorted
        # bottom-k
        bottom = ids[-k:].tolist()
        with torch.no_grad():
            for e in bottom:
                stage.expert_bias[e].add_(cfg.prune_bias_push)   # lower its gate prob softly

# ---------- Train ----------
def save_samples(epoch, model):
    model.eval()
    imgs=next(iter(val_loader))[:cfg.sample_rows*cfg.sample_rows].to(device)
    with torch.no_grad(), torch.amp.autocast("cuda", enabled=cfg.amp):
        rec, *_ = model(imgs, tau=1.0, epsilon=0.0, top_k=cfg.top_k,
                        capacities=[compute_capacity(imgs.size(0), st.E, cfg.top_k) for st in stats_list])
    grid=make_grid(torch.cat([denorm(imgs), denorm(rec)], dim=0), nrow=cfg.sample_rows)
    save_image(grid, os.path.join(cfg.out_samples_dir, f"epoch_{epoch:03d}.png"))
    model.train()

def log_epoch(epoch, report):
    print(f"\n[Epoch {epoch}] TrainMain={report['train_main']:.4f} | Val={report['val']:.4f} | CapHit={report['cap_hit']:.3f}")
    for k,v in report.items():
        if k.startswith("stage_"):
            st=report[k]
            print(f"  [{k}] Pools={st['pools']} Experts={st['experts']} | HHI={st['hhi']:.4f} | TopShare={st['top_share'][:3]}")
    with open(os.path.join(cfg.out_log_dir, f"epoch_{epoch:03d}_report.json"), "w") as f:
        json.dump(report, f, indent=2)

per_epoch_reports=[]
opt.zero_grad(set_to_none=True)

for epoch in range(1, cfg.epochs+1):
    model.train()
    tau=tau_sched(epoch); eps=cfg.epsilon

    # Explorer bias on low-usage pools
    with torch.no_grad():
        for stage, st in zip(model.stages, stats_list):
            P=stage.P; k_exp=max(1,int(P*cfg.explorer_ratio))
            idx=torch.argsort(st.pool_usage_ema)[:k_exp]
            stage.explorer_mask.zero_(); stage.explorer_mask[idx]=True
            stage.explorer_bias = stage.explorer_bias*cfg.explorer_bias_decay + cfg.explorer_bias_init*(stage.explorer_mask.float())

    running_main=0.0; running_lp=0.0; running_ssim=0.0; running_l1=0.0; running_mse=0.0
    rep_sum=0.0; cap_sum=0.0; cap_cnt=0

    pbar=tqdm(train_loader, desc=f"Epoch {epoch}/{cfg.epochs} (train)")
    for batch in pbar:
        batch=batch.to(device, non_blocking=True)
        B=batch.size(0)
        capacities=[compute_capacity(B, st.E, cfg.top_k) for st in stats_list]

        opt.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda", enabled=cfg.amp):
            x_rec, z0, zL, stage_masks, stage_pools, stage_pool_probs, stage_deltas, stage_cache, cap_hit = \
                model(batch, tau=tau, epsilon=eps, top_k=cfg.top_k, capacities=capacities)

            den_gt=denorm(batch).float(); den_rec=denorm(x_rec).float()
            lp = percept(den_rec, den_gt).mean()
            ssim = ssim_loss(den_rec, den_gt)
            l1 = F.l1_loss(x_rec, batch)
            mse = F.mse_loss(x_rec, batch)

            # Update EMA & normalize losses (stabilize scale)
            update_loss_ema({"lp":lp.item(), "ssim":ssim.item(), "l1":l1.item(), "mse":mse.item()})
            lp_n   = lp   / (loss_ema["lp"]   + cfg.loss_eps)
            ssim_n = ssim / (loss_ema["ssim"] + cfg.loss_eps)
            l1_n   = l1   / (loss_ema["l1"]   + cfg.loss_eps)
            mse_n  = mse  / (loss_ema["mse"]  + cfg.loss_eps)

            loss_main = cfg.w_lpips*lp_n + cfg.w_ssim*ssim_n + cfg.w_l1*l1_n + cfg.w_mse*mse_n

            # Repulsion (pool/expert prototypes)
            rep=0.0
            for stage in model.stages:
                rep = rep + repulsion_loss(stage.pool_proto, m=cfg.repulsion_subset, sigma=cfg.repulsion_sigma)
                rep = rep + repulsion_loss(stage.expert_proto, m=cfg.repulsion_subset, sigma=cfg.repulsion_sigma)

            # Loss Diffusion (broadcast) per stage
            diff_loss = z0.new_tensor(0.0)
            if cfg.diffuse_enabled:
                max_items = min(cfg.diffuse_batch_items, B)
                max_exps  = cfg.diffuse_experts_per_stage
                for stg, st, cache, dlt, msk in zip(model.stages, stats_list, stage_cache, stage_deltas, stage_masks):
                    diff_loss = diff_loss + diffusion_loss_for_stage(stg, st, cache, dlt.detach(), msk.detach(), max_exps, max_items)

            loss = loss_main + cfg.repulsion_lambda*rep + diff_loss
            loss = torch.nan_to_num(loss, nan=0.0, posinf=1e4, neginf=-1e4)

        scaler.scale(loss).backward()
        scaler.unscale_(opt)
        torch.nn.utils.clip_grad_norm_(params, cfg.clip_grad_norm)
        scaler.step(opt); scaler.update()

        running_main+=float(loss_main.item())
        running_lp+=float(lp.item()); running_ssim+=float(ssim.item())
        running_l1+=float(l1.item()); running_mse+=float(mse.item())
        rep_sum+=float(rep.item()); cap_sum+=cap_hit; cap_cnt+=1

        # Stats update
        with torch.no_grad():
            for (mask, pools, stage, st, delta) in zip(stage_masks, stage_pools, model.stages, stats_list, stage_deltas):
                st.usage_epoch[:stage.E_total]+=mask.float().sum(0)
                st.pool_usage_epoch[:stage.P]+=pools.float().sum(0)
                counts=mask.float().sum(0)
                if counts.sum()>0:
                    sum_e=delta.transpose(0,1)@mask.float()     # [D,E]
                    mean_e=(sum_e/(counts+1e-6)).transpose(0,1) # [E,D]
                    use_idx=torch.nonzero(counts>0, as_tuple=False).squeeze(1)
                    tv_old=st.task_vec[use_idx]
                    st.task_vec[use_idx]=0.9*tv_old + 0.1*mean_e[use_idx].to(tv_old.dtype)

        pbar.set_postfix(loss=f"{loss_main.item():.4f}", lpips=f"{lp.item():.3f}",
                         ssim=f"{ssim.item():.3f}", L1=f"{l1.item():.3f}",
                         tau=f"{tau:.2f}", cap=f"{cap_hit:.3f}")

    for st in stats_list: st.step_decay(beta=0.9)

    # Soft pruning (bias suppression)
    soft_prune(epoch, model, stats_list)

    # Validation
    model.eval(); val_loss=0.0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch}/{cfg.epochs} (val)"):
            batch=batch.to(device, non_blocking=True)
            with torch.amp.autocast("cuda", enabled=cfg.amp):
                x_rec, *_ = model(batch, tau=tau, epsilon=0.0, top_k=cfg.top_k,
                                  capacities=[compute_capacity(batch.size(0), st.E, cfg.top_k) for st in stats_list])
                den_gt=denorm(batch).float(); den_rec=denorm(x_rec).float()
                lp=percept(den_rec, den_gt).mean()
                ssim=ssim_loss(den_rec, den_gt)
                l1=F.l1_loss(x_rec, batch); mse=F.mse_loss(x_rec, batch)
                # use same normalized combination for stability of metric
                lp_n=lp/(loss_ema["lp"]+cfg.loss_eps)
                ssim_n=ssim/(loss_ema["ssim"]+cfg.loss_eps)
                l1_n=l1/(loss_ema["l1"]+cfg.loss_eps)
                mse_n=mse/(loss_ema["mse"]+cfg.loss_eps)
                loss_v = cfg.w_lpips*lp_n + cfg.w_ssim*ssim_n + cfg.w_l1*l1_n + cfg.w_mse*mse_n
            val_loss += float(loss_v.item())*batch.size(0)
    val_loss/=max(1,len(val_ds))

    # Save sample & checkpoint
    save_samples(epoch, model)
    torch.save({
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": opt.state_dict(),
        "config": cfg.__dict__,
        "loss_ema": loss_ema,
        "stats": {
            f"stage_{i+1}":{
                "usage_ema": st.usage_ema.detach().cpu().tolist(),
                "pool_usage_ema": st.pool_usage_ema.detach().cpu().tolist(),
                "task_vec": st.task_vec.detach().cpu().tolist(),
            } for i,st in enumerate(stats_list)
        }
    }, os.path.join(cfg.out_ckpt_dir, f"epoch_{epoch:03d}.pt"))

    # Reporting
    report={
        "epoch": epoch,
        "train_main": running_main/len(train_loader),
        "train_lpips": running_lp/len(train_loader),
        "train_ssim": running_ssim/len(train_loader),
        "train_l1": running_l1/len(train_loader),
        "train_mse": running_mse/len(train_loader),
        "train_repulsion": rep_sum/len(train_loader),
        "cap_hit": cap_sum/max(1,cap_cnt),
        "val": val_loss,
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
    }
    for i,(stage, st) in enumerate(zip(model.stages, stats_list), start=1):
        share,hhi=share_hhi(st.usage_ema)
        topk=min(5, stage.E_total)
        if topk>0:
            vals, ids = torch.topk(share, k=topk)
            top=[(int(ids[j]), float(vals[j])) for j in range(topk)]
        else:
            top=[]
        report[f"stage_{i}"]={"pools":int(stage.P), "experts":int(stage.E_total),
                              "hhi":hhi, "top_share":top}
    log_epoch(epoch, report)

print("✅ Training complete. See ./checkpoints, ./samples, ./logs")


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100    17  100    17    0     0     29      0 --:--:-- --:--:-- --:--:--    29
100   496    0   496    0     0    468      0 --:--:--  0:00:01 --:--:--  2084
100 8297k  100 8297k    0     0  1918k      0  0:00:04  0:00:04 --:--:-- 3663k
Archive:  temp.zip
  inflating: images/0.jpg            
  inflating: images/1.jpg            
  inflating: images/10.jpg           
  inflating: images/100.jpg          
  inflating: images/1000.jpg         
  inflating: images/1001.jpg         
  inflating: images/1002.jpg         
  inflating: images/1003.jpg         
  inflating: images/1004.jpg         
  inflating: images/1005.jpg         
  inflating: images/1006.jpg         
  inflating: images/1007.jpg         
  inflating: images/1008.jpg         
  inflating: images/1009.jpg         
  inflating: images/101.jpg          
  inflating: ima