In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import numpy as np
import torch
import torch.nn as nn
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

In [None]:
# Load latent embeddings and cohort statistics
latents = np.load("/content/drive/MyDrive/latents.npy")
cohort_mu = np.load("/content/drive/MyDrive/cohort_mu.npy")
cohort_sigma = np.load("/content/drive/MyDrive/cohort_sigma.npy")

print(f"Latents shape: {latents.shape}")
print(f"Cohort mean shape: {cohort_mu.shape}")
print(f"Cohort covariance shape: {cohort_sigma.shape}")

Latents shape: (1064, 16)
Cohort mean shape: (16,)
Cohort covariance shape: (16, 16)


In [None]:
class NeuroCore:
    def __init__(self, W, tau=1.0, eta=-5.0, Delta=1.0, J=15.0, I_ext=0.0):
        self.W = W
        self.N = W.shape[0]
        self.tau = tau
        self.eta = eta
        self.Delta = Delta
        self.J = J
        self.I_ext = I_ext

    def montbrio_equations(self, t, y):
        N = self.N
        r = y[:N]
        V = y[N:]
        coupling = self.W @ r
        dr_dt = (self.Delta / np.pi + 2 * r * V) / self.tau
        dV_dt = (V**2 + self.eta + self.J * coupling + self.I_ext) / self.tau
        return np.concatenate([dr_dt, dV_dt])

    def simulate(self, r0=None, V0=None, t_max=100.0, dt=0.1):
        N = self.N
        t_eval = np.arange(0, t_max, dt)
        if r0 is None: r0 = 0.1 * np.ones(N)
        if V0 is None: V0 = -5.0 * np.ones(N)
        y0 = np.concatenate([r0, V0])
        sol = solve_ivp(self.montbrio_equations, [0, t_max], y0, t_eval=t_eval, method="RK45")
        return t_eval, sol.y[:N], sol.y[N:]


In [None]:
class CrossCoder(nn.Module):
    def __init__(self, parcellation_dims, latent_dim=16):
        super(CrossCoder, self).__init__()
        self.parcellations = list(parcellation_dims.keys())
        self.latent_dim = latent_dim

        # Encoder blocks
        self.encoders = nn.ModuleDict()
        for parc, input_dim in parcellation_dims.items():
            self.encoders[parc] = nn.Sequential(
                nn.Linear(input_dim, 1024),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(1024, 512),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(512, 128),
                nn.ReLU(),
                nn.Linear(128, latent_dim)
            )

        # Decoder blocks
        self.decoders = nn.ModuleDict()
        for parc, input_dim in parcellation_dims.items():
            self.decoders[parc] = nn.Sequential(
                nn.Linear(latent_dim, 128),
                nn.ReLU(),
                nn.Linear(128, 512),
                nn.ReLU(),
                nn.Linear(512, 1024),
                nn.ReLU(),
                nn.Linear(1024, input_dim)
            )

    def encode(self, x, parcellation):
        if parcellation not in self.encoders:
            raise KeyError(f"Unknown parcellation key: {parcellation}")
        return self.encoders[parcellation](x)

    def decode(self, z, parcellation):
        if parcellation not in self.decoders:
            raise KeyError(f"Unknown parcellation key: {parcellation}")
        return self.decoders[parcellation](z) # Added return statement

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the saved state dictionary
checkpoint = torch.load("/content/drive/MyDrive/crosscoder_fixed.pth", map_location=device)

# Instantiate the model (replace CrossCoder with your actual model class)
# You will need the correct parcellation_dims and latent_dim
parcellation_dims = checkpoint['parcellation_dims'] # Get from loaded dictionary
latent_dim = checkpoint['latent_dim'] # Get from loaded dictionary
crosscoder = CrossCoder(parcellation_dims, latent_dim).to(device)

# Load the state dictionary into the model
crosscoder.load_state_dict(checkpoint['model_state_dict'])

crosscoder.eval()

# Make sure 'latent' tensor is available
# Example: latent = torch.load("/content/drive/MyDrive/latents.pth")


CrossCoder(
  (encoders): ModuleDict(
    (parc_86): Sequential(
      (0): Linear(in_features=3403, out_features=1024, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.3, inplace=False)
      (3): Linear(in_features=1024, out_features=512, bias=True)
      (4): ReLU()
      (5): Dropout(p=0.3, inplace=False)
      (6): Linear(in_features=512, out_features=128, bias=True)
      (7): ReLU()
      (8): Linear(in_features=128, out_features=16, bias=True)
    )
    (parc_129): Sequential(
      (0): Linear(in_features=8256, out_features=1024, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.3, inplace=False)
      (3): Linear(in_features=1024, out_features=512, bias=True)
      (4): ReLU()
      (5): Dropout(p=0.3, inplace=False)
      (6): Linear(in_features=512, out_features=128, bias=True)
      (7): ReLU()
      (8): Linear(in_features=128, out_features=16, bias=True)
    )
    (parc_234): Sequential(
      (0): Linear(in_features=27261, out_features=1024, bias=True)
      (1): R

In [None]:
import numpy as np
import torch
from scipy.integrate import solve_ivp
import pandas as pd

# ---------------------------
# Load CrossCoder and Latents
# ---------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
crosscoder = crosscoder.to(device)
crosscoder.eval()

print("✅ Loaded latents:", latents.shape)


# ---------------------------
# Helper: Decode latent -> Symmetric matrix
# ---------------------------

def vector_to_symmetric_matrix(vec, n_regions=463):
    mat = np.zeros((n_regions, n_regions))
    iu = np.triu_indices(n_regions, k=1)
    mat[iu] = vec[:len(iu[0])]
    mat = mat + mat.T
    np.fill_diagonal(mat, 0)
    return mat


# ---------------------------
# NeuroCore Simulation Model
# ---------------------------

class NeuroCore:
    def __init__(self, W, tau=1.0, eta=-5.0, Delta=1.0, J=15.0, I_ext=0.0):
        self.W = W
        self.N = W.shape[0]
        self.tau = tau
        self.eta = eta
        self.Delta = Delta
        self.J = J
        self.I_ext = I_ext

    def montbrio_equations(self, t, y):
        N = self.N
        r = y[:N]
        V = y[N:]
        coupling = self.W @ r
        dr_dt = (self.Delta / np.pi + 2 * r * V) / self.tau
        dV_dt = (V**2 + self.eta + self.J * coupling + self.I_ext) / self.tau
        return np.concatenate([dr_dt, dV_dt])

    def simulate(self, r0=None, V0=None, t_max=100.0, dt=0.1):
        N = self.N
        t_eval = np.arange(0, t_max, dt)
        if r0 is None: r0 = 0.1 * np.ones(N)
        if V0 is None: V0 = -5.0 * np.ones(N)
        y0 = np.concatenate([r0, V0])
        sol = solve_ivp(self.montbrio_equations, [0, t_max], y0, t_eval=t_eval, method="RK45")
        return t_eval, sol.y[:N], sol.y[N:]


# ---------------------------
# Brain state interpretation logic
# ---------------------------

def interpret_brain_state(mean_r, mean_V):
    """
    Simplified 3-state interpretation: 'suppressed', 'normal', 'hyperactive'.
    """

    # 1️⃣ Suppressed (low firing, low voltage)
    if mean_r < 0.15 and mean_V < -4.5:
        return "suppressed", "Neural activity is reduced — consider increasing excitatory drive or lowering inhibition."

    # 2️⃣ Normal (balanced/moderate range)
    elif 0.15 <= mean_r < 0.45 and -4.5 <= mean_V < -2.0:
        return "normal", "Stable network dynamics — no major modulation needed."

    # 3️⃣ Hyperactive (high firing or depolarized)
    elif mean_r >= 0.45 or mean_V >= -2.0:
        return "hyperactive", "Overexcitation detected — recommend reducing excitatory coupling or adding inhibitory feedback."

    # Fallback (very rare)
    else:
        return "normal", "Dynamics within near-normal range — maintain current parameters."


# ---------------------------
# Dataset generation loop
# ---------------------------

results = []

for i, latent in enumerate(latents):
    try:
        z = torch.tensor(latent, dtype=torch.float32).unsqueeze(0).to(device)

        # Decode using CrossCoder
        decoded = crosscoder.decode(z, "parc_463")
        decoded = decoded.detach().cpu().numpy().flatten()

        # Convert to symmetric connectivity matrix
        W_463 = vector_to_symmetric_matrix(decoded, 463)

        # Simulate brain dynamics
        neurocore = NeuroCore(W_463)
        t, r, V = neurocore.simulate(t_max=100.0, dt=0.1)

        # Compute mean features
        mean_r = np.mean(r)
        mean_V = np.mean(V)

        # Interpret and advise
        state, advice = interpret_brain_state(mean_r, mean_V)

        results.append({
            "brain_id": i,
            "mean_r": mean_r,
            "mean_V": mean_V,
            "state": state,
            "advice": advice
        })

        print(f"✅ Processed brain {i+1}/{len(latents)} — {state}")

    except Exception as e:
        print(f"❌ Error processing brain {i}: {e}")
        continue


# ---------------------------
# Save dataset as CSV
# ---------------------------

df = pd.DataFrame(results)
df.to_csv("neurocore_advice_dataset.csv", index=False)
print("✅ Dataset saved to neurocore_advice_dataset.csv")
print(df.head())


✅ Loaded latents: (1064, 16)
✅ Processed brain 1/1064 — normal
✅ Processed brain 2/1064 — normal
✅ Processed brain 3/1064 — normal
✅ Processed brain 4/1064 — normal
✅ Processed brain 5/1064 — normal
✅ Processed brain 6/1064 — normal
✅ Processed brain 7/1064 — normal
✅ Processed brain 8/1064 — normal
✅ Processed brain 9/1064 — normal
✅ Processed brain 10/1064 — normal
✅ Processed brain 11/1064 — normal
✅ Processed brain 12/1064 — normal
✅ Processed brain 13/1064 — normal
✅ Processed brain 14/1064 — normal
✅ Processed brain 15/1064 — normal
✅ Processed brain 16/1064 — normal
✅ Processed brain 17/1064 — normal
✅ Processed brain 18/1064 — normal
✅ Processed brain 19/1064 — normal
✅ Processed brain 20/1064 — normal
✅ Processed brain 21/1064 — normal
✅ Processed brain 22/1064 — normal
✅ Processed brain 23/1064 — normal
✅ Processed brain 24/1064 — normal
✅ Processed brain 25/1064 — normal
✅ Processed brain 26/1064 — normal
✅ Processed brain 27/1064 — normal
✅ Processed brain 28/1064 — normal
