In [None]:
# Cloud Environment Setup
import subprocess, sys, os
def install(pkg): subprocess.check_call([sys.executable, "-m", "pip", "install", pkg, "-i", "https://pypi.tuna.tsinghua.edu.cn/simple"])
try:
    import torch; print(f"PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}")
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "--index-url", "https://download.pytorch.org/whl/cu118"])
for m, p in [("numpy","numpy"),("pandas","pandas"),("sklearn","scikit-learn"),("networkx","networkx"),("requests","requests"),("scipy","scipy")]:
    try: __import__(m)
    except: install(p)
if os.getcwd() not in sys.path: sys.path.append(os.getcwd())
print("Done.")

In [None]:
# Create civae library inline
os.makedirs('civae', exist_ok=True)
open('civae/__init__.py','w').write('from .dag import DAG\nfrom .model import CausalIVAE\n')
open('civae/dag.py','w').write('import numpy as np, networkx as nx\nclass DAG:\n    def __init__(s, adj): s.A, s.d = adj, adj.shape[0]; s.graph = nx.from_numpy_array(s.A, create_using=nx.DiGraph); s.topo_order = list(nx.topological_sort(s.graph)); s.roots = [i for i in range(s.d) if s.graph.in_degree(i)==0]\n    def parents(s, i): return set(s.graph.predecessors(i))\n    def children(s, i): return set(s.graph.successors(i))\n    def markov_blanket(s, i): mb = s.parents(i)|s.children(i); [mb.update(s.parents(c)) for c in s.children(i)]; mb.discard(i); return mb\n    def level(s, i): return len(nx.ancestors(s.graph, i))\n')
open('civae/prior.py','w').write('import torch, torch.nn as nn\nclass CausalPrior(nn.Module):\n    def __init__(s, z_dim, u_dim, num_envs, hid=128):\n        super().__init__(); s.z_dim = z_dim; s.env_embed = nn.Embedding(num_envs, hid//4)\n        s.root_net = nn.Sequential(nn.Linear(u_dim+hid//4, hid), nn.ReLU(), nn.Linear(hid, hid), nn.ReLU())\n        s.root_mu, s.root_lv = nn.Linear(hid, z_dim), nn.Linear(hid, z_dim)\n        s.pk, s.pq, s.pv = nn.Linear(z_dim+u_dim, hid), nn.Linear(u_dim, hid), nn.Linear(z_dim+u_dim, hid)\n        s.cond_net = nn.Sequential(nn.Linear(hid+u_dim, hid), nn.ReLU(), nn.Linear(hid, hid), nn.ReLU())\n        s.cond_mu, s.cond_lv = nn.Linear(hid, z_dim), nn.Linear(hid, z_dim)\n    def forward(s, pz, pu, ui, env, is_root):\n        bs, dev = ui.size(0), ui.device; ui = ui.unsqueeze(0).expand(bs,-1) if ui.dim()==1 else ui\n        if is_root: h = s.root_net(torch.cat([ui, s.env_embed(env)], -1)); return s.root_mu(h), s.root_lv(h)\n        if not pz: return torch.zeros(bs, s.z_dim, device=dev), torch.zeros(bs, s.z_dim, device=dev)\n        ps = torch.cat([torch.stack(pz,1), torch.stack([p.unsqueeze(0).expand(bs,-1) if p.dim()==1 else p for p in pu],1)], -1)\n        K, Q, V = s.pk(ps), s.pq(ui).unsqueeze(1), s.pv(ps)\n        h = s.cond_net(torch.cat([torch.bmm(torch.softmax(torch.bmm(Q, K.transpose(1,2))/K.size(-1)**0.5, -1), V).squeeze(1), ui], -1))\n        return s.cond_mu(h), s.cond_lv(h)\n')
open('civae/mb_encoder.py','w').write('import torch, torch.nn as nn\nclass MBEncoder(nn.Module):\n    def __init__(s, n, dim=64): super().__init__(); s.ne = nn.Embedding(n, dim); s.phi = nn.Sequential(nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, dim)); s.rho = nn.Sequential(nn.Linear(dim*2+4, dim), nn.ReLU(), nn.Linear(dim, dim)); s.empty = nn.Parameter(torch.zeros(dim))\n    def forward(s, dag, dev):\n        u = {}\n        for i in range(dag.d): mb = dag.markov_blanket(i); se = s.ne(torch.tensor(i, device=dev)); ma = s.phi(s.ne(torch.tensor(list(mb), device=dev))).sum(0) if mb else s.empty; u[i] = s.rho(torch.cat([ma, se, torch.tensor([dag.level(i), len(dag.parents(i)), len(dag.children(i)), len(mb)], dtype=torch.float32, device=dev)]))\n        return u\n')
open('civae/model.py','w').write('import torch, torch.nn as nn, torch.nn.functional as F\nfrom .mb_encoder import MBEncoder\nfrom .prior import CausalPrior\nclass CausalEncoder(nn.Module):\n    def __init__(s, in_dim, z_dim, n, u_dim, num_envs, hid=256):\n        super().__init__(); s.z_dim, s.n = z_dim, n; s.env_embed = nn.Embedding(num_envs, hid//4)\n        s.feat_net = nn.Sequential(nn.Linear(in_dim+hid//4, hid), nn.ReLU(), nn.Linear(hid, hid), nn.ReLU())\n        s.node_nets = nn.ModuleList([nn.Sequential(nn.Linear(hid+u_dim+z_dim*n, hid//2), nn.ReLU()) for _ in range(n)])\n        s.mu_heads = nn.ModuleList([nn.Linear(hid//2, z_dim) for _ in range(n)])\n        s.lv_heads = nn.ModuleList([nn.Linear(hid//2, z_dim) for _ in range(n)])\n    def forward(s, x, u, env, dag, topo):\n        bs, dev = x.size(0), x.device; feat = s.feat_net(torch.cat([x, s.env_embed(env)], -1))\n        zs, zm, zl = {}, {}, {}\n        for i in topo: ui = u[i].unsqueeze(0).expand(bs,-1); pa = sorted(dag.parents(i)); pz = torch.cat([zs[j] for j in pa], -1) if pa else torch.zeros(bs, 0, device=dev); h = s.node_nets[i](torch.cat([feat, ui, F.pad(pz, (0, s.z_dim*s.n-pz.size(-1)))], -1)); mu, lv = s.mu_heads[i](h), s.lv_heads[i](h); zs[i] = mu + torch.randn_like(mu)*torch.exp(0.5*lv); zm[i], zl[i] = mu, lv\n        return zs, zm, zl\nclass Decoder(nn.Module):\n    def __init__(s, z_dim, out_dim, hid=256): super().__init__(); s.net = nn.Sequential(nn.Linear(z_dim, hid), nn.ReLU(), nn.Linear(hid, hid), nn.ReLU(), nn.Linear(hid, out_dim))\n    def forward(s, z): return s.net(z)\nclass CausalIVAE(nn.Module):\n    def __init__(s, in_dim, z_dim, dag, num_envs=1, u_dim=64, hid=256):\n        super().__init__(); s.dag, s.z_dim, s.d, s.roots = dag, z_dim, dag.d, set(dag.roots)\n        s.struct_enc = MBEncoder(dag.d, u_dim); s.encoder = CausalEncoder(in_dim, z_dim, dag.d, u_dim, num_envs, hid)\n        s.prior = CausalPrior(z_dim, u_dim, num_envs, hid); s.decoder = Decoder(z_dim*dag.d, in_dim, hid)\n    def forward(s, x, env=None):\n        dev = x.device; env = env if env is not None else torch.zeros(x.size(0), dtype=torch.long, device=dev)\n        u = s.struct_enc(s.dag, dev); zs, zm, zl = s.encoder(x, u, env, s.dag, s.dag.topo_order)\n        s._u, s._env = u, env; return s.decoder(torch.cat([zs[i] for i in range(s.d)], -1)), zs, zm, zl\n    def compute_loss(s, x, xr, zs, zm, zl, env=None):\n        bs, u, env = x.size(0), s._u, env if env is not None else s._env\n        rl = F.mse_loss(xr, x, reduction="sum")/bs; kl = 0\n        for i in s.dag.topo_order: pa = sorted(s.dag.parents(i)); pm, pl = s.prior([zs[j] for j in pa], [u[j] for j in pa], u[i].unsqueeze(0).expand(bs,-1), env, i in s.roots); v1, v2 = torch.exp(zl[i]), torch.exp(pl); kl += (0.5*(pl-zl[i]+v1/v2+(zm[i]-pm)**2/v2-1).sum(-1)).sum()/bs\n        return {"loss": rl+kl, "recon_loss": rl, "kl_loss": kl}\n')
print("Library created.")

In [None]:
"""Experiment 11: Sachs Cell Signaling (Fixed Environment Generation)"""
import torch, torch.nn as nn, torch.nn.functional as F
import numpy as np, pandas as pd, requests, io
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
from civae import DAG, CausalIVAE

# === Config ===
EPOCHS = 100  # Increased from 50
BATCH_SIZE = 128
LR = 1e-3
HIDDEN_DIM = 256  # Increased from 128

# === Baselines ===
class SimpleVAE(nn.Module):
    def __init__(self, in_dim, z_dim, hid=HIDDEN_DIM):
        super().__init__()
        self.enc = nn.Sequential(nn.Linear(in_dim, hid), nn.ReLU(), nn.Linear(hid, hid), nn.ReLU(), nn.Linear(hid, z_dim*2))
        self.dec = nn.Sequential(nn.Linear(z_dim, hid), nn.ReLU(), nn.Linear(hid, hid), nn.ReLU(), nn.Linear(hid, in_dim))
    def encode(self, x): h = self.enc(x); return h[:, :h.size(1)//2], h[:, h.size(1)//2:]
    def forward(self, x): mu, lv = self.encode(x); return self.dec(mu + torch.randn_like(mu)*torch.exp(0.5*lv)), mu, lv

class iVAE(nn.Module):
    def __init__(self, in_dim, z_dim, n_env, hid=HIDDEN_DIM):
        super().__init__()
        self.emb = nn.Embedding(n_env, hid//2)
        self.enc = nn.Sequential(nn.Linear(in_dim+hid//2, hid), nn.ReLU(), nn.Linear(hid, hid), nn.ReLU(), nn.Linear(hid, z_dim*2))
        self.dec = nn.Sequential(nn.Linear(z_dim, hid), nn.ReLU(), nn.Linear(hid, hid), nn.ReLU(), nn.Linear(hid, in_dim))
        self.p_net = nn.Sequential(nn.Linear(hid//2, hid), nn.ReLU(), nn.Linear(hid, z_dim*2))
    def encoder(self, x, e): h = self.enc(torch.cat([x, self.emb(e)],-1)); return h, h[:, :h.size(1)//2], h[:, h.size(1)//2:]
    def prior(self, e): h = self.p_net(self.emb(e)); return h[:, :h.size(1)//2], h[:, h.size(1)//2:]
    def forward(self, x, e): 
        _, mu, lv = self.encoder(x, e)
        z = mu + torch.randn_like(mu)*torch.exp(0.5*lv)
        xr = self.dec(z)
        pm, plv = self.prior(e)
        return xr, mu, lv, pm, plv

class CausalVAE(nn.Module):
    def __init__(self, in_dim, z_dim, adj, hid=HIDDEN_DIM):
        super().__init__()
        self.enc = nn.Sequential(nn.Linear(in_dim, hid), nn.ReLU(), nn.Linear(hid, hid), nn.ReLU(), nn.Linear(hid, z_dim*2))
        self.dec = nn.Sequential(nn.Linear(z_dim, hid), nn.ReLU(), nn.Linear(hid, hid), nn.ReLU(), nn.Linear(hid, in_dim))
        self.mask = nn.Parameter(torch.from_numpy(adj).float().t(), requires_grad=False)
        self.cw = nn.Parameter(torch.randn(z_dim, z_dim) * 0.01)
    def encode(self, x): h = self.enc(x); return h[:, :h.size(1)//2], h[:, h.size(1)//2:]
    def forward(self, x): 
        mu, lv = self.encode(x); z = mu + torch.randn_like(mu)*torch.exp(0.5*lv)
        xr = self.dec(z)
        pm = F.linear(z, self.cw * self.mask)
        return xr, mu, lv, pm

# === Metrics ===
def compute_mcc(z_true, z_pred):
    """Mean Correlation Coefficient with Hungarian matching."""
    from scipy.optimize import linear_sum_assignment
    d = z_true.shape[1]
    if z_pred.shape[1] != d:
       if z_pred.shape[1] > d: z_pred = z_pred[:, :d]
       else: return 0.0
    corrs = np.corrcoef(z_true.T, z_pred.T)[:d, d:]
    row, col = linear_sum_assignment(-np.abs(corrs))
    return np.abs(corrs[row, col]).mean()

def compute_consistency(z, dag):
    """Structural consistency: how well z follows DAG structure."""
    from sklearn.neural_network import MLPRegressor
    from sklearn.metrics import r2_score
    r2 = []
    for i in range(dag.d):
        pa = list(dag.parents(i))
        if pa:
            if z.shape[1] <= max(pa + [i]): continue 
            m = MLPRegressor((64,32), max_iter=1000, early_stopping=True).fit(z[:,pa], z[:,i])
            r2.append(max(0, r2_score(z[:,i], m.predict(z[:,pa]))))
    return np.mean(r2) if r2 else 1.0

# === Data & DAG ===
SACHS_NODES = ['Raf', 'Mek', 'PLCg', 'PIP2', 'PIP3', 'Erk', 'Akt', 'PKA', 'PKC', 'p38', 'Jnk']
SACHS_EDGES = [
    (0, 1), (1, 5), (2, 3), (2, 8), (3, 8), (4, 3), (4, 6), 
    (7, 0), (7, 1), (7, 5), (7, 6), (7, 9), (7, 10), 
    (8, 0), (8, 1), (8, 9), (8, 10)
]
def create_sachs_dag():
    d = len(SACHS_NODES); adj = np.zeros((d, d))
    for i, j in SACHS_EDGES: adj[i, j] = 1
    return DAG(adj)

def download_sachs():
    url = "https://raw.githubusercontent.com/ElementAI/causal_discovery_toolbox/master/cdt/data/resources/cyto_full_data.csv"
    try: 
        return pd.read_csv(io.StringIO(requests.get(url, timeout=30).text)).values.astype(np.float32)
    except: 
        print("Download failed. Creating synthetic..."); np.random.seed(42)
        return np.random.randn(7466, 11).astype(np.float32)

def create_environments(x, dag, num_envs, seed=42):
    """Create environments based on ROOT NODE values (key fix!)."""
    np.random.seed(seed)
    roots = list(dag.roots)
    print(f"  DAG roots: {[SACHS_NODES[r] for r in roots]} (indices: {roots})")
    
    if len(roots) > 0:
        # Use first root node's values to define environments
        root_values = x[:, roots[0]]  # Key: use ROOT, not x[:,0]
        percentiles = np.percentile(root_values, np.linspace(0, 100, num_envs + 1))
        env = np.digitize(root_values, percentiles[1:-1])
    else:
        env = np.random.randint(0, num_envs, x.shape[0])
    
    # Print environment distribution
    unique, counts = np.unique(env, return_counts=True)
    print(f"  Environment distribution: {dict(zip(unique, counts))}")
    return env.astype(np.int64)

# === Main ===
def run():
    print("="*60 + "\nEXPERIMENT 11: Sachs (Fixed Root-Based Environment)\n" + "="*60)
    dev = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Device: {dev}, Epochs: {EPOCHS}, Hidden: {HIDDEN_DIM}")
    
    x_raw = download_sachs(); dag = create_sachs_dag()
    x = StandardScaler().fit_transform(x_raw).astype(np.float32)
    ne = 5  # Use 5 environments as in full version
    env = create_environments(x, dag, ne)
    
    results = {}
    
    # 1. VAE
    print("\n[1/4] Training VAE...")
    m = SimpleVAE(x.shape[1], dag.d).to(dev); opt = optim.Adam(m.parameters(), lr=LR)
    ldr = DataLoader(TensorDataset(torch.from_numpy(x).to(dev)), batch_size=BATCH_SIZE, shuffle=True)
    for ep in range(EPOCHS): 
        for xb, in ldr: 
            opt.zero_grad(); xr, mu, lv = m(xb); 
            (F.mse_loss(xr, xb, reduction='sum')/xb.size(0) - 0.5*torch.sum(1+lv-mu.pow(2)-lv.exp())/xb.size(0)).backward(); opt.step()
    z = m.encode(torch.from_numpy(x).to(dev))[0].detach().cpu().numpy()
    results['VAE'] = {'MCC': compute_mcc(x, z), 'Consistency': compute_consistency(z, dag)}
    print(f"  VAE: MCC={results['VAE']['MCC']:.4f}, Consistency={results['VAE']['Consistency']:.4f}")

    # 2. iVAE
    print("\n[2/4] Training iVAE...")
    m = iVAE(x.shape[1], dag.d, ne).to(dev); opt = optim.Adam(m.parameters(), lr=LR)
    ldr = DataLoader(TensorDataset(torch.from_numpy(x).to(dev), torch.from_numpy(env).to(dev)), batch_size=BATCH_SIZE, shuffle=True)
    for ep in range(EPOCHS):
        for xb, eb in ldr:
            opt.zero_grad(); xr, mu, lv, pm, plv = m(xb, eb)
            rl = F.mse_loss(xr, xb, reduction='sum')/xb.size(0)
            kl = 0.5*(plv-lv + torch.exp(lv)/(torch.exp(plv)+1e-8) + (mu-pm)**2/(torch.exp(plv)+1e-8) - 1).sum()/xb.size(0)
            (rl + kl).backward(); opt.step()
    z = m.encoder(torch.from_numpy(x).to(dev), torch.from_numpy(env).to(dev))[1].detach().cpu().numpy()
    results['iVAE'] = {'MCC': compute_mcc(x, z), 'Consistency': compute_consistency(z, dag)}
    print(f"  iVAE: MCC={results['iVAE']['MCC']:.4f}, Consistency={results['iVAE']['Consistency']:.4f}")

    # 3. CausalVAE
    print("\n[3/4] Training CausalVAE...")
    m = CausalVAE(x.shape[1], dag.d, dag.A).to(dev); opt = optim.Adam(m.parameters(), lr=LR)
    ldr = DataLoader(TensorDataset(torch.from_numpy(x).to(dev)), batch_size=BATCH_SIZE, shuffle=True)
    for ep in range(EPOCHS):
        for xb, in ldr:
            opt.zero_grad(); xr, mu, lv, pm = m(xb)
            rl = F.mse_loss(xr, xb, reduction='sum')/xb.size(0)
            kl = -0.5*torch.sum(1+lv-(mu-pm).pow(2)-lv.exp())/xb.size(0)
            (rl + kl).backward(); opt.step()
    z = m.encode(torch.from_numpy(x).to(dev))[0].detach().cpu().numpy()
    results['CausalVAE'] = {'MCC': compute_mcc(x, z), 'Consistency': compute_consistency(z, dag)}
    print(f"  CausalVAE: MCC={results['CausalVAE']['MCC']:.4f}, Consistency={results['CausalVAE']['Consistency']:.4f}")

    # 4. C-iVAE
    print("\n[4/4] Training C-iVAE...")
    m = CausalIVAE(x.shape[1], 1, dag, ne, 64, HIDDEN_DIM).to(dev); opt = optim.Adam(m.parameters(), lr=LR)
    ldr = DataLoader(TensorDataset(torch.from_numpy(x).to(dev), torch.from_numpy(env).to(dev)), batch_size=BATCH_SIZE, shuffle=True)
    for ep in range(EPOCHS):
        for xb, eb in ldr:
            opt.zero_grad(); xr, z, zm, zl = m(xb, eb)
            m.compute_loss(xb, xr, z, zm, zl, eb)['loss'].backward(); opt.step()
    m.eval()
    with torch.no_grad(): _, zs, _, _ = m(torch.from_numpy(x).to(dev), torch.from_numpy(env).to(dev))
    z = torch.cat([zs[i] for i in range(dag.d)], -1).cpu().numpy()
    results['C-iVAE'] = {'MCC': compute_mcc(x, z), 'Consistency': compute_consistency(z, dag)}
    print(f"  C-iVAE: MCC={results['C-iVAE']['MCC']:.4f}, Consistency={results['C-iVAE']['Consistency']:.4f}")

    # Summary
    print("\n" + "="*60 + "\nSUMMARY\n" + "="*60)
    print(f"{'Model':<12} {'MCC':>8} {'Consistency':>12}")
    print("-"*34)
    for name in ['VAE', 'iVAE', 'CausalVAE', 'C-iVAE']:
        print(f"{name:<12} {results[name]['MCC']:>8.4f} {results[name]['Consistency']:>12.4f}")
    print("Done.")

run()