In [None]:
# Cloud Environment Setup
import subprocess, sys, os
def install(pkg): subprocess.check_call([sys.executable, "-m", "pip", "install", pkg, "-i", "https://pypi.tuna.tsinghua.edu.cn/simple"])
try:
    import torch; print(f"PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}")
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "--index-url", "https://download.pytorch.org/whl/cu118"])
for m, p in [("numpy","numpy"),("pandas","pandas"),("sklearn","scikit-learn"),("networkx","networkx"),("requests","requests")]:
    try: __import__(m)
    except: install(p)
if os.getcwd() not in sys.path: sys.path.append(os.getcwd())
print("Done.")

In [None]:
# Create civae library inline
os.makedirs('civae', exist_ok=True)
open('civae/__init__.py','w').write('from .dag import DAG\nfrom .model import CausalIVAE\n')
open('civae/dag.py','w').write('import numpy as np, networkx as nx\nclass DAG:\n    def __init__(s, adj): s.A, s.d = adj, adj.shape[0]; s.graph = nx.from_numpy_array(s.A, create_using=nx.DiGraph); s.topo_order = list(nx.topological_sort(s.graph)); s.roots = [i for i in range(s.d) if s.graph.in_degree(i)==0]\n    def parents(s, i): return set(s.graph.predecessors(i))\n    def children(s, i): return set(s.graph.successors(i))\n    def markov_blanket(s, i): mb = s.parents(i)|s.children(i); [mb.update(s.parents(c)) for c in s.children(i)]; mb.discard(i); return mb\n    def level(s, i): return len(nx.ancestors(s.graph, i))\n')
open('civae/prior.py','w').write('import torch, torch.nn as nn\nclass CausalPrior(nn.Module):\n    def __init__(s, z_dim, u_dim, num_envs, hid=128):\n        super().__init__(); s.z_dim = z_dim; s.env_embed = nn.Embedding(num_envs, hid//4)\n        s.root_net = nn.Sequential(nn.Linear(u_dim+hid//4, hid), nn.ReLU(), nn.Linear(hid, hid), nn.ReLU())\n        s.root_mu, s.root_lv = nn.Linear(hid, z_dim), nn.Linear(hid, z_dim)\n        s.pk, s.pq, s.pv = nn.Linear(z_dim+u_dim, hid), nn.Linear(u_dim, hid), nn.Linear(z_dim+u_dim, hid)\n        s.cond_net = nn.Sequential(nn.Linear(hid+u_dim, hid), nn.ReLU(), nn.Linear(hid, hid), nn.ReLU())\n        s.cond_mu, s.cond_lv = nn.Linear(hid, z_dim), nn.Linear(hid, z_dim)\n    def forward(s, pz, pu, ui, env, is_root):\n        bs, dev = ui.size(0), ui.device; ui = ui.unsqueeze(0).expand(bs,-1) if ui.dim()==1 else ui\n        if is_root: h = s.root_net(torch.cat([ui, s.env_embed(env)], -1)); return s.root_mu(h), s.root_lv(h)\n        if not pz: return torch.zeros(bs, s.z_dim, device=dev), torch.zeros(bs, s.z_dim, device=dev)\n        ps = torch.cat([torch.stack(pz,1), torch.stack([p.unsqueeze(0).expand(bs,-1) if p.dim()==1 else p for p in pu],1)], -1)\n        K, Q, V = s.pk(ps), s.pq(ui).unsqueeze(1), s.pv(ps)\n        h = s.cond_net(torch.cat([torch.bmm(torch.softmax(torch.bmm(Q, K.transpose(1,2))/K.size(-1)**0.5, -1), V).squeeze(1), ui], -1))\n        return s.cond_mu(h), s.cond_lv(h)\n')
open('civae/mb_encoder.py','w').write('import torch, torch.nn as nn\nclass MBEncoder(nn.Module):\n    def __init__(s, n, dim=64): super().__init__(); s.ne = nn.Embedding(n, dim); s.phi = nn.Sequential(nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, dim)); s.rho = nn.Sequential(nn.Linear(dim*2+4, dim), nn.ReLU(), nn.Linear(dim, dim)); s.empty = nn.Parameter(torch.zeros(dim))\n    def forward(s, dag, dev):\n        u = {}\n        for i in range(dag.d): mb = dag.markov_blanket(i); se = s.ne(torch.tensor(i, device=dev)); ma = s.phi(s.ne(torch.tensor(list(mb), device=dev))).sum(0) if mb else s.empty; u[i] = s.rho(torch.cat([ma, se, torch.tensor([dag.level(i), len(dag.parents(i)), len(dag.children(i)), len(mb)], dtype=torch.float32, device=dev)]))\n        return u\n')
open('civae/model.py','w').write('import torch, torch.nn as nn, torch.nn.functional as F\nfrom .mb_encoder import MBEncoder\nfrom .prior import CausalPrior\nclass CausalEncoder(nn.Module):\n    def __init__(s, in_dim, z_dim, n, u_dim, num_envs, hid=256):\n        super().__init__(); s.z_dim, s.n = z_dim, n; s.env_embed = nn.Embedding(num_envs, hid//4)\n        s.feat_net = nn.Sequential(nn.Linear(in_dim+hid//4, hid), nn.ReLU(), nn.Linear(hid, hid), nn.ReLU())\n        s.node_nets = nn.ModuleList([nn.Sequential(nn.Linear(hid+u_dim+z_dim*n, hid//2), nn.ReLU()) for _ in range(n)])\n        s.mu_heads = nn.ModuleList([nn.Linear(hid//2, z_dim) for _ in range(n)])\n        s.lv_heads = nn.ModuleList([nn.Linear(hid//2, z_dim) for _ in range(n)])\n    def forward(s, x, u, env, dag, topo):\n        bs, dev = x.size(0), x.device; feat = s.feat_net(torch.cat([x, s.env_embed(env)], -1))\n        zs, zm, zl = {}, {}, {}\n        for i in topo: ui = u[i].unsqueeze(0).expand(bs,-1); pa = sorted(dag.parents(i)); pz = torch.cat([zs[j] for j in pa], -1) if pa else torch.zeros(bs, 0, device=dev); h = s.node_nets[i](torch.cat([feat, ui, F.pad(pz, (0, s.z_dim*s.n-pz.size(-1)))], -1)); mu, lv = s.mu_heads[i](h), s.lv_heads[i](h); zs[i] = mu + torch.randn_like(mu)*torch.exp(0.5*lv); zm[i], zl[i] = mu, lv\n        return zs, zm, zl\nclass Decoder(nn.Module):\n    def __init__(s, z_dim, out_dim, hid=256): super().__init__(); s.net = nn.Sequential(nn.Linear(z_dim, hid), nn.ReLU(), nn.Linear(hid, hid), nn.ReLU(), nn.Linear(hid, out_dim))\n    def forward(s, z): return s.net(z)\nclass CausalIVAE(nn.Module):\n    def __init__(s, in_dim, z_dim, dag, num_envs=1, u_dim=64, hid=256):\n        super().__init__(); s.dag, s.z_dim, s.d, s.roots = dag, z_dim, dag.d, set(dag.roots)\n        s.struct_enc = MBEncoder(dag.d, u_dim); s.encoder = CausalEncoder(in_dim, z_dim, dag.d, u_dim, num_envs, hid)\n        s.prior = CausalPrior(z_dim, u_dim, num_envs, hid); s.decoder = Decoder(z_dim*dag.d, in_dim, hid)\n    def forward(s, x, env=None):\n        dev = x.device; env = env if env is not None else torch.zeros(x.size(0), dtype=torch.long, device=dev)\n        u = s.struct_enc(s.dag, dev); zs, zm, zl = s.encoder(x, u, env, s.dag, s.dag.topo_order)\n        s._u, s._env = u, env; return s.decoder(torch.cat([zs[i] for i in range(s.d)], -1)), zs, zm, zl\n    def compute_loss(s, x, xr, zs, zm, zl, env=None):\n        bs, u, env = x.size(0), s._u, env if env is not None else s._env\n        rl = F.mse_loss(xr, x, reduction="sum")/bs; kl = 0\n        for i in s.dag.topo_order: pa = sorted(s.dag.parents(i)); pm, pl = s.prior([zs[j] for j in pa], [u[j] for j in pa], u[i].unsqueeze(0).expand(bs,-1), env, i in s.roots); v1, v2 = torch.exp(zl[i]), torch.exp(pl); kl += (0.5*(pl-zl[i]+v1/v2+(zm[i]-pm)**2/v2-1).sum(-1)).sum()/bs\n        return {"loss": rl+kl, "recon_loss": rl, "kl_loss": kl}\n')
print("Library created.")

In [None]:
"""Experiment 11: Sachs Cell Signaling (Self-contained)"""
import torch, numpy as np, pandas as pd, requests, io
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
from civae import DAG, CausalIVAE

# Sachs DAG (11 proteins, 17 edges - from Sachs et al. 2005)
SACHS_NODES = ['Raf', 'Mek', 'PLCg', 'PIP2', 'PIP3', 'Erk', 'Akt', 'PKA', 'PKC', 'p38', 'Jnk']
SACHS_EDGES = [
    (0, 1),   # Raf -> Mek
    (1, 5),   # Mek -> Erk
    (2, 3),   # PLCg -> PIP2
    (2, 8),   # PLCg -> PKC
    (3, 8),   # PIP2 -> PKC
    (4, 3),   # PIP3 -> PIP2
    (4, 6),   # PIP3 -> Akt
    (7, 0),   # PKA -> Raf
    (7, 1),   # PKA -> Mek
    (7, 5),   # PKA -> Erk
    (7, 6),   # PKA -> Akt
    (7, 9),   # PKA -> p38
    (7, 10),  # PKA -> Jnk
    (8, 0),   # PKC -> Raf
    (8, 1),   # PKC -> Mek
    (8, 9),   # PKC -> p38
    (8, 10),  # PKC -> Jnk
]

def create_sachs_dag():
    d = len(SACHS_NODES)
    adj = np.zeros((d, d))
    for i, j in SACHS_EDGES:
        adj[i, j] = 1
    return DAG(adj)

def download_sachs():
    """Download Sachs dataset from CDT (ElementAI)."""
    url = "https://raw.githubusercontent.com/ElementAI/causal_discovery_toolbox/master/cdt/data/resources/cyto_full_data.csv"
    print(f"Downloading Sachs data from {url}...")
    try:
        resp = requests.get(url, timeout=30)
        resp.raise_for_status()
        df = pd.read_csv(io.StringIO(resp.text))
        print(f"Downloaded: {df.shape}")
        return df.values.astype(np.float32)
    except Exception as e:
        print(f"Download failed: {e}. Creating synthetic data...")
        return create_synthetic_sachs()

def create_synthetic_sachs():
    """Generate synthetic Sachs-like data."""
    np.random.seed(42)
    n, d = 7466, len(SACHS_NODES)
    adj = np.zeros((d, d))
    for i, j in SACHS_EDGES: adj[i, j] = 1
    # Topological order
    in_deg = [int(adj[:, i].sum()) for i in range(d)]
    queue, topo = [i for i in range(d) if in_deg[i] == 0], []
    while queue:
        node = queue.pop(0); topo.append(node)
        for c in range(d):
            if adj[node, c]: in_deg[c] -= 1
            if in_deg[c] == 0 and c not in topo and c not in queue: queue.append(c)
    z = np.zeros((n, d))
    for i in topo:
        pa = np.where(adj[:, i] == 1)[0]
        if len(pa) == 0: z[:, i] = np.random.randn(n)
        else: z[:, i] = np.tanh(z[:, pa].sum(1) * 0.5) + np.random.randn(n) * 0.5
    return z.astype(np.float32)

def train(dag, x, env, ne, dev):
    m = CausalIVAE(x.shape[1], 1, dag, ne, 64).to(dev)
    opt = optim.Adam(m.parameters(), lr=1e-3)
    ldr = DataLoader(TensorDataset(torch.from_numpy(x).to(dev), torch.from_numpy(env).to(dev)), batch_size=128, shuffle=True)
    m.train()
    for ep in range(100):
        for xb, eb in ldr: 
            opt.zero_grad()
            xr, z, zm, zl = m(xb, eb)
            m.compute_loss(xb, xr, z, zm, zl, eb)['loss'].backward()
            opt.step()
        if ep % 20 == 0: print(f"Epoch {ep}")
    return m

def run():
    print("="*50 + "\nEXPERIMENT 11: Sachs Cell Signaling\n" + "="*50)
    dev = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Device: {dev}")
    
    # Download data
    x_raw = download_sachs()
    dag = create_sachs_dag()
    print(f"Data shape: {x_raw.shape}, DAG nodes: {dag.d}, edges: {len(SACHS_EDGES)}")
    
    # Preprocess
    x = StandardScaler().fit_transform(x_raw).astype(np.float32)
    ne = 3
    env = np.random.randint(0, ne, x.shape[0]).astype(np.int64)
    
    # Train
    train(dag, x, env, ne, dev)
    print("Done.")

run()