## This notebook processes the LINCS-L1000 dataset and generates the graphs and necessary files for subsequent model training.

In [None]:
import scanpy as sc 
import pandas as pd
import SEACells
import numpy as np
import matplotlib
from statistics import mode
import matplotlib.pyplot as plt
from statistics import mode
import sys 
sys.path.insert(0, '../')
import utils
from importlib import reload
reload(utils)
from utils import *
import seaborn as sns
# Some plotting aesthetics
%matplotlib inline
sns.set_style('ticks')
matplotlib.rcParams['figure.dpi'] = 300

# Input Data Directory
data_path = "/../Cell-Type-Specific-Graphs/Data/"

# Output Results Directory
save_path_results = "/../Cell-Type-Specific-Graphs/Results/"

## Import the Data
- Download from  https://f003.backblazeb2.com/file/chemCPA-datasets/lincs_full.h5ad.gz

In [None]:
adata = sc.read(f"{data_path}lincs_full.h5ad")
adata

In [None]:
adata.var.reset_index(inplace = True)
adata.var.rename({"index": "gene_name"}, axis=1, inplace=True)
adata.var.set_index("gene_name", inplace = True)
adata.var

In [None]:
import re
def remove_non_alphanumeric(input_string):
    return re.sub(r'[^a-zA-Z0-9]', '', input_string)
    
adata.obs['condition'] = adata.obs['pert_iname'].apply(remove_non_alphanumeric)
adata.obs.loc[adata.obs.condition == "DMSO", "condition"] = "control"
print(adata.obs.condition.unique())
adata.obs['cell_type'] = adata.obs['cell_id']
adata.obs['dose_val'] = adata.obs['pert_dose'].astype(float) / np.max(adata.obs['pert_dose'].astype(float))
adata.obs['cov_drug_dose'] = adata.obs.cell_type.astype(str) + '_' + adata.obs.condition.astype(str) + '_' + adata.obs.dose_val.astype(str)
adata.obs['cov_drug'] = adata.obs.cell_type.astype(str) + '_' + adata.obs.condition.astype(str)
adata.obs['eval_category'] = adata.obs['cov_drug']
adata.obs['control'] = (adata.obs['condition'] == 'DMSO').astype(int)
adata

## Plot a histogram of the number of samples available for each cell line–drug pair

In [None]:
print(adata.obs.cov_drug.nunique())
print(adata.obs.cell_type.nunique())
print(adata.obs.condition.nunique())

In [None]:
cov_drug = adata.obs["cov_drug"].astype(str).value_counts()
cov_drug = cov_drug[cov_drug >= 50].sort_values(ascending=False)
print(cov_drug.index.nunique())
n = len(cov_drug)
x = np.arange(n)
font_size = 50
plt.figure(figsize=(n/40 + 10, 20))   # ~40 bars per inch
plt.bar(x, cov_drug.values, width=4.0, color = "#B5345C")
plt.ylabel("Number of samples per cell line-drug pairs", fontsize=font_size, labelpad=20)
plt.xlabel(f"(cell line-drug pairs) ={n}", fontsize=font_size, labelpad=20)
plt.xticks([])                           # keep labels hidden
plt.yticks(fontsize=font_size) 
plt.tight_layout()
plt.savefig("cov_drug_50.png", dpi = 300, bbox_inches='tight')
plt.show()

In [None]:
# y: counts per pair (e.g., cov_drug.values), sorted descending
# x: bar positions you already use (e.g., np.arange(len(y)))
y = cov_drug.values
# x = np.arange(len(y))  # if you didn't define x elsewhere

thr = 300  # your threshold

# Filter by number of samples per cov_drug
adata_thr = adata[adata.obs.cov_drug.value_counts()[adata.obs.cov_drug] >= thr]

# Compute required values
n_cell_types = adata_thr.obs["cell_type"].nunique() 
n_cell_line_drug = adata_thr.obs["cov_drug"].nunique() 
n_samples = adata_thr.n_obs  # total number of samples
n_conditions = adata_thr.obs["condition"].nunique() - (1 if "control" in adata_thr.obs["condition"].unique() else 0)
dataset_size = adata_thr.shape  # (samples, features)

print(f"Number of cell lines: {n_cell_types}")
print(f"Number of drugs (excluding control): {n_conditions}")
print(f"Number of cell line-drug pairs: {n_cell_line_drug}")
print(f"Dataset size (cells × genes): {dataset_size}")

fig, ax = plt.subplots(figsize=(n/40 + 10, 20))
ax.bar(x, y, width=4.0, color="#B5345C")

# ---- vertical line at the last bar with y >= thr ----
idx_above = np.where(y >= thr)[0]
if idx_above.size:
    k = idx_above[-1]                 # last index where y >= thr
    # place the line between bar k and k+1
    dx = np.diff(x).mean() if len(x) > 1 else 1.0
    cut_x = x[k] + dx/2
    ax.axvline(cut_x, linestyle="--", linewidth=5)


ax.set_ylabel("Number of samples per cell line–drug pair", fontsize=font_size, labelpad=20)
ax.set_xlabel(f"(cell line–drug pairs)", fontsize=font_size, labelpad=20)
ax.set_xticks([])             # keep x labels hidden
ax.tick_params(axis='y', labelsize=font_size)
plt.tight_layout()
plt.savefig("cov_drug_300.png", dpi = 300, bbox_inches='tight')
plt.show()


In [None]:
adata = adata[adata.obs.cov_drug.value_counts()[adata.obs.cov_drug] >= thr].copy()
adata

In [None]:
ct = pd.crosstab(
    adata.obs["condition"].astype(str),
    adata.obs["cell_type"].astype(str),
    dropna=False,      # keep NaNs as "nan" categories; set True to drop
    margins=True,      # add row/col totals
    margins_name="Total"
)
ct

## Download drug information from GEO (https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE92742) and assign SMILES strings to each drug

In [None]:
pert_id_unique = adata.obs["pert_id"].dropna().unique().tolist()
bad = {'-666', 'restricted', 'nan'}
reference_df = pd.read_csv(data_path + "GSE92742_Broad_LINCS_pert_info.txt", delimiter = "\t")
display(reference_df.loc[reference_df.pert_id == "ERG-AZD2281"])
reference_df = reference_df.loc[reference_df.pert_id.isin(pert_id_unique), ['pert_id', 'canonical_smiles']]
mask = reference_df["canonical_smiles"].notna() & ~reference_df["canonical_smiles"].isin(bad)
reference_df = reference_df.loc[mask].copy()
reference_df.canonical_smiles.value_counts()
reference_df
reference_dict = dict(zip(reference_df["pert_id"], reference_df["canonical_smiles"]))
reference_df

In [None]:
adata.obs["SMILES"] = adata.obs["pert_id"].map(reference_dict)
adata.obs.loc[adata.obs.pert_id == "CMAP-ERG-AZD2281", "SMILES"] = "O=C(c1cc(Cc2n[nH]c(=O)c3ccccc23)ccc1F)N1CCN(C(=O)C2CC2)CC1"
# 2) Unique pert_id that FAILED to map to SMILES
miss = adata.obs["SMILES"].isna()
uniq_unmapped = adata.obs.loc[miss, "pert_id"].dropna().unique()
print("Unique unmapped pert_id:", len(uniq_unmapped))
print(uniq_unmapped[:20])  # preview first 20

# 3) (Optional) the full set difference
missing_keys = set(adata.obs["pert_id"].dropna().unique()) - set(reference_dict.keys())
print("Unique unmapped via set-diff:", len(missing_keys))

In [None]:
adata

In [None]:
ct = pd.crosstab(
    adata.obs["condition"].astype(str),
    adata.obs["cell_type"].astype(str),
    dropna=False,      # keep NaNs as "nan" categories; set True to drop
    margins=True,      # add row/col totals
    margins_name="Total"
)
ct

## Construct cell type graphs over control cells 

In [None]:
ctrl_adata = adata[adata.obs.condition == "control", :].copy()
ctrl_adata

In [None]:
ctrl_adata.X

In [None]:
import os
import networkx as nx
from torch_geometric.utils.convert import from_networkx
from sklearn.metrics import mean_squared_error
from torch_geometric.data import InMemoryDataset, Data, download_url, extract_zip, HeteroData, Batch
from torch_geometric.utils import *
import torch
from torch import nn
cell_type_network = {}
# Leaf directory 
directory = "LINCS"
# Parent Directories 
parent_dir = "/../Cell-Type-Specific-Graphs/graphs/"
# Path 
path = os.path.join(parent_dir, directory)
print(path)
try: 
    os.makedirs(path) 
except:
    print("File exists!")
max_nodes = 0
genes_pos = []
for cell_type in ctrl_adata.obs.cell_type.unique():
    print(cell_type)
    cell = Correlation_matrix(ctrl_adata, cell_type, 'cell_type',
                       hv_genes_cells = None, union_HVGs = True)
    threshold = np.percentile(np.abs(cell[0].values), 99.0)
    print("threshold:", threshold)
    g = create_coexpression_graph(ctrl_adata, cell, cell_type, threshold, 'gene_name')
    if g.num_nodes >= max_nodes:
        max_nodes = g.num_nodes
    cell_type_network[cell_type] = g
    genes_pos.append(cell_type_network[cell_type].pos)
    cell_type_network[cell_type].pos = torch.tensor(cell_type_network[cell_type].pos)
    torch.save(cell_type_network[cell_type] , os.path.join(path,cell_type+'_coexpr_graph.pkl'))

cell_type_network

## Match treatment and control cells

In [None]:
import numpy as np
import torch
import tqdm as tq
from scipy.spatial.distance import cdist

# Separate stimulated and control data
stim_data = adata[adata.obs.condition != 'control'].copy()
stim_data.layers['ctrl_x'] = stim_data.X
ctrl_data = adata[adata.obs.condition == 'control'].copy()

# Pre-compute mean and variance for each cell type's control expression
for cell_type in adata.obs.cell_type.unique():
    genes = cell_type_network[cell_type].pos.tolist()
    ctrl_subset = ctrl_data[ctrl_data.obs.cell_type == cell_type, genes].copy().X
    ctrl_tensor = torch.tensor(ctrl_subset.toarray() if hasattr(ctrl_subset, "toarray") else ctrl_subset)
    
    mean = torch.mean(ctrl_tensor, dim=0)
    std = torch.var(ctrl_tensor, dim=0)
    
    cell_type_network[cell_type].x = torch.cat([mean.unsqueeze(1), std.unsqueeze(1)], dim=1)

# Random control selection for each drug-cell_type condition
for cov_drug in tq.tqdm(stim_data.obs.cov_drug.unique()):
    cell_type = cov_drug.split('_')[0]
    
    # Get control and stimulated samples for this cell type
    ad_ctrl = ctrl_data[ctrl_data.obs.cell_type == cell_type].copy()
    ad_stim = stim_data[stim_data.obs.cov_drug == cov_drug].copy()
    
    # Extract expression matrices
    ctrl_X = ad_ctrl.X.toarray() if hasattr(ad_ctrl.X, "toarray") else ad_ctrl.X
    stim_X = ad_stim.X.toarray() if hasattr(ad_stim.X, "toarray") else ad_stim.X
    
    # Randomly select one control index for each stimulated sample
    rng = np.random.default_rng(seed=42)
    if ctrl_X.shape[0] > stim_X.shape[0]:
        random_ctrl_indices = rng.choice(ctrl_X.shape[0], size=stim_X.shape[0], replace=False)
    else:
        random_ctrl_indices = rng.choice(ctrl_X.shape[0], size=stim_X.shape[0], replace=True)
    selected_ctrl_X = ctrl_X[random_ctrl_indices]

    # Save in the 'ctrl_x' layer
    stim_data[ad_stim.obs.index.values, :].layers['ctrl_x'] = selected_ctrl_X

## Plot the distance distribution between perturbation and control, and identify the conditions with maximum and minimum distances

In [None]:
from scperturb import *
estats = edist_to_control(adata, obs_key='condition', obsm_key='X_pca', dist='sqeuclidean')
estats

In [None]:
# Series of distances; drop NaNs and the control=0
dist = estats["distance"].astype(float)
dist = dist[dist > 0]

d = dist.sort_values(ascending = False)                           # ascending

plt.figure(figsize=(10,4))
plt.bar(d.index, d.values) # line plot of sorted values
plt.xlabel(f"Conditions (n = {adata.obs.condition.nunique() - 1})")
plt.xticks(rotation=45, ha='right')
plt.title("Distances from control")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# distances as a Series indexed by condition/drug
dist = estats["distance"].astype(float).dropna()
dist = dist[dist > 0]                       # drop control (0)

# sort DESC so largest effects are first
d = dist.sort_values(ascending=False)
n = len(d)

# top & lowest
top_name,   top_val = d.index[0],   float(d.iloc[0])
low_name,   low_val = d.index[-1],  float(d.iloc[-1])

# quantile bins (top10%, 10–25%, 25–50%, 50–75%, 75–90%, 90–100%)
qs = [0.10, 0.25, 0.50, 0.75, 0.90]
idxs = [int(np.ceil(q*n)) for q in qs]
bounds = [0] + idxs + [n]
labels = ["top 10%", "10–25%", "25–50%", "50–75%", "75–90%", "90–100%"]

seed = 123            # set your seed here
rng = np.random.RandomState(seed)

samples = []
for (a, b), lab in zip(zip(bounds[:-1], bounds[1:]), labels):
    if b - a <= 0:
        samples.append((lab, None, np.nan))
        continue
    pos = rng.choice(np.arange(a, b))
    samples.append((lab, d.index[pos], float(d.iloc[pos])))

print(f"TOP:     {top_name}  ({top_val:.3f})")
print(f"LOWEST:  {low_name}  ({low_val:.3f})")
for lab, name, val in samples:
    print(f"{lab:8} {name}  ({val:.3f})")

In [None]:
adata.obs["split"] = "Train"
testing_drugs = ["olaparib", "geldanamycin"]
adata.obs.loc[ adata.obs.condition.isin(testing_drugs) , "split"] = "Test"
adata.obs.split.value_counts()

In [None]:
adata.write(data_path + 'LINCS_L1000_processed.h5ad')

## Map drugs from their SMILES to their chemical descriptor features

In [None]:
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import Chem, DataStructs
from tqdm import tqdm
import pandas as pd

file_path = data_path + "SMILES_feat_all_datasets.csv"

df = pd.read_csv(file_path)
train_smiles = adata.obs.loc[adata.obs.split == "Train"].SMILES.unique().tolist()
df_train = df.loc[df.SMILES.isin(train_smiles)]

test_smiles = adata.obs.loc[adata.obs.split == "Test"].SMILES.unique().tolist()
df_test = df.loc[df.SMILES.isin(test_smiles)] 
display(df_train)
display(df_test)

In [None]:
from sklearn.preprocessing import StandardScaler

# first column is "SMILES", the rest are numeric descriptors
feat_cols = df_train.columns.drop("SMILES")   # or: df_train.select_dtypes("number").columns

# --- fit on TRAIN only ---
scaler = StandardScaler().fit(df_train[feat_cols])

# --- transform train & test with the SAME scaler ---
Xtr = pd.DataFrame(
    scaler.transform(df_train[feat_cols]),
    index=df_train.index, columns=feat_cols
)
Xte = pd.DataFrame(
    scaler.transform(df_test[feat_cols]),
    index=df_test.index, columns=feat_cols
)

# (optional) put SMILES back as the first column
df_train_scaled = pd.concat([df_train[["SMILES"]], Xtr], axis=1)
df_test_scaled  = pd.concat([df_test[["SMILES"]],  Xte], axis=1)

# quick checks
print(Xtr.mean().round(3).head())  # ~0 on train
print(Xtr.std(ddof=0).round(3).head())  # ~1 on train

In [None]:
train_scaled = df_train_scaled.copy()
test_scaled  = df_test_scaled.copy()
train_scaled["split"] = "Train"
test_scaled["split"]  = "Test"

df_all_scaled = pd.concat([train_scaled, test_scaled], axis=0, ignore_index=True)
df_all_scaled

In [None]:

# Define a list of drugs and their corresponding SMILES strings
canonical_smiles = {}
drug_smiles = adata.obs[['condition', 'SMILES']]
drug_smiles = dict(zip(drug_smiles['condition'], drug_smiles['SMILES']))
canonical_smiles = {}
for drug_name, smile in drug_smiles.items():
    candidates = list(df_all_scaled.loc[df_all_scaled.SMILES == smile].values[0][1:])
    canonical_smiles[drug_name] = candidates[:-1]

## Construct PyTorch Geometric (PyG) data objects for the cells to enable GNN training.

In [None]:
from importlib import reload
import utils
reload(utils)
from utils import *
import pickle

cells_train = create_cells(stim_data, cell_type_network, canonical_smiles)
with open(data_path+'cells_LINCS.pkl', 'wb') as f:
    pickle.dump(cells_train, f)
cells_train[0]