<a href="https://colab.research.google.com/github/username12345678901234567890/1234/blob/main/Untitled1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# %% [single-cell] Self-Organizing MoE v3.1 — E=16384 start, Lower S-threshold, Single-reporter, Growth+Prune, Device-safe
import os, sys, json, time, glob, math, random, subprocess
from dataclasses import dataclass
from typing import List, Tuple, Optional

# I/O
os.makedirs("./checkpoints", exist_ok=True)
os.makedirs("./samples", exist_ok=True)

# ---- LPIPS (optional) ----
_HAS_LPIPS=False
try:
    import lpips  # noqa
    _HAS_LPIPS=True
except Exception:
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "lpips==0.1.4"])
        import lpips
        _HAS_LPIPS=True
    except Exception:
        _HAS_LPIPS=False

import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from PIL import Image
from tqdm.auto import tqdm

torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[Device] {device}")

# ---------------- Config ----------------
@dataclass
class Config:
    data_root: str="/images"
    image_size: int=128
    epochs: int=20
    batch_size: int=192

    # Colab-safe
    num_workers: int=0
    pin_memory: bool=True

    # Optimizer (SGD for lighter state)
    lr: float=2e-3
    momentum: float=0.9
    weight_decay: float=1e-4
    amp: bool=True

    # Latent/model
    z_dim: int=512
    hidden: int=512

    # Experts / routing
    init_experts_desired: int=16384
    top_k: int=4
    epsilon_start: float=0.20
    epsilon_end: float=0.05
    tau_high: float=1.8
    tau_low: float=0.7

    # Capacity
    capacity_alpha: float=1.2

    # Entropy/repulsion
    gate_entropy_lambda_warm: float=0.05
    gate_entropy_lambda_cool: float=0.01
    repulsion_lambda: float=1e-3
    repulsion_subset: int=256
    repulsion_sigma: float=0.15

    # Per-expert virtual neurons (memory-light)
    lora_rank: int=2
    adapters_per_expert: int=8

    # Pattern reporting (lower threshold; single-reporter)
    report_theta: float=1.2
    report_K: int=4
    report_M: int=2
    review_Kp: int=6
    review_Mp: int=2
    dup_cos_thr: float=0.90

    # Regulation & credit
    big_share_thr: float=0.05
    big_topN: int=5
    share_tax_thr: float=0.10
    share_tax_lambda: float=0.2
    neg_interest: float=0.03
    credit_decay: float=0.98
    credit_false_penalty: float=1.0
    credit_true_reward: float=2.0
    credit_contrib_scale: float=5.0
    credit_diversity_scale: float=1.0
    bankrupcy_level: float=10.0
    bankrupcy_streak: int=3

    # Recruitment
    team_base_Q: int=256
    team_alpha_S: float=0.5
    recruit_donor_bias: float=-0.20
    recruit_new_bias: float=+0.30

    # Pruning
    enable_prune: bool=True
    prune_min_keep: int=1024
    prune_usage_eps: float=1e-3
    prune_credit_level: float=10.0
    prune_max_per_epoch: int=2048

    # Saving
    out_ckpt_dir: str="./checkpoints"
    out_samples_dir: str="./samples"
    sample_rows: int=4
    seed: int=42

cfg = Config()
random.seed(cfg.seed); torch.manual_seed(cfg.seed)

# ---------------- Dataset ----------------
IMG_EXTS=(".png",".jpg",".jpeg",".bmp",".webp")

class PlainImageDataset(Dataset):
    def __init__(self, root, size=128, train=True):
        self.paths=[p for p in glob.glob(os.path.join(root,"**","*"), recursive=True)
                    if os.path.splitext(p)[1].lower() in IMG_EXTS]
        if not self.paths:
            raise RuntimeError(f"No images found under {root}.")
        aug=[transforms.Resize(size), transforms.CenterCrop(size)]
        if train:
            aug.insert(1, transforms.RandomHorizontalFlip(0.5))
            aug.insert(2, transforms.ColorJitter(0.1,0.1,0.05,0.02))
        self.tf=transforms.Compose(aug+[
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3,[0.5]*3)
        ])
    def __len__(self): return len(self.paths)
    def __getitem__(self, i):
        return self.tf(Image.open(self.paths[i]).convert("RGB"))

full_ds=PlainImageDataset(cfg.data_root, size=cfg.image_size, train=True)
val_ratio=0.02 if len(full_ds)>2000 else 0.1
val_len=max(16, int(len(full_ds)*val_ratio))
train_len=len(full_ds)-val_len
train_ds, val_ds = random_split(full_ds,[train_len,val_len], generator=torch.Generator().manual_seed(cfg.seed))

train_loader=DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                        num_workers=cfg.num_workers, pin_memory=cfg.pin_memory, drop_last=True)
val_loader=DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False,
                      num_workers=cfg.num_workers, pin_memory=cfg.pin_memory, drop_last=False)
print(f"[Data] train={len(train_ds)} val={len(val_ds)} batch={cfg.batch_size}")

# ---------------- Model ----------------
def conv_block(in_ch,out_ch,k=3,s=1,p=1):
    return nn.Sequential(nn.Conv2d(in_ch,out_ch,k,s,p), nn.GroupNorm(8,out_ch), nn.SiLU())

class Encoder(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        ch=64
        self.net=nn.Sequential(
            conv_block(3,ch), conv_block(ch,ch),
            conv_block(ch,ch*2,s=2), conv_block(ch*2,ch*2),
            conv_block(ch*2,ch*4,s=2), conv_block(ch*4,ch*4),
            conv_block(ch*4,ch*8,s=2), conv_block(ch*8,ch*8),
            conv_block(ch*8,ch*8,s=2), conv_block(ch*8,ch*8),
        )
        self.proj=nn.Linear(ch*8*8*8, z_dim)
    def forward(self,x):
        h=self.net(x).view(x.size(0),-1)
        return self.proj(h)

class Decoder(nn.Module):
    def __init__(self,z_dim):
        super().__init__()
        ch=64
        self.fc=nn.Linear(z_dim, ch*8*8*8)
        self.up=nn.Sequential(
            conv_block(ch*8,ch*8),
            nn.Upsample(scale_factor=2,mode='nearest'),
            conv_block(ch*8,ch*8),
            nn.Upsample(scale_factor=2,mode='nearest'),
            conv_block(ch*8,ch*4),
            nn.Upsample(scale_factor=2,mode='nearest'),
            conv_block(ch*4,ch*2),
            nn.Upsample(scale_factor=2,mode='nearest'),
            conv_block(ch*2,ch),
            nn.Conv2d(ch,3,3,1,1), nn.Tanh()
        )
    def forward(self,z):
        h=self.fc(z).view(z.size(0),512,8,8)
        return self.up(h)

class ExpertLoRA(nn.Module):
    def __init__(self, z_dim, hidden, W, rank):
        super().__init__()
        self.z_dim=z_dim; self.hidden=hidden; self.W=W; self.r=rank
        self.a1=nn.Parameter(torch.randn(W, rank, z_dim)*0.02)   # [W,r,D]
        self.b1=nn.Parameter(torch.randn(W, hidden, rank)*0.02)  # [W,H,r]
        self.a2=nn.Parameter(torch.randn(W, rank, hidden)*0.02)  # [W,r,H]
        self.b2=nn.Parameter(torch.randn(W, z_dim, rank)*0.02)   # [W,D,r]
        self.scale=1.0
    def forward_batch(self, z, a_idx, h_pre, out_pre):
        dt = z.dtype
        A1 = self.a1[a_idx].to(dt)
        B1 = self.b1[a_idx].to(dt)
        t1 = torch.einsum('nd,nrd->nr', z, A1)
        d1 = torch.einsum('nr,nhr->nh', t1, B1) * self.scale
        h  = F.silu(h_pre + d1).to(dt)
        A2 = self.a2[a_idx].to(dt)
        B2 = self.b2[a_idx].to(dt)
        t2 = torch.einsum('nh,nrh->nr', h, A2)
        d2 = torch.einsum('nr,ndr->nd', t2, B2) * self.scale
        return (out_pre + d2).to(dt)

class MoELatent(nn.Module):
    def __init__(self, z_dim, hidden, init_experts, adapters_per_expert, rank):
        super().__init__()
        # shared MLP
        self.fc1=nn.Linear(z_dim, hidden)
        self.fc2=nn.Linear(hidden, z_dim)
        nn.init.kaiming_uniform_(self.fc1.weight, a=math.sqrt(5))
        nn.init.zeros_(self.fc1.bias); nn.init.zeros_(self.fc2.bias)

        self.z_dim=z_dim; self.hidden=hidden
        self.experts=nn.ModuleList([ExpertLoRA(z_dim,hidden,adapters_per_expert,rank) for _ in range(init_experts)])
        self.prototypes=nn.Parameter(torch.randn(init_experts, z_dim)*0.02)
        self.adapter_keys=nn.Parameter(torch.randn(init_experts, adapters_per_expert, z_dim)*0.02)
        self.logit_bias=nn.Parameter(torch.zeros(init_experts))

    def num_experts(self): return len(self.experts)

    @torch.no_grad()
    def add_expert(self, clone_from: Optional[List[int]]=None, donors_fraction: float=0.7, noise_std: float=0.02, bias_boost: float=0.0):
        """
        Create a new Expert on SAME device/dtype as existing params.
        70% of adapter rows cloned from donors (+noise), 30% random init.
        """
        dev = self.prototypes.device
        dt  = self.prototypes.dtype
        W=cfg.adapters_per_expert; D=self.z_dim; H=self.hidden; R=cfg.lora_rank

        new = ExpertLoRA(D,H,W,R).to(device=dev, dtype=dt)

        if clone_from:
            donors = [d for d in clone_from if 0<=d<self.num_experts()]
            if len(donors)>0:
                take = int(W*donors_fraction)
                idxs = torch.randperm(W, device=dev)[:take]
                donor_sel = torch.randint(0, len(donors), (take,), device=dev)
                for wrow,di in zip(idxs.tolist(), donor_sel.tolist()):
                    d = donors[di]
                    new.a1[wrow].copy_( self.experts[d].a1[wrow].to(dev, dt) + noise_std*torch.randn_like(new.a1[wrow], device=dev, dtype=dt) )
                    new.b1[wrow].copy_( self.experts[d].b1[wrow].to(dev, dt) + noise_std*torch.randn_like(new.b1[wrow], device=dev, dtype=dt) )
                    new.a2[wrow].copy_( self.experts[d].a2[wrow].to(dev, dt) + noise_std*torch.randn_like(new.a2[wrow], device=dev, dtype=dt) )
                    new.b2[wrow].copy_( self.experts[d].b2[wrow].to(dev, dt) + noise_std*torch.randn_like(new.b2[wrow], device=dev, dtype=dt) )
                proto = self.prototypes.data[donors].mean(dim=0).to(dev, dt).clone()
                keys  = self.adapter_keys.data[donors].mean(dim=0).to(dev, dt).clone()
            else:
                proto=torch.randn(D, device=dev, dtype=dt)*noise_std
                keys=torch.randn(W, D, device=dev, dtype=dt)*noise_std
        else:
            proto=torch.randn(D, device=dev, dtype=dt)*noise_std
            keys=torch.randn(W, D, device=dev, dtype=dt)*noise_std

        self.experts.append(new)
        self.prototypes = nn.Parameter(torch.cat([self.prototypes.data, proto[None,:]], dim=0).to(dev, dt))
        self.adapter_keys = nn.Parameter(torch.cat([self.adapter_keys.data, keys[None,:,:]], dim=0).to(dev, dt))
        self.logit_bias = nn.Parameter(torch.cat([self.logit_bias.data, torch.tensor([bias_boost], device=dev, dtype=dt)], dim=0))

    @torch.no_grad()
    def remove_experts(self, remove_idx: List[int]):
        if not remove_idx: return 0
        E=self.num_experts()
        keep = sorted(set(i for i in range(E) if i not in set(remove_idx)))
        if len(keep)==E: return 0
        new_list=nn.ModuleList([self.experts[i] for i in keep])
        self.experts = new_list
        dev=self.prototypes.device; dt=self.prototypes.dtype
        self.prototypes = nn.Parameter(self.prototypes.data[keep].clone().to(dev, dt))
        self.adapter_keys = nn.Parameter(self.adapter_keys.data[keep].clone().to(dev, dt))
        self.logit_bias = nn.Parameter(self.logit_bias.data[keep].clone().to(dev, dt))
        return E - len(keep)

    def capacity_route(self, probs, top_k, capacity):
        B,E=probs.shape
        k=min(top_k, E)
        cand_k=min(E, max(k*4, k))
        vals, idxs=torch.topk(probs, k=cand_k, dim=-1)
        assigned_idx=torch.full((B,k), -1, dtype=torch.long, device=probs.device)
        assigned_w=torch.zeros(B,k, device=probs.device, dtype=probs.dtype)
        cap=torch.zeros(E, dtype=torch.int32, device=probs.device)
        cap_limit=torch.full((E,), max(1,capacity), dtype=torch.int32, device=probs.device)
        cap_hit=0; total_slots=B*k
        for b in range(B):
            slot=0
            for j in range(cand_k):
                if slot>=k: break
                e=int(idxs[b,j].item())
                if cap[e] < cap_limit[e]:
                    assigned_idx[b,slot]=e
                    assigned_w[b,slot]=vals[b,j]
                    cap[e]+=1
                    if j>0: cap_hit+=1
                    slot+=1
        cap_hit_rate = cap_hit/max(1,total_slots)
        return assigned_idx, assigned_w, cap_hit_rate

    def pick_adapters(self, z, expert_ids):
        N=z.size(0); dt=z.dtype
        a_idx=torch.zeros(N, dtype=torch.long, device=z.device)
        for u in expert_ids.unique():
            u=int(u.item())
            sel=(expert_ids==u)
            logits = z[sel] @ self.adapter_keys[u].to(dt).t()
            a_idx[sel]=torch.argmax(logits, dim=-1)
        return a_idx

    def forward(self, z, top_k, tau, epsilon, capacity, ban_expert: Optional[int]=None):
        h_pre = F.silu(self.fc1(z))
        out_pre = self.fc2(h_pre)

        logits=F.linear(z, self.prototypes) / max(tau,1e-6)
        logits = logits + self.logit_bias
        if ban_expert is not None and 0<=ban_expert<self.num_experts():
            logits[:, ban_expert] = -1e9
        probs=F.softmax(logits, dim=-1)
        if epsilon>0:
            probs=(1-epsilon)*probs + epsilon*(torch.ones_like(probs)/probs.size(1))

        assigned_idx, assigned_w, cap_hit_rate = self.capacity_route(probs, top_k, capacity)
        B,E=probs.shape; k=assigned_idx.size(1)
        z_out=torch.zeros_like(z)
        used_mask=torch.zeros(B,E, dtype=torch.bool, device=z.device)

        for j in range(k):
            idx_j=assigned_idx[:,j]; w_j=assigned_w[:,j].unsqueeze(1).to(z.dtype)
            sel=(idx_j>=0)
            if not sel.any(): continue
            idx_j_sel=idx_j[sel]; z_sel=z[sel]; h_sel=h_pre[sel]; out_sel=out_pre[sel]
            for u in idx_j_sel.unique():
                u=int(u.item())
                sub=(idx_j_sel==u)
                z_sub=z_sel[sub]; h_sub=h_sel[sub]; out_sub=out_sel[sub]
                a_sub=self.pick_adapters(z_sub, torch.full((z_sub.size(0),), u, dtype=torch.long, device=z.device))
                y=self.experts[u].forward_batch(z_sub, a_sub, h_sub, out_sub).to(z.dtype)
                b_idx=torch.nonzero(sel, as_tuple=False).squeeze(1)[sub]
                z_out[b_idx] = z_out[b_idx] + w_j[sel][sub] * y
                used_mask[b_idx, u]=True

        return z + z_out, probs, used_mask, assigned_idx, assigned_w, cap_hit_rate

class AutoEncoderMoE(nn.Module):
    def __init__(self, init_E: int):
        super().__init__()
        self.enc=Encoder(cfg.z_dim)
        self.moe=MoELatent(cfg.z_dim, cfg.hidden, init_E, cfg.adapters_per_expert, cfg.lora_rank)
        self.dec=Decoder(cfg.z_dim)
    def forward(self, x, top_k, tau, epsilon, capacity, ban_expert: Optional[int]=None):
        z=self.enc(x)
        z_moe, probs, used_mask, aidx, aw, cap_hit = self.moe(z, top_k, tau, epsilon, capacity, ban_expert=ban_expert)
        x_rec=self.dec(z_moe)
        return x_rec, z, z_moe, probs, used_mask, aidx, aw, cap_hit

def denorm(x): return (x.clamp(-1,1)+1)*0.5

# ---------------- Loss helpers ----------------
def gate_entropy(probs):
    p=probs.clamp_min(1e-8)
    return (-p * p.log()).sum(dim=-1).mean()

def repulsion_loss(protos, m=256, sigma=0.15):
    E=protos.size(0)
    if E<=1: return protos.sum()*0
    idx=torch.randperm(E, device=protos.device)[:min(m,E)]
    P=F.normalize(protos[idx], dim=-1)
    S=P @ P.t()
    mask=~torch.eye(P.size(0), device=P.device, dtype=torch.bool)
    s=S[mask]
    return torch.exp(s/sigma).mean()

# ---------------- Economy & Pattern Board ----------------
class ExpertStats:
    def __init__(self, E:int):
        self.reset(E)
    def reset(self,E):
        dev=device
        self.E=E
        self.usage_epoch=torch.zeros(E, dtype=torch.float32, device=dev)
        self.usage_ema  =torch.zeros(E, dtype=torch.float32, device=dev)
        self.contrib_ema=torch.zeros(E, dtype=torch.float32, device=dev)
        self.credit     =torch.zeros(E, dtype=torch.float32, device=dev)
        self.neg_streak =torch.zeros(E, dtype=torch.int32 , device=dev)
    def on_structure_add(self, new_added:int):
        if new_added<=0: return
        dev=device
        self.E += new_added
        self.usage_epoch = torch.cat([self.usage_epoch, torch.zeros(new_added, device=dev)])
        self.usage_ema   = torch.cat([self.usage_ema  , torch.zeros(new_added, device=dev)])
        self.contrib_ema = torch.cat([self.contrib_ema, torch.zeros(new_added, device=dev)])
        self.credit      = torch.cat([self.credit     , torch.zeros(new_added, device=dev)])
        self.neg_streak  = torch.cat([self.neg_streak , torch.zeros(new_added, dtype=torch.int32, device=dev)])
    def on_structure_remove(self, keep_idx: List[int]):
        dev=device
        keep_idx=torch.tensor(keep_idx, dtype=torch.long, device=dev)
        self.E = keep_idx.numel()
        self.usage_epoch = self.usage_epoch[keep_idx]
        self.usage_ema   = self.usage_ema[keep_idx]
        self.contrib_ema = self.contrib_ema[keep_idx]
        self.credit      = self.credit[keep_idx]
        self.neg_streak  = self.neg_streak[keep_idx]
    def epoch_decay(self, decay):
        self.usage_ema = decay*self.usage_ema + (1-decay)*self.usage_epoch
        self.usage_epoch.zero_()

class EWMA:
    def __init__(self, beta=0.98):
        self.beta=beta; self.mean=None; self.var=None
    def update(self, x: float):
        x=float(x)
        if self.mean is None:
            self.mean=x; self.var=1e-6
        else:
            m=self.mean; v=self.var
            m_new = self.beta*m + (1-self.beta)*x
            v_new = self.beta*v + (1-self.beta)*(x-m_new)*(x-m)+1e-9
            self.mean=m_new; self.var=v_new
    def normz(self, x):
        if self.mean is None: return 0.0
        return (float(x)-self.mean)/max(1e-8, math.sqrt(self.var))

class PatternCandidate:
    def __init__(self, signature, reporter_eid, S):
        self.signature=signature
        self.reporters=set([reporter_eid])
        self.S_list=[S]
        self.windows=1
        self.hits=1
        self.state="DETECTED"  # DETECTED -> REVIEW -> APPROVED/DECLINED
        self.created_step=0
        self.last_update_step=0

class PatternBoard:
    def __init__(self):
        self.S_ewma=EWMA(0.98)
        self.rec_ewma=EWMA(0.98)
        self.js_ewma=EWMA(0.98)
        self.sig_store=[]
        self.candidates: List[PatternCandidate]=[]
        self.step=0
        self.approved=0; self.declined=0
        self.cluster_centers=None  # float32
        self.cluster_counts=None
    def signature(self, z_bar, p_bar):
        v=torch.cat([F.normalize(z_bar.float(), dim=0), F.normalize(p_bar.float(), dim=0)], dim=0)
        v = F.normalize(v, dim=0)
        return v.detach()
    @staticmethod
    def js_div(p, q, eps=1e-8):
        p = p.float(); q = q.float()
        m=0.5*(p+q)
        kl1=(p.clamp_min(eps)*(p.clamp_min(eps)/m.clamp_min(eps)).log()).sum(dim=-1)
        kl2=(q.clamp_min(eps)*(q.clamp_min(eps)/m.clamp_min(eps)).log()).sum(dim=-1)
        return 0.5*(kl1+kl2)
    def update_clusters(self, z_bar):
        with torch.no_grad():
            z = z_bar.detach().float()
            if self.cluster_centers is None:
                self.cluster_centers = z[None,:].to(device=device, dtype=torch.float32)
                self.cluster_counts  = torch.ones(1, dtype=torch.float32, device=device)
                return 0.0
            centers = self.cluster_centers
            diffs = centers - z.unsqueeze(0)
            dists = torch.linalg.vector_norm(diffs, dim=1)
            k=int(torch.argmin(dists).item())
            dmin=float(dists[k].item())
            c=centers[k]; n=self.cluster_counts[k]
            alpha=1.0/(n+1.0)
            self.cluster_centers[k] = (1-alpha)*c + alpha*z   # <- fixed here
            self.cluster_counts[k]  = n+1.0
            if dmin>2.0 and self.cluster_centers.size(0)<16:
                self.cluster_centers = torch.cat([self.cluster_centers, z[None,:]], dim=0)
                self.cluster_counts  = torch.cat([self.cluster_counts, torch.ones(1, device=device)], dim=0)
            return dmin

# ---------------- Helpers ----------------
def compute_capacity(B, E, k):
    cap = int(math.ceil(cfg.capacity_alpha * B * k / max(1,E)))
    return max(1, cap)

@torch.no_grad()
def save_samples(epoch, model):
    model.eval()
    imgs=next(iter(val_loader))[:cfg.sample_rows*cfg.sample_rows].to(device)
    capacity=compute_capacity(imgs.size(0), model.moe.num_experts(), cfg.top_k)
    rec, *_ = model(imgs, top_k=cfg.top_k, tau=1.0, epsilon=0.0, capacity=capacity)
    grid=make_grid(torch.cat([denorm(imgs), denorm(rec)], dim=0), nrow=cfg.sample_rows)
    save_image(grid, os.path.join(cfg.out_samples_dir, f"epoch_{epoch:03d}.png"))
    model.train()

def summarize_epoch(epoch, model, stats, train_loss, val_loss, share, hhi, entropy_mean, cap_hit_rate, approved, declined):
    E=model.moe.num_experts()
    usage=stats.usage_ema.detach().cpu()
    credit=stats.credit.detach().cpu()
    k=min(5,E)
    if usage.sum()>0 and k>0:
        share_vec=(usage/usage.sum())
        s_vals, s_idx=torch.topk(share_vec, k=k)
        top5_share=[(int(s_idx[j]), float(s_vals[j])) for j in range(k)]
    else:
        top5_share=[]
    if k>0:
        c_vals, c_idx=torch.topk(credit, k=k)
        top5_credit=[(int(c_idx[j]), float(c_vals[j])) for j in range(k)]
    else:
        top5_credit=[]
    cov = (usage>0).float().sum().item()

    print(f"\n[Epoch {epoch}] Experts={E} | NeuronScale={E*cfg.hidden} | Coverage={cov}/{E} | "
          f"HHI={hhi:.4f} | Entropy={entropy_mean:.3f} | CapHit={cap_hit_rate:.3f} | "
          f"TrainLoss={train_loss:.4f} | ValLoss={val_loss:.4f} | "
          f"Patterns: approved={approved} declined={declined}")
    print("Top-5 by SHARE (id, share):")
    for eid,s in top5_share: print(f"  - {eid:4d} | {s:.4f}")
    print("Top-5 by CREDIT (id, credit):")
    for eid,c in top5_credit: print(f"  - {eid:4d} | {c:.3f}")

    with open(os.path.join(cfg.out_ckpt_dir, f"epoch_{epoch:03d}_summary.json"),"w") as f:
        json.dump({
            "epoch": epoch,
            "experts_total": E,
            "neuron_scale": E*cfg.hidden,
            "coverage": cov,
            "hhi": hhi,
            "entropy_mean": entropy_mean,
            "capacity_hit_rate": cap_hit_rate,
            "top5_share":[{"expert_id":eid,"share":s} for eid,s in top5_share],
            "top5_credit":[{"expert_id":eid,"credit":c} for eid,c in top5_credit],
            "train_loss": float(train_loss),
            "val_loss": float(val_loss),
            "patterns":{"approved":approved,"declined":declined},
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
        }, f, indent=2)

def share_hhi_from_usage(usage):
    tot=usage.sum().item()+1e-6
    s=usage/tot
    hhi=float((s*s).sum().item())
    return s, hhi

@torch.no_grad()
def loo_probe_contrib(model, stats, tau, capacity, percept):
    if len(val_loader)==0: return
    batch=next(iter(val_loader))[:64].to(device)
    model.eval()
    with torch.amp.autocast("cuda", enabled=cfg.amp):
        x_rec, *_ = model(batch, top_k=cfg.top_k, tau=tau, epsilon=0.0, capacity=capacity)
        base = F.mse_loss(x_rec,batch)
        if percept is not None:
            base = 0.7*base + 0.3*percept(denorm(x_rec).float(), denorm(batch).float()).mean()
    usage=stats.usage_ema
    k=min(16, model.moe.num_experts())
    if usage.sum()>0 and k>0:
        _, top_idx=torch.topk(usage, k=k)
    else:
        model.train(); return
    for e in top_idx.tolist():
        with torch.amp.autocast("cuda", enabled=cfg.amp):
            x_loo, *_ = model(batch, top_k=cfg.top_k, tau=tau, epsilon=0.0, capacity=capacity, ban_expert=int(e))
            loss_loo = F.mse_loss(x_loo,batch)
        dL=float((loss_loo - base).item())
        stats.contrib_ema[e] = 0.9*stats.contrib_ema[e] + 0.1*torch.tensor(max(0.0, dL), device=device)
    model.train()

# ---------------- Build model (OOM-aware) ----------------
def build_model_oom_aware(desired_E: int):
    E = max(8, desired_E)
    while True:
        try:
            model = AutoEncoderMoE(init_E=E).to(device)
            stats  = ExpertStats(E)
            print(f"[Model] Built with E={E} experts")
            return model, stats
        except RuntimeError as e:
            if "out of memory" in str(e).lower() and E>8:
                print(f"[Model] OOM at E={E}. Trying E={E//2} …")
                torch.cuda.empty_cache()
                E = max(8, E//2)
            else:
                raise

model, stats = build_model_oom_aware(cfg.init_experts_desired)
percept = lpips.LPIPS(net='vgg').to(device).eval() if _HAS_LPIPS else None

# SGD optimizer
opt = torch.optim.SGD(model.parameters(), lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
scaler = torch.amp.GradScaler("cuda", enabled=cfg.amp)

# ---------------- Schedules ----------------
def tau_sched(epoch):
    if epoch<=5: return cfg.tau_high
    t=(epoch-5)/max(1,(cfg.epochs-5))
    return cfg.tau_high + (cfg.tau_low-cfg.tau_high)*t
def eps_sched(epoch):
    t=(epoch-1)/max(1,(cfg.epochs-1))
    return cfg.epsilon_start + (cfg.epsilon_end-cfg.epsilon_start)*t
def ent_lambda_sched(epoch):
    return cfg.gate_entropy_lambda_warm if epoch<=5 else cfg.gate_entropy_lambda_cool

# ---------------- Pattern Board (simple; only novelty used) ----------------
class SimpleBoard:
    def __init__(self):
        self.approved=0; self.declined=0
        self.S_ema=None; self.beta=0.98
    def upd(self, S):
        if self.S_ema is None: self.S_ema=S
        else: self.S_ema=self.beta*self.S_ema+(1-self.beta)*S
    def norm(self,S):
        if self.S_ema is None: return 0.0
        return S - self.S_ema

board=SimpleBoard()

# ---------------- Training ----------------
global_step=0
p_avg=None  # EMA of gate distribution

for epoch in range(1, cfg.epochs+1):
    model.train()
    tau=tau_sched(epoch); eps=eps_sched(epoch); ent_l=ent_lambda_sched(epoch)

    running=0.0; cap_hits=0.0; cap_tot=0.0; ent_sum=0.0; ent_cnt=0
    pbar=tqdm(train_loader, desc=f"Epoch {epoch}/{cfg.epochs} (train)")

    for batch in pbar:
        batch=batch.to(device, non_blocking=True)
        B=batch.size(0)
        capacity=compute_capacity(B, model.moe.num_experts(), cfg.top_k)

        opt.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda", enabled=cfg.amp):
            x_rec, z0, z_moe, probs, used_mask, aidx, aw, cap_hit = model(
                batch, top_k=cfg.top_k, tau=tau, epsilon=eps, capacity=capacity
            )
            mse=F.mse_loss(x_rec, batch)
            loss_main = mse
            if _HAS_LPIPS:
                lp = percept(denorm(x_rec).float(), denorm(batch).float()).mean()
                loss_main = 0.7*mse + 0.3*lp
            H=gate_entropy(probs)
            rep=repulsion_loss(model.moe.prototypes, m=cfg.repulsion_subset, sigma=cfg.repulsion_sigma)
            loss = loss_main + ent_l*H + cfg.repulsion_lambda*rep

        scaler.scale(loss).backward()
        scaler.unscale_(opt)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(opt); scaler.update()

        with torch.no_grad():
            E=model.moe.num_experts()
            stats.usage_epoch[:E]+=used_mask.float().sum(dim=0)
            cap_hits += cap_hit; cap_tot+=1.0
            ent_sum  += float(H.item()); ent_cnt+=1

            # novelty S (lower → easier approval)
            p_bar = probs.mean(dim=0)
            if p_avg is None: p_avg = p_bar.detach()
            js = float((0.5*(p_bar.clamp_min(1e-8)*((p_bar.clamp_min(1e-8)/(0.5*(p_bar+p_avg)).clamp_min(1e-8))).log()).sum() +
                        0.5*(p_avg.clamp_min(1e-8)*((p_avg.clamp_min(1e-8)/(0.5*(p_bar+p_avg)).clamp_min(1e-8))).log()).sum()).item())
            p_avg = 0.98*p_avg + 0.02*p_bar.detach()
            z_bar = z0.mean(dim=0)
            dmin = float((model.moe.prototypes - z_bar).pow(2).sum(dim=1).sqrt().min().item())
            S = 0.50*dmin + 0.30*max(0.0, js) + 0.20*float(H.item())
            board.upd(S)
            Sn = board.norm(S)

            # approve if S >= threshold (single reporter)
            if S >= cfg.report_theta:
                board.approved += 1
                # recruit
                share_vec = (stats.usage_ema / (stats.usage_ema.sum()+1e-6))
                E = model.moe.num_experts()
                top_vals, top_idx = torch.topk(share_vec, k=min(cfg.big_topN, E))
                enterprises = set([int(i) for i in top_idx.tolist()])
                enterprises |= set([int(i) for i,v in enumerate(share_vec.tolist()) if v>=cfg.big_share_thr])
                donor_pool = list(enterprises) if len(enterprises)>0 else [int(torch.argmax(used_mask.float().sum(dim=0)).item())]
                Savg=S
                Q = int(cfg.team_base_Q * (1.0 + cfg.team_alpha_S*max(0.0, Savg-cfg.report_theta)))
                new_expert_count = max(1, Q // cfg.adapters_per_expert)
                for _ in range(new_expert_count):
                    k_choose = max(1, min(len(donor_pool), max(1, len(donor_pool)//2)))
                    donors = random.sample(donor_pool, k=k_choose)
                    model.moe.add_expert(clone_from=donors, donors_fraction=0.7, bias_boost=cfg.recruit_new_bias)
                    stats.on_structure_add(new_added=1)
                with torch.no_grad():
                    for d in donor_pool:
                        model.moe.logit_bias.data[d] += cfg.recruit_donor_bias
                print(f"[Growth] Approved pattern S={S:.2f} → +{new_expert_count} experts (E={model.moe.num_experts()})")

        running+=float(loss_main.item())
        pbar.set_postfix(loss=f"{loss_main.item():.4f}", mse=f"{mse.item():.4f}", H=f"{H.item():.3f}",
                         tau=f"{tau:.2f}", eps=f"{eps:.3f}", S=f"{S:.2f}", Sn=f"{Sn:.2f}")
        global_step+=1

    # end epoch usage decay
    stats.epoch_decay(decay=0.9)

    # Validation
    model.eval(); val_loss=0.0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch}/{cfg.epochs} (val)"):
            batch=batch.to(device, non_blocking=True)
            capacity=compute_capacity(batch.size(0), model.moe.num_experts(), cfg.top_k)
            with torch.amp.autocast("cuda", enabled=cfg.amp):
                x_rec, *_ = model(batch, top_k=cfg.top_k, tau=tau, epsilon=0.0, capacity=capacity)
                loss_v = F.mse_loss(x_rec,batch)
                if _HAS_LPIPS:
                    lp = percept(denorm(x_rec).float(), denorm(batch).float()).mean()
                    loss_v = 0.7*loss_v + 0.3*lp
            val_loss += float(loss_v.item())*batch.size(0)
    val_loss/=max(1,len(val_ds))

    # LOO probe (light)
    loo_probe_contrib(model, stats, tau=tau, capacity=compute_capacity(64, model.moe.num_experts(), cfg.top_k), percept=percept)

    # Share/HHI
    share_vec, hhi = share_hhi_from_usage(stats.usage_ema)
    entropy_mean = ent_sum/max(1,ent_cnt)
    cap_hit_rate = cap_hits/max(1,cap_tot)

    # --------- Credit settle ----------
    with torch.no_grad():
        approved_reward = cfg.credit_true_reward * float(board.approved)
        declined_cost  = cfg.credit_false_penalty * float(board.declined)
        stats.credit += torch.tensor(approved_reward - declined_cost, device=device) * (share_vec / (share_vec.sum()+1e-6))
        stats.credit += cfg.credit_contrib_scale * stats.contrib_ema
        div_weight = (1.0 - (share_vec**2))
        stats.credit += cfg.credit_diversity_scale * float(max(0.0, entropy_mean)) * div_weight
        over = (share_vec - cfg.share_tax_thr).clamp_min(0.0)
        stats.credit -= cfg.share_tax_lambda * over
        stats.credit *= cfg.credit_decay
        neg_mask = (stats.credit<0)
        stats.credit[neg_mask] = stats.credit[neg_mask] * (1.0 + cfg.neg_interest)
        stats.neg_streak[stats.credit<=-cfg.bankrupcy_level] += 1
        stats.neg_streak[stats.credit>-cfg.bankrupcy_level] = 0

    # --------- PRUNE ----------
    if cfg.enable_prune:
        with torch.no_grad():
            E=model.moe.num_experts()
            if E>cfg.prune_min_keep:
                usage = stats.usage_ema
                share = (usage/(usage.sum()+1e-6))
                cand_mask = (usage <= cfg.prune_usage_eps) & (stats.credit <= -cfg.prune_credit_level) & (stats.neg_streak >= cfg.bankrupcy_streak)
                cand_idx = torch.nonzero(cand_mask, as_tuple=False).view(-1).tolist()
                top_keep = torch.topk(share, k=min(64,E)).indices.tolist()
                cand_idx = [i for i in cand_idx if i not in set(top_keep)]
                max_prune = min(cfg.prune_max_per_epoch, max(0, E - cfg.prune_min_keep))
                cand_idx = cand_idx[:max_prune]
                if cand_idx:
                    keep = [i for i in range(E) if i not in set(cand_idx)]
                    removed = model.moe.remove_experts(cand_idx)
                    stats.on_structure_remove(keep)
                    print(f"[Structure] Pruned {removed} experts → E={model.moe.num_experts()}")

    # Save sample + ckpt
    save_samples(epoch, model)
    torch.save({
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": opt.state_dict(),
        "stats": {
            "usage_ema": stats.usage_ema.detach().cpu().tolist(),
            "contrib_ema": stats.contrib_ema.detach().cpu().tolist(),
            "credit": stats.credit.detach().cpu().tolist(),
            "neg_streak": stats.neg_streak.detach().cpu().tolist(),
        },
        "config": cfg.__dict__,
    }, os.path.join(cfg.out_ckpt_dir, f"epoch_{epoch:03d}.pt"))

    # Epoch summary
    summarize_epoch(epoch, model, stats,
                    train_loss=running/len(train_loader), val_loss=val_loss,
                    share=share_vec, hhi=hhi, entropy_mean=entropy_mean, cap_hit_rate=cap_hit_rate,
                    approved=board.approved, declined=board.declined)

print("✅ Training complete. See ./checkpoints & ./samples")


SyntaxError: cannot assign to expression (ipython-input-3280588967.py, line 481)