In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/GitHub/Biological-Foundation-Model/Notebooks')

!pip install -r ../requirements.txt

In [None]:
!cd "/content/drive/My Drive/Colab Notebooks/huggingface_models"
!pip freeze > "/content/drive/My Drive/Colab Notebooks/huggingface_models/requirements.txt"
%cd "/content/drive/My Drive/Colab Notebooks/huggingface_models/Geneformer"
!pip install .

In [3]:
import sys
os.chdir('/content/drive/MyDrive/GitHub/Biological-Foundation-Model/Notebooks')
sys.path.append("structural_probe_utils")
import numpy as np
import pandas as pd
import pickle
import torch
import anndata as ad

from structural_probe_utils.structural_probe_utils import build_G_from_edges, train_A

In [4]:
# prepare gene embedding

from geneformer import TOKEN_DICTIONARY_FILE
from geneformer import perturber_utils as pu
from geneformer import TranscriptomeTokenizer
with open("/content/drive/My Drive/Colab Notebooks/huggingface_models/Geneformer/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl", "rb") as f:
    token_dict = pickle.load(f)
model = pu.load_model("Pretrained", num_classes = 0, model_directory = "/content/drive/My Drive/Colab Notebooks/huggingface_models/Geneformer/gf-12L-30M-i2048", mode="eval")

In [5]:
len(token_dict)

25426

In [6]:
gene_emb = model.get_input_embeddings().weight.detach().cpu()
gene_names = list(token_dict.keys())
print(gene_emb.shape)
print(len(gene_names))

torch.Size([25426, 512])
25426


In [26]:
# prepare true gene network
STRING_homosapien = pd.read_csv("/content/drive/MyDrive/Colab Notebooks/Structural_Probe_Gene/data/9606.protein.links.v12.0.txt", sep=" ")
STRING_homosapien["protein1"] = STRING_homosapien["protein1"].str.replace("9606.", "", regex=False)
STRING_homosapien["protein2"] = STRING_homosapien["protein2"].str.replace("9606.", "", regex=False)

STRING_homosapien_alias = pd.read_csv("/content/drive/MyDrive/Colab Notebooks/Structural_Probe_Gene/data/9606.protein.aliases.v12.0.txt", sep="\t", header=None)
STRING_homosapien_alias = STRING_homosapien_alias[STRING_homosapien_alias[1].str.startswith("ENSG", na=False)][[0,1]].drop_duplicates()
STRING_homosapien_alias[0] = STRING_homosapien_alias[0].str.replace("9606.", "", regex=False)
ENSP_to_ENSG = dict(zip(STRING_homosapien_alias[0], STRING_homosapien_alias[1]))

# Map both columns using the dictionary
STRING_homosapien["gene1"] = STRING_homosapien["protein1"].map(ENSP_to_ENSG)
STRING_homosapien["gene2"] = STRING_homosapien["protein2"].map(ENSP_to_ENSG)
STRING_gene_interaction = STRING_homosapien.dropna(subset=["gene1", "gene2", "combined_score"])
STRING_gene_interaction_high_conf = STRING_gene_interaction[STRING_gene_interaction["combined_score"]>700] # 700 for highly confident gene interaction

In [27]:
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, f1_score
from sklearn.metrics import precision_score, recall_score, f1_score

n_splits = 10
n = len(gene_names)
random_state = 42
device = "cpu"

G_full = build_G_from_edges(STRING_gene_interaction_high_conf, gene_names, make_dense = True)

genes = np.array(gene_names)
X_full = np.asarray(gene_emb)  # (n,d)
n = len(genes)

kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)

acc_tr_0 = []
acc_tr_1 = []
acc_va_0 = []
acc_va_1 = []
f1_tr = []
f1_va = []
precision_va = []
recall_va = []
for fold_id, (train_idx, val_idx) in enumerate(kf.split(np.arange(n)), start=1):
    # Induced subgraphs and embeddings
    X_tr = X_full[train_idx]
    X_va = X_full[val_idx]
    G_tr = G_full[np.ix_(train_idx, train_idx)]
    G_va = G_full[np.ix_(val_idx,  val_idx)]

    # Train
    A, w, tr_loss, va_loss = train_A(
        X_tr, G_tr, X_va, G_va, steps = 200
    )

    # X_tr = torch.as_tensor(X_tr, dtype=torch.float32, device=device)
    # np.fill_diagonal(G_tr, 0)
    # G_tr = torch.from_numpy((G_tr > 0).astype(np.bool_)).to(device)
    # nt = X_tr.shape[0]
    # I, J = torch.triu_indices(nt, nt, offset=1, device=device)
    # y = G_tr[I, J].float()

    X_va = torch.as_tensor(X_va, dtype=torch.float32, device=device)
    np.fill_diagonal(G_va, 0)
    G_va = torch.from_numpy((G_va > 0).astype(np.bool_)).to(device)
    nv = X_va.shape[0]
    Iv, Jv = torch.triu_indices(nv, nv, offset=1, device=device)
    yv = G_va[Iv, Jv].float()

    with torch.no_grad():
      # delta = X_tr[I] - X_tr[J]
      # z = delta @ A
      # sqdist = (z * z).sum(dim=1)
      # logits = w - sqdist
      # prob = torch.sigmoid(logits).cpu().numpy()
      # pred = (prob > 0.5).astype(int)

      vdelta = X_va[Iv] - X_va[Jv]
      vz = vdelta @ A
      vsqdist = (vz * vz).sum(dim=1)
      vlogits = w - vsqdist
      vprob = torch.sigmoid(vlogits).cpu().numpy()
      vpred = (vprob > 0.5).astype(int)

      # acc_tr_0.append(accuracy_score(y[y == 0], pred[y == 0]))
      # acc_tr_1.append(accuracy_score(y[y == 1], pred[y == 1]))
      acc_va_0.append(accuracy_score(yv[yv == 0], vpred[yv == 0]))
      acc_va_1.append(accuracy_score(yv[yv == 1], vpred[yv == 1]))
      # f1_tr.append(f1_score(y, pred))
      f1_va.append(f1_score(yv, vpred))
      precision_va.append(precision_score(yv, vpred))
      recall_va.append(recall_score(yv, vpred))


In [28]:
print(acc_va_0)
print(acc_va_1)
print(f1_va)
print(precision_va)
print(recall_va)

[0.6913906673109174, 0.6892803873485719, 0.6906040574170152, 0.6993530715619258, 0.6654986802807755, 0.6989786018450383, 0.7163280699490865, 0.7008965961359263, 0.7134069162389907, 0.7216914254094087]
[0.6925520649488175, 0.7280801209372638, 0.6987664146438519, 0.6774135120449194, 0.7409658344283837, 0.6762986470686487, 0.6911516062370844, 0.7003240661777247, 0.6366160681229175, 0.6901633522727273]
[0.007806516173056379, 0.0076161748340077134, 0.006975187636916856, 0.0076378124814926, 0.008279175789788588, 0.008253319214293795, 0.007966289969490344, 0.008429454641383777, 0.007372890531120394, 0.008573193056750259]
[0.003925381734370499, 0.00382810963117808, 0.0035050879665300028, 0.0038405573121027873, 0.004162844650500076, 0.004151994372401799, 0.004006233128139568, 0.004240246194518454, 0.0037079166554354567, 0.0043133869242344205]
[0.6925520649488175, 0.7280801209372638, 0.6987664146438519, 0.6774135120449194, 0.7409658344283837, 0.6762986470686487, 0.6911516062370844, 0.70032406617

In [29]:
G_full.sum() / n**2

np.float64(0.0017366251651114593)

In [14]:
len(set(STRING_gene_interaction_high_conf["gene1"]) | set(STRING_gene_interaction_high_conf["gene2"]))

16185