# Imports

In [None]:
%load_ext autoreload
%autoreload 2
import h5py
import anndata
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy as sp
import scipy.sparse.linalg
rng=np.random.default_rng()
import tqdm.notebook
import pickle
%matplotlib inline
import sys
import ipywidgets
import sklearn.neighbors
from scipy.sparse import csr_matrix
import requests

from spatial.merfish_dataset import FilteredMerfishDataset, MerfishDataset
from spatial.models.monet_ae import MonetAutoencoder2D, TrivialAutoencoder
from spatial.train import train
from spatial.predict import test

import time
import json
import numpy as np

import torch
import pytorch_lightning as pl

from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RepeatedKFold
from sklearn.preprocessing import StandardScaler

from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import HistGradientBoostingRegressor

import xgboost as xgb

import hydra
from hydra.experimental import compose, initialize

%matplotlib inline

import copy
import os
from collections import defaultdict
import itertools

import seaborn as sns
from sklearn.cluster import AgglomerativeClustering

from MESSI_for_reproduction.tutorials.context import messi
from messi.data_processing import *
from messi.hme import hme
from messi.gridSearch import gridSearch

original_url= "https://datadryad.org/stash/downloads/file_stream/67671"
csv_location='../data/spatial/moffit_merfish/original_file.csv'
h5ad_location='../data/spatial/moffit_merfish/original_file.h5ad'
connectivity_matrix_template='../data/spatial/moffit_merfish/connectivity_%dneighbors.h5ad'
genetypes_location='/data/spatial/moffit_merfish/genetypes.pkl'

# LightGBM Implementation

###### download csv

In [None]:
import requests

with open(csv_location, "wb") as csvf:
    csvf.write(requests.get(original_url).content)

##### munge into hdf5 file

In [None]:
dataframe = pd.read_csv(csv_location)

dct={}
for colnm, dtype in zip(dataframe.keys()[:9], dataframe.dtypes[:9]):
    if dtype.kind == "O":
        dct[colnm]=np.require(dataframe[colnm], dtype="U36")
    else:
        dct[colnm]=np.require(dataframe[colnm])
expression = np.array(dataframe[dataframe.keys()[9:]]).astype(np.float16)
gene_names = np.array(dataframe.keys()[9:], dtype="U80")
cellid=dct.pop('Cell_ID')

ad=anndata.AnnData(
    X=expression,
    var=pd.DataFrame(index=gene_names),
    obs=pd.DataFrame(dct,index=cellid)
)

ad.write_h5ad(h5ad_location)

##### supplement hdf5 file with a column indicating "tissue id" for each cell

In [None]:
ad=anndata.read_h5ad(h5ad_location)
animal_ids=np.unique(ad.obs['Animal_ID'])
bregmas=np.unique(ad.obs['Bregma'])
tissue_id=np.zeros(len(ad),dtype=int)
n_tissues=0
    
for aid in animal_ids:
    for bregma in bregmas:
        good=(ad.obs['Animal_ID']==aid)&(ad.obs['Bregma']==bregma)
        if np.sum(good)>0:
            tissue_id[good]=n_tissues
            n_tissues+=1
ad.obs['Tissue_ID']=tissue_id
ad.write_h5ad(h5ad_location)

##### create global graph 

In [None]:
ad=anndata.read_h5ad(h5ad_location)
row=np.zeros(0,dtype=int)
col=np.zeros(0,dtype=int)
radius=70
mode="rad"

for tid in tqdm.notebook.tqdm(np.unique(ad.obs['Tissue_ID'])):
    good=ad.obs['Tissue_ID']==tid
    pos=np.array(ad.obs[good][['Centroid_X','Centroid_Y']])
    if mode == "neighbors":
        if nneigh == 0:
            E = csr_matrix(np.eye(pos.shape[0]))
        else:
            p=sklearn.neighbors.BallTree(pos)
            E=sklearn.neighbors.kneighbors_graph(pos,nneigh,mode='connectivity')
        col=np.r_[col,idxs[E.tocoo().col]]
        row=np.r_[row,idxs[E.tocoo().row]]
    if mode == "rad":
        p=sklearn.spatial.cKDTree(pos)
        E=p.query_ball_point(pos, r=radius, return_sorted=False)
    idxs=np.where(good)[0]
    
E=sp.sparse.coo_matrix((np.ones(len(col)),(row,col)),shape=(len(ad),len(ad))).tocsr()
if mode == "neighbors":
    anndata.AnnData(E).write_h5ad(connectivity_matrix_template%nneigh)
if mode == "rad":
    anndata.AnnData(E).write_h5ad(connectivity_matrix_template%radius)

##### write down ligand/receptor sets

In [None]:
ligands=np.array(['Cbln1', 'Cxcl14', 'Cbln2', 'Vgf', 'Scg2', 'Cartpt', 'Tac2',
       'Bdnf', 'Bmp7', 'Cyr61', 'Fn1', 'Fst', 'Gad1', 'Ntng1', 'Pnoc',
       'Selplg', 'Sema3c', 'Sema4d', 'Serpine1', 'Adcyap1', 'Cck', 'Crh',
       'Gal', 'Gnrh1', 'Nts', 'Oxt', 'Penk', 'Sst', 'Tac1', 'Trh', 'Ucn3'])

receptors=np.array(['Crhbp', 'Gabra1', 'Gpr165', 'Glra3', 'Gabrg1', 'Adora2a',
       'Avpr1a', 'Avpr2', 'Brs3', 'Calcr', 'Cckar', 'Cckbr', 'Crhr1',
       'Crhr2', 'Galr1', 'Galr2', 'Grpr', 'Htr2c', 'Igf1r', 'Igf2r',
       'Kiss1r', 'Lepr', 'Lpar1', 'Mc4r', 'Npy1r', 'Npy2r', 'Ntsr1',
       'Oprd1', 'Oprk1', 'Oprl1', 'Oxtr', 'Pdgfra', 'Prlr', 'Ramp3',
       'Rxfp1', 'Slc17a7', 'Slc18a2', 'Tacr1', 'Tacr3', 'Trhr'])

response_genes=np.array(['Ace2', 'Aldh1l1', 'Amigo2', 'Ano3', 'Aqp4', 'Ar', 'Arhgap36',
       'Baiap2', 'Ccnd2', 'Cd24a', 'Cdkn1a', 'Cenpe', 'Chat', 'Coch',
       'Col25a1', 'Cplx3', 'Cpne5', 'Creb3l1', 'Cspg5', 'Cyp19a1',
       'Cyp26a1', 'Dgkk', 'Ebf3', 'Egr2', 'Ermn', 'Esr1', 'Etv1',
       'Fbxw13', 'Fezf1', 'Fos', 'Gbx2', 'Gda', 'Gem', 'Gjc3', 'Greb1',
       'Irs4', 'Isl1', 'Klf4', 'Krt90', 'Lmod1', 'Man1a', 'Mbp', 'Mki67',
       'Mlc1', 'Myh11', 'Ndnf', 'Ndrg1', 'Necab1', 'Nnat', 'Nos1',
       'Npas1', 'Nup62cl', 'Omp', 'Onecut2', 'Opalin', 'Pak3', 'Pcdh11x',
       'Pgr', 'Plin3', 'Pou3f2', 'Rgs2', 'Rgs5', 'Rnd3', 'Scgn',
       'Serpinb1b', 'Sgk1', 'Slc15a3', 'Slc17a6', 'Slc17a8', 'Slco1a4',
       'Sln', 'Sox4', 'Sox6', 'Sox8', 'Sp9', 'Synpr', 'Syt2', 'Syt4',
       'Sytl4', 'Th', 'Tiparp', 'Tmem108', 'Traf4', 'Ttn', 'Ttyh2'])
cell_types = [
        "Ambiguous",
        "Astrocyte",
        "Endothelial 1",
        "Endothelial 2",
        "Endothelial 3",
        "Ependymal",
        "Excitatory",
        "Inhibitory",
        "Microglia",
        "OD Immature 1",
        "OD Immature 2",
        "OD Mature 1",
        "OD Mature 2",
        "OD Mature 3",
        "OD Mature 4",
        "Pericytes",
    ]

##### run a simple experiment: use ligands and receptors to predict response genes in excitatory cells, with a linear model

In [None]:
# load data
nneigh=30
radius=70
mode="rad"
ad=anndata.read_h5ad(h5ad_location)
if mode == "neighbors":
    connectivity_matrix=anndata.read_h5ad(connectivity_matrix_template%nneigh).X
if mode == "rad":
     connectivity_matrix=anndata.read_h5ad(connectivity_matrix_template%radius).X
gene_lookup={x:i for (i,x) in enumerate(ad.var.index)}

with open(genetypes_location,'rb') as f:
    genetypes=pickle.load(f)

In [None]:
# onehot encode cell classes
def oh_encode(lst):
    lst=np.array(lst)
    group_names=np.unique(lst)
    group_indexes=np.zeros((len(lst),len(group_names)),dtype=bool)
    for i,nm in enumerate(group_names):
        group_indexes[lst==nm,i]=True
    return group_names,group_indexes
cell_classes,cell_class_onehots=oh_encode(ad.obs['Cell_class'])

In [None]:
# a function to construct a prediction problem for a subset of cells

def construct_problem(mask,target_gene,neighbor_genes,self_genes,filter_excitatory=False):
    '''
    mask -- set of cells
    target_gene -- gene to predict
    neighbor_genes -- names of genes which will be read from neighbors
    self_genes -- names of genes which will be read from target cell
    '''
    
    feature_names = []
    
    # load subset of data relevant to mask
    local_processed_expression=np.log1p(ad.X[mask].astype(float)) # get expression on subset of cells
    local_edges=connectivity_matrix[mask][:,mask]   # get edges for subset
    
    selfset_idxs=[gene_lookup[x] for x in self_genes] # collect the column indexes associated with them
    selfset_exprs = local_processed_expression[:,selfset_idxs] # collect ligand and receptor expressions
    
    feature_names += [x for x in self_genes]
    
    neighborset_idxs=[gene_lookup[x] for x in neighbor_genes] # collect the column indexes associated with them
    neighset_exprs = local_processed_expression[:,neighborset_idxs] # collect ligand and receptor expressions
    
    feature_names += [x + " from Neighbors" for x in neighbor_genes]
    
    n_neighs=(local_edges@np.ones(local_edges.shape[0]))
    print(n_neighs)
    neigh_avgs = (local_edges@neighset_exprs) / n_neighs[:,None] # average ligand/receptor for neighbors
    
    neigh_cellclass_avgs = (local_edges@cell_class_onehots[mask]) / n_neighs[:,None] # celltype simplex
    
    feature_names += [f"Cell Class {cell_types[x]}" for x in range(16)]
    
    positions=np.array(ad.obs[['Centroid_X','Centroid_Y','Bregma']])[mask] # get positions
    
    feature_names += ['Centroid_X','Centroid_Y','Bregma']
    
    covariates=np.c_[selfset_exprs,neigh_avgs,neigh_cellclass_avgs,positions] # collect all covariates
    predict = local_processed_expression[:,gene_lookup[target_gene]] # collect what we're supposed to predict
    
    print(selfset_exprs.shape, neigh_avgs.shape, neigh_cellclass_avgs.shape, positions.shape)
    
    if filter_excitatory:
    
        excites=(ad.obs['Cell_class']=='Excitatory')[mask] # get the subset of these cells which are excitatory
        covariates=covariates[excites] # subset to excites
        predict=predict[excites]       # subset to excites
    
    return covariates,predict,feature_names

In [None]:
response_genes=['Ace2', 'Aldh1l1', 'Amigo2', 'Ano3', 'Aqp4', 'Ar', 'Arhgap36',
       'Baiap2', 'Ccnd2', 'Cd24a', 'Cdkn1a', 'Cenpe', 'Chat', 'Coch',
       'Col25a1', 'Cplx3', 'Cpne5', 'Creb3l1', 'Cspg5', 'Cyp19a1',
       'Cyp26a1', 'Dgkk', 'Ebf3', 'Egr2', 'Ermn', 'Esr1', 'Etv1',
       'Fbxw13', 'Fezf1', 'Gbx2', 'Gda', 'Gem', 'Gjc3', 'Greb1',
       'Irs4', 'Isl1', 'Klf4', 'Krt90', 'Lmod1', 'Man1a', 'Mbp', 'Mki67',
       'Mlc1', 'Myh11', 'Ndnf', 'Ndrg1', 'Necab1', 'Nnat', 'Nos1',
       'Npas1', 'Nup62cl', 'Omp', 'Onecut2', 'Opalin', 'Pak3', 'Pcdh11x',
       'Pgr', 'Plin3', 'Pou3f2', 'Rgs2', 'Rgs5', 'Rnd3', 'Scgn',
       'Serpinb1b', 'Sgk1', 'Slc15a3', 'Slc17a6', 'Slc17a8', 'Slco1a4',
       'Sln', 'Sox4', 'Sox6', 'Sox8', 'Sp9', 'Synpr', 'Syt2', 'Syt4',
       'Sytl4', 'Th', 'Tiparp', 'Tmem108', 'Traf4', 'Ttn', 'Ttyh2']

import time
import json
from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import HistGradientBoostingRegressor

all_MAEs = []

time_dict = {}
L1_loss_dict = {}

for animal in [1,2,3,4]:
    start = time.time()
    MAE_list = []
    for target_gene in response_genes:
        neighset=genetypes['ligands']
        oset=np.r_[genetypes['ligands'],genetypes['receptors']]
        # oset=neighset

        # oset=[]
        # neighset=[]
        
        train_animals = [1,2,3,4]
        train_animals.remove(animal)
        print(train_animals)
        # FIX THIS SO THAT ONLY FIRST 4 ANIMALS GET USED
        trainX,trainY=construct_problem((ad.obs['Animal_ID']!=animal)&(ad.obs['Animal_ID']<=4),target_gene,neighset,oset,True)
        testX,testY=construct_problem((ad.obs['Animal_ID']==animal),target_gene,neighset,oset,True)

        print(trainX.shape,trainY.shape)
        print(testX.shape,testY.shape)

        # whiten covariates
        mu=np.mean(trainX,axis=0)
        sig=np.std(trainX,axis=0)
        trainX=(trainX-mu)/sig
        testX=(testX-mu)/sig

        model=HistGradientBoostingRegressor(loss="absolute_error")
        model.fit(trainX,trainY)
        MAE_list.append(np.mean(np.abs(model.predict(testX)-testY)))

    end = time.time()

    all_MAEs.append(np.mean(MAE_list))
    
print(np.mean(all_MAEs))

CV w/ Standard Scaler

# MESSI Implementation

In [None]:
def MESSI(sex, behavior, celltype, train_animals):
    input_path = 'input/'
    output_path = 'output/'
    data_type = 'merfish'
    sex = sex
    behavior = behavior
    behavior_no_space = behavior.replace(" ", "_")
    current_cell_type = celltype
    current_cell_type_no_space = current_cell_type.replace(" ", "_")

    grid_search = False
    n_sets = 5  # for example usage only; we recommend 5

    n_classes_0 = 1
    n_epochs = 20  # for example usage only; we recommend using the default 20 n_epochs 
    
    read_in_functions = {'merfish': [read_meta_merfish, read_merfish_data, get_idx_per_dataset_merfish],
                'merfish_cell_line': [read_meta_merfish_cell_line, read_merfish_cell_line_data, get_idx_per_dataset_merfish_cell_line],
                'starmap': [read_meta_starmap_combinatorial, read_starmap_combinatorial, get_idx_per_dataset_starmap_combinatorial]}

    # set data reading functions corresponding to the data type
    if data_type in ['merfish', 'merfish_cell_line', 'starmap']:
        read_meta = read_in_functions[data_type][0]
        read_data = read_in_functions[data_type][1]
        get_idx_per_dataset = read_in_functions[data_type][2]
    else:
        raise NotImplementedError(f"Now only support processing 'merfish', 'merfish_cell_line' or 'starmap'")

    # read in ligand and receptor lists
    l_u, r_u = get_lr_pairs(input_path='input/')  # may need to change to the default value

    # read in meta information about the dataset
    meta_all, meta_all_columns, cell_types_dict, genes_list, genes_list_u, \
    response_list_prior, regulator_list_prior = \
        read_meta(input_path, behavior, sex, l_u, r_u)  # TO BE MODIFIED: number of responses

    # get all available animals/samples
    all_animals = list(set(meta_all[:, meta_all_columns['Animal_ID']]))
    print(all_animals)
    
    test_animals  = [np.max(all_animals)]
    samples_test = np.array(test_animals)
    samples_train = train_animals
    print(f"Test set is {samples_test}")
    print(f"Training set is {samples_train}")
    
    n_experts_types = {'Inhibitory': {1: 10, 2: 10, 3: 10, 4: 10}, 
                   'Excitatory': {1: 8, 2: 8, 3: 10, 4: 10},
                   'Astrocyte': {1: 4, 2: 4, 3: 3, 4: 3},
                   'OD Mature 2' : {1: 3, 2: 3, 3: 4, 4: 3},
                   'Endothelial 1': {1: 1, 2: 1, 3: 2, 4: 2},
                   'OD Immature 1': {1: 1, 2: 1, 3: 2, 4: 2},
                   'OD Mature 1': {1: 1, 2: 1, 3: 1, 4: 1},
                   'Microglia': {1: 1, 2: 1, 3: 1, 4: 1}}
    
    n_classes_1 = n_experts_types["Excitatory"][test_animals[0]]

    preprocess = 'neighbor_cat'
    top_k_response = None  # for example usage only; we recommend use all responses (i.e. None)
    top_k_regulator = None
    response_type = 'original'  # use raw values to fit the model
    condition = f"response_{top_k_response}_l1_{n_classes_0}_l2_{n_classes_1}"

    if grid_search:
        condition = f"response_{top_k_response}_l1_{n_classes_0}_l2_grid_search"
    else:
        condition = f"response_{top_k_response}_l1_{n_classes_0}_l2_{n_classes_1}"
    
    bregma = None
    idx_train, idx_test, idx_train_in_general, \
    idx_test_in_general, idx_train_in_dataset, \
    idx_test_in_dataset, meta_per_dataset_train, \
    meta_per_dataset_test = find_idx_for_train_test(samples_train, samples_test, 
                                                    meta_all, meta_all_columns, data_type, 
                                                    current_cell_type, get_idx_per_dataset,
                                                    return_in_general = False, 
                                                    bregma=bregma)
    
    data_sets = []

    for animal_id, bregma in meta_per_dataset_train:
        hp, hp_cor, hp_genes = read_data(input_path, bregma, animal_id, genes_list, genes_list_u)

        if hp is not None:
            hp_columns = dict(zip(hp.columns, range(0, len(hp.columns))))
            hp_np = hp.to_numpy()
        else:
            hp_columns = None
            hp_np = None
        hp_cor_columns = dict(zip(hp_cor.columns, range(0, len(hp_cor.columns))))
        hp_genes_columns = dict(zip(hp_genes.columns, range(0, len(hp_genes.columns))))
        data_sets.append([hp_np, hp_columns, hp_cor.to_numpy(), hp_cor_columns,
                          hp_genes.to_numpy(), hp_genes_columns])
        del hp, hp_cor, hp_genes

    datasets_train = data_sets

    data_sets = []

    for animal_id, bregma in meta_per_dataset_test:
        hp, hp_cor, hp_genes = read_data(input_path, bregma, animal_id, genes_list, genes_list_u)

        if hp is not None:
            hp_columns = dict(zip(hp.columns, range(0, len(hp.columns))))
            hp_np = hp.to_numpy()
        else:
            hp_columns = None
            hp_np = None

        hp_cor_columns = dict(zip(hp_cor.columns, range(0, len(hp_cor.columns))))
        hp_genes_columns = dict(zip(hp_genes.columns, range(0, len(hp_genes.columns))))
        data_sets.append([hp_np, hp_columns, hp_cor.to_numpy(), hp_cor_columns,
                          hp_genes.to_numpy(), hp_genes_columns])
        del hp, hp_cor, hp_genes

    datasets_test = data_sets

    del data_sets
    
    if data_type == 'merfish_rna_seq':
        neighbors_train = None
        neighbors_test = None
    else: 
        if data_type == 'merfish':
            dis_filter = 100
        else:
            dis_filter = 1e9  

        neighbors_train = get_neighbors_datasets(datasets_train, "Del", k=10, dis_filter=dis_filter, include_self = False)
        neighbors_test = get_neighbors_datasets(datasets_test, "Del", k=10, dis_filter=dis_filter, include_self = False)
        
    lig_n =  {'name':'regulators_neighbor','helper':preprocess_X_neighbor_per_cell, 
                      'feature_list_type': 'regulator_neighbor', 'per_cell':True, 'baseline':False, 
                      'standardize': True, 'log':True, 'poly':False}
    rec_s = {'name':'regulators_self','helper':preprocess_X_self_per_cell, 
                          'feature_list_type': 'regulator_self', 'per_cell':True, 'baseline':False, 
                          'standardize': True, 'log':True, 'poly':False}
    lig_s = {'name':'regulators_neighbor_self','helper':preprocess_X_self_per_cell, 
                          'feature_list_type':'regulator_neighbor', 'per_cell':True, 'baseline':False, 
                          'standardize': True, 'log':True, 'poly':False}
    type_n =  {'name': 'neighbor_type','helper':preprocess_X_neighbor_type_per_dataset, 
                          'feature_list_type':None,'per_cell':False, 'baseline':False, 
                          'standardize': True, 'log':False, 'poly':False}
    base_s = {'name':'baseline','helper':preprocess_X_baseline_per_dataset,'feature_list_type':None, 
                          'per_cell':False, 'baseline':True, 'standardize': True, 'log':False, 'poly':False}
    
    if data_type == 'merfish_cell_line':
        feature_types = [lig_n, rec_s, base_s, lig_s]

    else:
        feature_types = [lig_n, rec_s, type_n , base_s, lig_s]
    
    X_trains, X_tests, regulator_list_neighbor, regulator_list_self  = prepare_features(data_type, datasets_train, datasets_test, meta_per_dataset_train, meta_per_dataset_test, 
                     idx_train, idx_test, idx_train_in_dataset, idx_test_in_dataset,neighbors_train, neighbors_test,
                    feature_types, regulator_list_prior, top_k_regulator, 
                     genes_list_u, l_u, r_u,cell_types_dict)
    
    total_regulators = regulator_list_neighbor + regulator_list_self
    
    log_response = True  # take log transformation of the response genes
    
    Y_train, Y_train_true, Y_test, Y_test_true, response_list = prepare_responses(data_type, datasets_train,
                                                                                  datasets_test, idx_train_in_general,
                                                                                  idx_test_in_general,
                                                                                  idx_train_in_dataset,
                                                                                  idx_test_in_dataset, neighbors_train,
                                                                                  neighbors_test,
                                                                                  response_type, log_response,
                                                                                  response_list_prior, top_k_response,
                                                                                  genes_list_u, l_u, r_u)

    if grid_search:
        X_trains_gs = copy.deepcopy(X_trains)
        Y_train_gs = copy.copy(Y_train)

    ### Transform and combine different type of features

    # transform features
    transform_features(X_trains, X_tests, feature_types)
    print(f"Minimum value after transformation can below 0: {np.min(X_trains['regulators_self'])}")

    if data_type == 'merfish':
        num_coordinates = 3
    elif data_type == 'starmap' or data_type == 'merfish_cell_line':
        num_coordinates = 2
    else:
        num_coordinates = None

    if np.ndim(X_trains['baseline']) > 1 and np.ndim(X_tests['baseline']) > 1:
        X_train, X_train_clf_1, X_train_clf_2 = combine_features(X_trains, preprocess, num_coordinates)
        X_test, X_test_clf_1, X_test_clf_2 = combine_features(X_tests, preprocess, num_coordinates)
    elif np.ndim(X_trains['baseline']) > 1:
        X_train, X_train_clf_1, X_train_clf_2 = combine_features(X_trains, preprocess, num_coordinates)

    print(f"Dimension of X train is: {X_train.shape}")
    print(f"Dimension of Y train is: {Y_train.shape}")

    ## Construct and train MESSI model

    ### set default parameters

    # ------ set parameters ------
    model_name_gates = 'logistic'
    model_name_experts = 'mrots'
    num_response = Y_train.shape[1]

    # default values 
    soft_weights = True
    partial_fit_expert = True

    # specify default parameters for MESSI
    model_params = {'n_classes_0': n_classes_0,
                    'n_classes_1': n_classes_1,
                    'model_name_gates': model_name_gates,
                    'model_name_experts': model_name_experts,
                    'num_responses': Y_train.shape[1],
                    'soft_weights': soft_weights,
                    'partial_fit_expert': partial_fit_expert,
                    'n_epochs': n_epochs,
                    'tolerance': 3}

    ### set up directory to save results

    # set up directory for saving the model
    sub_condition = f"{condition}_{model_name_gates}_{model_name_experts}"
    sub_dir = f"{data_type}/{behavior_no_space}/{sex}/{current_cell_type_no_space}/{preprocess}/{sub_condition}"
    current_dir = os.path.join(output_path, sub_dir)

    if not os.path.exists(current_dir):
        os.makedirs(current_dir)

    print(f"Model and validation results (if appliable) saved to: {current_dir}")

    suffix = f"_{test_animals}"

    ### conduct grid seach for hyper-parameters if needed 

    # search range for number of experts; for example usage only, we recommend 4
    search_range_dict = {'Excitatory': range(4,5), 'U-2_OS': range(1,3), \
                            'STARmap_excitatory': range(1,3)}  


    if grid_search:
        # prepare input meta data
        if data_type == 'merfish':
            meta_per_part = [tuple(i) for i in meta_per_dataset_train]
            meta_idx = meta2idx(idx_train_in_dataset, meta_per_part)
        else:
            meta_per_part, meta_idx = combineParts(samples_train, datasets_train, idx_train_in_dataset)

        # prepare parameters list to be tuned
        if data_type == 'merfish_cell_line':
            current_cell_type_data = 'U-2_OS'
        elif data_type == 'starmap':
            current_cell_type_data = 'STARmap_excitatory'
        else:
            current_cell_type_data = "Excitatory"
            

        params = {'n_classes_1': list(search_range_dict[current_cell_type_data]), 'soft_weights': [True, False],
                  'partial_fit_expert': [True, False]}

        keys, values = zip(*params.items())
        params_list = [dict(zip(keys, v)) for v in itertools.product(*values)]

        new_params_list = []
        for d in params_list:
            if d['n_classes_1'] == 1:
                if d['soft_weights'] and d['partial_fit_expert']:
                    # n_expert = 1, soft or hard are equivalent
                    new_params_list.append(d)
            else:
                if d['soft_weights'] == d['partial_fit_expert']:
                    new_params_list.append(d)
        ratio = 0.2

        # initialize with default values
        model_params_val = model_params.copy()
        model_params_val['n_epochs'] = 1  # increase for validation models to converge
        model_params_val['tolerance'] = 0
        print(f"Default model parameters for validation {model_params_val}")
        model = hme(**model_params_val)

        gs = gridSearch(params, model, ratio, n_sets, new_params_list)
        gs.generate_val_sets(samples_train, meta_per_part)
        gs.runCV(X_trains_gs, Y_train_gs, meta_per_part, meta_idx, feature_types, data_type,
                 preprocess)
        gs.get_best_parameter()
        print(f"Best params from grid search: {gs.best_params}")

        # modify the parameter setting
        for key, value in gs.best_params.items():
            model_params[key] = value

        print(f"Model parameters for training after grid search {model_params}")

        filename = f"validation_results{suffix}.pickle"
        pickle.dump(gs, open(os.path.join(current_dir, filename), 'wb'))

    ### fit the full data with specified/selected hyperparameter  

    if grid_search and 'n_classes_1' in params:
        model = AgglomerativeClustering(n_clusters=gs.best_params['n_classes_1'])
    else:
        model = AgglomerativeClustering(n_classes_1)

    model = model.fit(Y_train)
    hier_labels = [model.labels_]
    model_params['init_labels_1'] = hier_labels

    # ------ construct MESSI  ------
    model = hme(**model_params)
    # train
    model.train(X_train, X_train_clf_1, X_train_clf_2, Y_train)

    ### save the model

    filename = f"hme_model{suffix}.pickle"

    pickle.dump(model, open(os.path.join(current_dir, filename), 'wb'))

    ## Make predictions on the test data

    ### load the saved model

    saved_model = pickle.load(open(os.path.join(current_dir, filename), 'rb'))

    ### make predictions

    Y_hat_final = saved_model.predict(X_test, X_test_clf_1, X_test_clf_2)
    print(f"Mean absolute value : {(abs(Y_test - Y_hat_final).mean(axis=1)).mean()}")
    print(f"{sex}_{behavior}_{celltype} inference completed!")
    return (abs(Y_test - Y_hat_final).mean(axis=1)).mean()

# XGBoost Implementation

In [None]:
# read in merfish dataset and get columns names
import pandas as pd

# get relevant data stuff
df_file = pd.ExcelFile("~/spatial/data/messi.xlsx")
messi_df = pd.read_excel(df_file, "All.Pairs")
merfish_df = pd.read_csv("~/spatial/data/raw/merfish.csv")
merfish_df = merfish_df.drop(['Blank_1', 'Blank_2', 'Blank_3', 'Blank_4', 'Blank_5', 'Fos'], axis=1)

# these are the 13 ligands or receptors found in MESSI
non_response_genes = ['Cbln1', 'Cxcl14', 'Crhbp', 'Gabra1', 'Cbln2', 'Gpr165', 
                      'Glra3', 'Gabrg1', 'Adora2a', 'Vgf', 'Scg2', 'Cartpt',
                      'Tac2']
# this list stores the control genes aka "Blank_{int}"
blank_genes = []

# we will populate all of the non-response genes as being in one or the other
# the ones already filled in come from the existing 13 L/R genes above
ligands = ["Cbln1", "Cxcl14", "Cbln2", "Vgf", "Scg2", "Cartpt", "Tac2"]
receptors = ["Crhbp", "Gabra1", "Gpr165", "Glra3", "Gabrg1", "Adora2a"]

# ligands and receptor indexes in MERFISH
non_response_indeces = [list(merfish_df.columns).index(gene)-9 for gene in non_response_genes]
ligand_indeces = [list(merfish_df.columns).index(gene)-9 for gene in ligands]
receptor_indeces = [list(merfish_df.columns).index(gene)-9 for gene in receptors]
all_pairs_columns = [
    "Ligand.ApprovedSymbol",
    "Receptor.ApprovedSymbol",
]


# for column name in the column names above
for column in all_pairs_columns:
    for gene in merfish_df.columns:
        if (
            gene.upper() in list(messi_df[column])
            and gene.upper() not in non_response_genes
        ):
            non_response_genes.append(gene)
            non_response_indeces.append(list(merfish_df.columns).index(gene)-9)
            if column[0] == "L":
                ligands.append(gene)
                ligand_indeces.append(list(merfish_df.columns).index(gene)-9)
            else:
                receptors.append(gene)
                receptor_indeces.append(list(merfish_df.columns).index(gene)-9)
        if gene[:5] == "Blank" and gene not in blank_genes:
            blank_genes.append(gene)
            # non_response_indeces.append(list(merfish_df.columns).index(gene)-9)

print(non_response_genes)
print(
    "There are "
    + str(len(non_response_genes))
    + " genes recognized as either ligands or receptors (including new ones)."
)

print(
    "There are "
    + str(len(blank_genes))
    + " blank genes."
)

print(
    "There are "
    + str(155 - len(blank_genes) - len(non_response_genes))
    + " genes that are treated as response variables."
)

print(
    "There are "
    + str(len(ligands))
    + " ligands."
)

print(
    "There are "
    + str(len(receptors))
    + " receptors."
)

response_indeces = list(set(range(155)) - set(non_response_indeces))

In [None]:
def get_neighbors(batch_obj):
    return [batch_obj.edge_index[:, batch_obj.edge_index[0] == i][1] for i in range(batch_obj.x.shape[0])]

In [None]:
def get_ligand_sum(data, neighbors_tensor, ligand_indeces):
    return torch.tensor([np.array(data.index_select(0, neighbors).index_select(1, torch.tensor(ligand_indeces))) for neighbors in neighbors_tensor]).sum(axis=1)

In [None]:
import torch.nn.functional as F

def get_celltypes(cell_behavior_tensor, neighbors_tensor):
    # print(f"There are {(num_classes := cell_behavior_tensor.max() + 1)} different cell types.")
    return [F.one_hot(cell_behavior_tensor.index_select(0, neighbors), num_classes=num_classes) for neighbors in neighbors_tensor] 

In [None]:
def get_celltype_simplex(cell_behavior_tensor, neighbors_tensor):
    print(f"There are {(num_classes := 16)} different cell types.")
    return torch.cat([(torch.mean(1.0*F.one_hot(cell_behavior_tensor.index_select(0, neighbors), num_classes=num_classes), dim=0)).unsqueeze(0) for neighbors in neighbors_tensor], dim=0)

# Table 1

###### XGBoost

In [None]:
behaviors = ["Naive"]
sexes = ["Female"]

with open('animal_id.json') as json_file:
    animals = json.load(json_file)

# loss_dict = {}
time_dict = {}
loss_dict = {}
# loss_inhibitory_dict = {}
gene_loss_dict = {}

for behavior in behaviors:
    for sex in sexes:
        try:
            animal_list = animals[behavior][sex]
        except KeyError:
            continue
        behavior = [behavior]
        sex = [sex]
        # print(behavior, sex, animal_list)
        for animal in animal_list:
            start = time.time()
            trial_run = FilteredMerfishDataset('data', sexes=sex, behaviors=behavior, test_animal=animal)
            print(sex, behavior, animal)
            datalist = trial_run.construct_graphs(3, True)
            print(len(datalist))
            start = time.time()

            train_dataset = None

            for batch in datalist:
                # gene expressions of cell i
                x = batch.x
                # position coordinate of gene i
                pos = batch.pos
                bregma = torch.tensor([batch.bregma]*pos.shape[0]).reshape(-1,1)

                # behavior and cell_type
                behavior_and_cell_type = batch.y

                # get the neighbors of the current batch
                neighbors = get_neighbors(batch)

                # get the sum of the ligand expressions for each cell in the current batch
                total_ligands = get_ligand_sum(x, neighbors, ligand_indeces)

                # get the proportion of celltypes as one-hot encoded vectors
                celltype_proportions = get_celltype_simplex(behavior_and_cell_type[:, 1], neighbors)

                # combine all the data
                X = torch.cat((x[:, non_response_indeces], total_ligands, pos, bregma, celltype_proportions), dim=1)
#                 scaler = StandardScaler().fit(X)
#                 X = torch.tensor(scaler.transform(X))

                if train_dataset is None:
                    train_dataset = X
                    train_Y = x[:, response_indeces]
                else:
                    train_dataset = torch.cat((train_dataset, X), dim=0)
                    train_Y = torch.cat((train_Y, x[:, response_indeces]), dim=0)

                print(f"Batch: {datalist.index(batch)+1}/{len(datalist)}")
                
            scaler = StandardScaler().fit(train_dataset)
            train_dataset = torch.tensor(scaler.transform(train_dataset))

            assert train_dataset.shape[0] == train_Y.shape[0]

            test_datalist = trial_run.construct_graphs(3, False)

            test_dataset = None

            for batch in test_datalist:
                # gene expressions of cell i
                x = batch.x
                # position coordinate of gene i
                pos = batch.pos
                bregma = torch.tensor([batch.bregma]*pos.shape[0]).reshape(-1,1)

                # behavior and cell_type
                behavior_and_cell_type = batch.y

                # get the neighbors of the current batch
                neighbors = get_neighbors(batch)

                # get the sum of the ligand expressions for each cell in the current batch
                total_ligands = get_ligand_sum(x, neighbors, ligand_indeces)

                # get the proportion of celltypes as one-hot encoded vectors
                celltype_proportions = get_celltype_simplex(behavior_and_cell_type[:, 1], neighbors)

                # combine all the data
                test_X = torch.cat((x[:, non_response_indeces], total_ligands, pos, bregma, celltype_proportions), dim=1)

                # standardize the data using TRAINING mean and sd.
                test_X = torch.tensor(scaler.transform(test_X))

                if test_dataset is None:
                    test_dataset = test_X
                    test_Y = x[:, response_indeces]
                    print(test_dataset.shape[0], test_Y.shape[0])
                else:
                    test_dataset = torch.cat((test_dataset, test_X), dim=0)
                    test_Y = torch.cat((test_Y, x[:, response_indeces]), dim=0)

                print(f"Batch: {test_datalist.index(batch)+1}/{len(test_datalist)}")

            test_dataset = torch.tensor(scaler.transform(test_dataset))
            
            assert test_dataset.shape[0] == test_Y.shape[0]

            model_list = []
            MAE_list = []

            # for each response gene in our response matrix....
            for i in range(train_Y.shape[1]):
                print(train_dataset.shape)

                # create response gene variables
                y_i_train = train_Y[:, i]
                y_i_test = test_Y[:, i]
                
                if i == 0 and behavior == ["Naive"] and animal == 1:
                    print(f"trainY Mean: {torch.mean(y_i_train)}")
                    print(f"testY Mean: {torch.mean(y_i_test)}")

                # create XGBoost Regression Model
                model = xgb.XGBRegressor(tree_method="gpu_hist", nthread=1, objective="reg:squarederror", eval_metric="mae")

                # fit the regression model and add it to model list
                # print(train_dataset.shape, y_i_train.shape)
                model.fit(np.array(train_dataset), np.array(y_i_train))
                model_list.append((f"Gene {i}", model_list))

                # run the testing data through the model
                test_output = torch.tensor(model.predict(np.array(test_dataset)))

                # collect its MAE
                MAE_list.append(F.l1_loss(test_output, y_i_test))
                print(MAE_list[-1])
                print(f"Response Gene {merfish_df.columns[9:][response_indeces[i]]} MAE: {MAE_list[-1].item()}")
                gene_loss_dict[merfish_df.columns[9:][response_indeces[i]]] = MAE_list[-1].item()
                print(f"Response Gene: {i+1}/{train_Y.shape[1]}")

            end = time.time()
            time_dict[f"{sex}_{behavior}_{animal}"] = end-start
            loss_dict[f"{sex}_{behavior}_{animal}"] = float(np.mean(MAE_list))

            with open("XGBoost_time_all.json", "w") as outfile:
                json.dump(time_dict, outfile, indent=4)

            with open("XGBoost_MAE.json", "w") as outfile:
                json.dump(loss_dict, outfile, indent=4)

            print(f"Test animal {animal} CV finished!")

##### LightGBM

In [None]:
response_genes=['Ace2', 'Aldh1l1', 'Amigo2', 'Ano3', 'Aqp4', 'Ar', 'Arhgap36',
       'Baiap2', 'Ccnd2', 'Cd24a', 'Cdkn1a', 'Cenpe', 'Chat', 'Coch',
       'Col25a1', 'Cplx3', 'Cpne5', 'Creb3l1', 'Cspg5', 'Cyp19a1',
       'Cyp26a1', 'Dgkk', 'Ebf3', 'Egr2', 'Ermn', 'Esr1', 'Etv1',
       'Fbxw13', 'Fezf1', 'Gbx2', 'Gda', 'Gem', 'Gjc3', 'Greb1',
       'Irs4', 'Isl1', 'Klf4', 'Krt90', 'Lmod1', 'Man1a', 'Mbp', 'Mki67',
       'Mlc1', 'Myh11', 'Ndnf', 'Ndrg1', 'Necab1', 'Nnat', 'Nos1',
       'Npas1', 'Nup62cl', 'Omp', 'Onecut2', 'Opalin', 'Pak3', 'Pcdh11x',
       'Pgr', 'Plin3', 'Pou3f2', 'Rgs2', 'Rgs5', 'Rnd3', 'Scgn',
       'Serpinb1b', 'Sgk1', 'Slc15a3', 'Slc17a6', 'Slc17a8', 'Slco1a4',
       'Sln', 'Sox4', 'Sox6', 'Sox8', 'Sp9', 'Synpr', 'Syt2', 'Syt4',
       'Sytl4', 'Th', 'Tiparp', 'Tmem108', 'Traf4', 'Ttn', 'Ttyh2']

import time
import json
from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.preprocessing import StandardScaler

all_MAEs = []

time_dict = {}
L1_loss_dict = {}

for animal in [1,2,3,4]:
    start = time.time()
    MAE_list = []
    for target_gene in response_genes:
        neighset=genetypes['ligands']
        oset=np.r_[genetypes['ligands'],genetypes['receptors']]
        # oset=neighset

        # oset=[]
        # neighset=[]
        
        train_animals = [1,2,3,4]
        train_animals.remove(animal)
        print(train_animals)
        # FIX THIS SO THAT ONLY FIRST 4 ANIMALS GET USED
        trainX,trainY=construct_problem((ad.obs['Animal_ID']!=animal)&(ad.obs['Animal_ID']<=4),target_gene,neighset,oset,filter_excitatory=True)
        testX,testY=construct_problem((ad.obs['Animal_ID']==animal),target_gene,neighset,oset,filter_excitatory=True)

        print(trainX.shape,trainY.shape)
        print(testX.shape,testY.shape)

        # whiten covariates
        scaler = StandardScaler().fit(trainX)
        trainX = scaler.transform(trainX)
        testX = scaler.transform(testX)
        
        model=HistGradientBoostingRegressor(loss="absolute_error")
        model.fit(trainX,trainY)
        MAE_list.append(np.mean(np.abs(model.predict(testX)-testY)))

    end = time.time()
    time_dict[f"Female_Naive_{animal}"] = end-start
    L1_loss_dict[f"Female_Naive_{animal}"] = float(np.mean(MAE_list))

    with open("XGBoost_L1_time_excitatory.json", "w") as outfile:
        json.dump(time_dict, outfile, indent=4)

    with open("XGBoost_L1_MAE_excitatory.json", "w") as outfile:
        json.dump(L1_loss_dict, outfile, indent=4)
    
    all_MAEs.append(np.mean(MAE_list))
    
print(np.mean(all_MAEs))

##### deepST

In [None]:
animal_list = [1,2,3,4]
loss_dict = {}

for animal in animal_list:
    with initialize(config_path="../config"):
        overrides_train = {
            "datasets": "FilteredMerfishDataset",
            "gpus": "[4]"
        }
        overrides_train_list = [f"{k}={v}" for k, v in overrides_train.items()]
        cfg_from_terminal = compose(config_name="config", overrides=overrides_train_list)
        # update the behavior to get the model of interest
        OmegaConf.update(cfg_from_terminal, "datasets.dataset.test_animal", animal)
        model, trainer = train(cfg_from_terminal)
        output = test(cfg_from_terminal)
        trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
        loss_dict.append((animal, test_results[0]['test_loss']))

##### MESSI

In [None]:
import itertools as it

merfish_df = pd.read_csv('../../spatial/data/raw/merfish.csv')
sexes = merfish_df["Animal_sex"].unique()
behaviors = merfish_df["Behavior"].unique()
celltypes = merfish_df["Cell_class"].unique()
cell_categories = list(it.product(sexes, behaviors, celltypes))

Running MESSI with general celltype.

In [None]:
import time
import json

loss_dict = {}
time_dict = {}
animals = [1,2,3,4]
for sex in ['Female']:
    for behavior in ['Naive']:
        loss_dict[f"{sex}_{behavior}_general"]
        for animal in animals:
            filtered_df = merfish_df[(merfish_df['Animal_sex'] == sex) & (merfish_df['Behavior'] == behavior)]
            if len(filtered_df["Animal_ID"].unique()) != 0:
                start = time.time()
                MAE = MESSI(sex, behavior, 'general', list(set(animals) - set(animal)))
                end = time.time()
                # time_dict[f"{sex}_{behavior}_general"] = end-start
                loss_dict[f"{sex}_{behavior}_general"] += MAE/len(animals)
            
MAE_results = json.dumps(loss_dict, indent=4)
time_results = json.dumps(time_dict, indent=4)

# Table 2

###### XGBoost

In [None]:
behaviors = ["Naive"]
sexes = ["Female"]

with open('animal_id.json') as json_file:
    animals = json.load(json_file)

# loss_dict = {}
time_dict = {}
loss_excitatory_dict = {}
# loss_inhibitory_dict = {}
gene_loss_dict = {}

for behavior in behaviors:
    for sex in sexes:
        try:
            animal_list = animals[behavior][sex]
        except KeyError:
            continue
        behavior = [behavior]
        sex = [sex]
        # print(behavior, sex, animal_list)
        for animal in animal_list:
            start = time.time()
            trial_run = FilteredMerfishDataset('data', sexes=sex, behaviors=behavior, test_animal=animal)
            print(sex, behavior, animal)
            datalist = trial_run.construct_graphs(3, True)
            print(len(datalist))
            start = time.time()

            train_dataset = None

            for batch in datalist:
                # gene expressions of cell i
                x = batch.x
                # position coordinate of gene i
                pos = batch.pos
                bregma = torch.tensor([batch.bregma]*pos.shape[0]).reshape(-1,1)

                # behavior and cell_type
                behavior_and_cell_type = batch.y

                # get the neighbors of the current batch
                neighbors = get_neighbors(batch)

                # get the sum of the ligand expressions for each cell in the current batch
                total_ligands = get_ligand_sum(x, neighbors, ligand_indeces)

                # get the proportion of celltypes as one-hot encoded vectors
                celltype_proportions = get_celltype_simplex(behavior_and_cell_type[:, 1], neighbors)

                # combine all the data
                X = torch.cat((x[:, non_response_indeces], total_ligands, pos, bregma, celltype_proportions), dim=1)
                excitatory_cells = (behavior_and_cell_type[:, 1] == 6).nonzero(as_tuple=True)[0]
                X = torch.index_select(X, 0, excitatory_cells)
#                 scaler = StandardScaler().fit(X)
#                 X = torch.tensor(scaler.transform(X))

                if train_dataset is None:
                    train_dataset = X
                    train_Y = torch.index_select(x[:, response_indeces], 0, excitatory_cells)
                else:
                    train_dataset = torch.cat((train_dataset, X), dim=0)
                    train_Y = torch.cat((train_Y, torch.index_select(x[:, response_indeces], 0, excitatory_cells)), dim=0)

                print(f"Batch: {datalist.index(batch)+1}/{len(datalist)}")
                
            scaler = StandardScaler().fit(train_dataset)
            train_dataset = torch.tensor(scaler.transform(train_dataset))

            assert train_dataset.shape[0] == train_Y.shape[0]

            test_datalist = trial_run.construct_graphs(3, False)

            test_dataset = None

            for batch in test_datalist:
                # gene expressions of cell i
                x = batch.x
                # position coordinate of gene i
                pos = batch.pos
                bregma = torch.tensor([batch.bregma]*pos.shape[0]).reshape(-1,1)

                # behavior and cell_type
                behavior_and_cell_type = batch.y

                # get the neighbors of the current batch
                neighbors = get_neighbors(batch)

                # get the sum of the ligand expressions for each cell in the current batch
                total_ligands = get_ligand_sum(x, neighbors, ligand_indeces)

                # get the proportion of celltypes as one-hot encoded vectors
                celltype_proportions = get_celltype_simplex(behavior_and_cell_type[:, 1], neighbors)

                # combine all the data
                test_X = torch.cat((x[:, non_response_indeces], total_ligands, pos, bregma, celltype_proportions), dim=1)

                # standardize the data using TRAINING mean and sd.
                excitatory_cells = (behavior_and_cell_type[:, 1] == 6).nonzero(as_tuple=True)[0]
                test_X = torch.index_select(test_X, 0, excitatory_cells)
#                 test_X = torch.tensor(scaler.transform(test_X))

                if test_dataset is None:
                    test_dataset = test_X
                    test_Y = torch.index_select(x[:, response_indeces], 0, excitatory_cells)
                    print(test_dataset.shape[0], test_Y.shape[0])
                else:
                    test_dataset = torch.cat((test_dataset, test_X), dim=0)
                    test_Y = torch.cat((test_Y, torch.index_select(x[:, response_indeces], 0, excitatory_cells)), dim=0)

                print(f"Batch: {test_datalist.index(batch)+1}/{len(test_datalist)}")

            test_dataset = torch.tensor(scaler.transform(test_dataset))
            
            assert test_dataset.shape[0] == test_Y.shape[0]

            model_list = []
            MAE_list = []

            # for each response gene in our response matrix....
            for i in range(train_Y.shape[1]):

                # create response gene variables
                y_i_train = train_Y[:, i]
                y_i_test = test_Y[:, i]
                
                if i == 0 and behavior == ["Naive"] and animal == 1:
                    print(f"trainY Mean: {torch.mean(y_i_train)}")
                    print(f"testY Mean: {torch.mean(y_i_test)}")

                # create XGBoost Regression Model
                model = xgb.XGBRegressor(tree_method="gpu_hist", nthread=1, objective="reg:squarederror", eval_metric="mae")

                # fit the regression model and add it to model list
                # print(train_dataset.shape, y_i_train.shape)
                model.fit(np.array(train_dataset), np.array(y_i_train))
                model_list.append((f"Gene {i}", model_list))

                # run the testing data through the model
                test_output = torch.tensor(model.predict(np.array(test_dataset)))

                # collect its MAE
                MAE_list.append(F.l1_loss(test_output, y_i_test))
                print(MAE_list[-1])
                print(f"Response Gene {merfish_df.columns[9:][response_indeces[i]]} MAE: {MAE_list[-1].item()}")
                gene_loss_dict[merfish_df.columns[9:][response_indeces[i]]] = MAE_list[-1].item()
                print(f"Response Gene: {i+1}/{train_Y.shape[1]}")

            end = time.time()
            time_dict[f"{sex}_{behavior}_{animal}"] = end-start
            loss_excitatory_dict[f"{sex}_{behavior}_{animal}"] = float(np.mean(MAE_list))

            with open("XGBoost_time_excitatory.json", "w") as outfile:
                json.dump(time_dict, outfile, indent=4)

            with open("XGBoost_MAE_excitatory.json", "w") as outfile:
                json.dump(loss_excitatory_dict, outfile, indent=4)

            print(f"Test animal {animal} CV finished!")

##### LightGBM

In [None]:
response_genes=['Ace2', 'Aldh1l1', 'Amigo2', 'Ano3', 'Aqp4', 'Ar', 'Arhgap36',
       'Baiap2', 'Ccnd2', 'Cd24a', 'Cdkn1a', 'Cenpe', 'Chat', 'Coch',
       'Col25a1', 'Cplx3', 'Cpne5', 'Creb3l1', 'Cspg5', 'Cyp19a1',
       'Cyp26a1', 'Dgkk', 'Ebf3', 'Egr2', 'Ermn', 'Esr1', 'Etv1',
       'Fbxw13', 'Fezf1', 'Gbx2', 'Gda', 'Gem', 'Gjc3', 'Greb1',
       'Irs4', 'Isl1', 'Klf4', 'Krt90', 'Lmod1', 'Man1a', 'Mbp', 'Mki67',
       'Mlc1', 'Myh11', 'Ndnf', 'Ndrg1', 'Necab1', 'Nnat', 'Nos1',
       'Npas1', 'Nup62cl', 'Omp', 'Onecut2', 'Opalin', 'Pak3', 'Pcdh11x',
       'Pgr', 'Plin3', 'Pou3f2', 'Rgs2', 'Rgs5', 'Rnd3', 'Scgn',
       'Serpinb1b', 'Sgk1', 'Slc15a3', 'Slc17a6', 'Slc17a8', 'Slco1a4',
       'Sln', 'Sox4', 'Sox6', 'Sox8', 'Sp9', 'Synpr', 'Syt2', 'Syt4',
       'Sytl4', 'Th', 'Tiparp', 'Tmem108', 'Traf4', 'Ttn', 'Ttyh2']

import time
import json
from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.preprocessing import StandardScaler

all_MAEs = []

time_dict = {}
L1_loss_dict = {}

for animal in [1,2,3,4]:
    start = time.time()
    MAE_list = []
    for target_gene in response_genes:
        neighset=genetypes['ligands']
        oset=np.r_[genetypes['ligands'],genetypes['receptors']]
        # oset=neighset

        # oset=[]
        # neighset=[]
        
        train_animals = [1,2,3,4]
        train_animals.remove(animal)
        print(train_animals)
        # FIX THIS SO THAT ONLY FIRST 4 ANIMALS GET USED
        trainX,trainY=construct_problem((ad.obs['Animal_ID']!=animal)&(ad.obs['Animal_ID']<=4),target_gene,neighset,oset,filter_excitatory=True)
        testX,testY=construct_problem((ad.obs['Animal_ID']==animal),target_gene,neighset,oset,filter_excitatory=True)

        print(trainX.shape,trainY.shape)
        print(testX.shape,testY.shape)

        # whiten covariates
        scaler = StandardScaler().fit(trainX)
        trainX = scaler.transform(trainX)
        testX = scaler.transform(testX)
        
        model=HistGradientBoostingRegressor(loss="absolute_error")
        model.fit(trainX,trainY)
        MAE_list.append(np.mean(np.abs(model.predict(testX)-testY)))

    end = time.time()
    time_dict[f"Female_Naive_{animal}"] = end-start
    L1_loss_dict[f"Female_Naive_{animal}"] = float(np.mean(MAE_list))

    with open("XGBoost_L1_time_excitatory.json", "w") as outfile:
        json.dump(time_dict, outfile, indent=4)

    with open("XGBoost_L1_MAE_excitatory.json", "w") as outfile:
        json.dump(L1_loss_dict, outfile, indent=4)
    
    all_MAEs.append(np.mean(MAE_list))
    
print(np.mean(all_MAEs))

##### deepST

In [None]:
animal_list = [1,2,3,4]
loss_dict = {}

for animal in animal_list:
    with initialize(config_path="../../config"):
        overrides_train = {
            "datasets": "FilteredMerfishDataset",
            "gpus": "[4]",
            "model.kwargs.celltypes": ["Excitatory"],
            "datasets.dataset.test_animal": animal,
        }
        overrides_train_list = [f"{k}={v}" for k, v in overrides_train.items()]
        cfg_from_terminal = compose(config_name="config", overrides=overrides_train_list)
        model, trainer = train(cfg_from_terminal)
        output = test(cfg_from_terminal)
        trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
        loss_dict.append((animal, test_results[0]['test_loss']))

##### MESSI

In [None]:
import time
import json

loss_dict = {}
time_dict = {}
for sex in ['Female']:
    for behavior in ['Naive']:
        if len(filtered_df["Animal_ID"].unique()) != 0:
            start = time.time()
            MAE = MESSI(sex, behavior, 'Excitatory')
            end = time.time()
            time_dict[f"{sex}_{behavior}_Excitatory"] = end-start
            loss_dict[f"{sex}_{behavior}_Excitatory"] = MAE
            
MAE_results = json.dumps(loss_dict, indent=4)
time_results = json.dumps(time_dict, indent=4)

# Figure 4

###### XGBoost

In [None]:
time_dict = {}

slices = [[1,4], [1,2,4], [1,2,4,7], [1,2,3,4,5,7]]

for i, animal_slices in enumerate(slices):

    trial_run = FilteredMerfishDataset('data', animals=animal_slices, test_animal=4)
    datalist = trial_run.construct_graphs(3, True)

    train_dataset = None

    for batch in datalist:
        # gene expressions of cell i
        x = batch.x
        # position coordinate of gene i
        pos = batch.pos
        bregma = torch.tensor([batch.bregma]*pos.shape[0]).reshape(-1,1)

        # behavior and cell_type
        behavior_and_cell_type = batch.y

        # get the neighbors of the current batch
        neighbors = get_neighbors(batch)

        # get the sum of the ligand expressions for each cell in the current batch
        total_ligands = get_ligand_sum(x, neighbors, ligand_indeces)

        # get the proportion of celltypes as one-hot encoded vectors
        celltype_proportions = get_celltype_simplex(behavior_and_cell_type[:, 1], neighbors)

        # combine all the data
        X = torch.cat((x[:, non_response_indeces], total_ligands, pos, bregma, celltype_proportions), dim=1)
    #                 scaler = StandardScaler().fit(X)
    #                 X = torch.tensor(scaler.transform(X))

        if train_dataset is None:
            train_dataset = X
            train_Y = x[:, response_indeces]
        else:
            train_dataset = torch.cat((train_dataset, X), dim=0)
            train_Y = torch.cat((train_Y, x[:, response_indeces]), dim=0)

        print(f"Batch: {datalist.index(batch)+1}/{len(datalist)}")

    scaler = StandardScaler().fit(train_dataset)
    train_dataset = torch.tensor(scaler.transform(train_dataset))

    test_datalist = trial_run.construct_graphs(3, False)

    test_dataset = None

    for batch in test_datalist:
        # gene expressions of cell i
        x = batch.x
        # position coordinate of gene i
        pos = batch.pos
        bregma = torch.tensor([batch.bregma]*pos.shape[0]).reshape(-1,1)

        # behavior and cell_type
        behavior_and_cell_type = batch.y

        # get the neighbors of the current batch
        neighbors = get_neighbors(batch)

        # get the sum of the ligand expressions for each cell in the current batch
        total_ligands = get_ligand_sum(x, neighbors, ligand_indeces)

        # get the proportion of celltypes as one-hot encoded vectors
        celltype_proportions = get_celltype_simplex(behavior_and_cell_type[:, 1], neighbors)

        # combine all the data
        test_X = torch.cat((x[:, non_response_indeces], total_ligands, pos, bregma, celltype_proportions), dim=1)

        # standardize the data using TRAINING mean and sd.
        test_X = torch.tensor(scaler.transform(test_X))

        if test_dataset is None:
            test_dataset = test_X
            test_Y = x[:, response_indeces]
            print(test_dataset.shape[0], test_Y.shape[0])
        else:
            test_dataset = torch.cat((test_dataset, test_X), dim=0)
            test_Y = torch.cat((test_Y, x[:, response_indeces]), dim=0)

        print(f"Batch: {test_datalist.index(batch)+1}/{len(test_datalist)}")

    test_dataset = torch.tensor(scaler.transform(test_dataset))

    model_list = []
    MAE_list = []

    # for each response gene in our response matrix....
    for i in range(train_Y.shape[1]):
        print(train_dataset.shape)

        # create response gene variables
        y_i_train = train_Y[:, i]
        y_i_test = test_Y[:, i]

        # create XGBoost Regression Model
        model = xgb.XGBRegressor(tree_method="gpu_hist", nthread=1, objective="reg:squarederror", eval_metric="mae")

        # fit the regression model and add it to model list
        # print(train_dataset.shape, y_i_train.shape)
        model.fit(np.array(train_dataset), np.array(y_i_train))
        model_list.append((f"Gene {i}", model_list))

        # run the testing data through the model
        test_output = torch.tensor(model.predict(np.array(test_dataset)))

        # collect its MAE
        MAE_list.append(F.l1_loss(test_output, y_i_test))
        print(MAE_list[-1])
        print(f"Response Gene {merfish_df.columns[9:][response_indeces[i]]} MAE: {MAE_list[-1].item()}")
        gene_loss_dict[merfish_df.columns[9:][response_indeces[i]]] = MAE_list[-1].item()
        print(f"Response Gene: {i+1}/{train_Y.shape[1]}")

    end = time.time()
    time_dict[(i+1)*12] = end-start

time_dict

##### LightGBM

CONSTRUCT EACH PROBLEM FOR THE TIMING MANUALLY

In [None]:
neighset=genetypes['ligands']
oset=np.r_[genetypes['ligands'],genetypes['receptors']]
# oset=neighset

# oset=[]
# neighset=[]

trainX,trainY,feature_names=construct_problem((ad.obs['Animal_ID']>=2)&(ad.obs['Animal_ID']<=4),'Th',neighset,oset)
testX,testY,feature_names=construct_problem((ad.obs['Animal_ID']==1),'Th',neighset,oset)

print(trainX.shape,trainY.shape)
print(testX.shape,testY.shape)

##### deepST

In [None]:
time_dict = {}

train_animal_sets = [[1,4], [1,2,4], [1,2,4,7], [1,2,3,4,5,7]]

for i, animal_slices in enumerate(train_animal_sets):
    start = time.time()
    with initialize(config_path="../config"):
        overrides_train = {
            "datasets": "FilteredMerfishDataset",
            "gpus": "[4]"
        }
        overrides_train_list = [f"{k}={v}" for k, v in overrides_train.items()]
        cfg_from_terminal = compose(config_name="config", overrides=overrides_train_list)
        # update the behavior to get the model of interest
        OmegaConf.update(cfg_from_terminal, "datasets.dataset.animals", animal_slices)
        OmegaConf.update(cfg_from_terminal, "datasets.dataset.test_animal", 4)
        model, trainer = train(cfg_from_terminal)
    end = time.time()
    time_dict[(i+1)*12] = end-start
    
time_dict

##### MESSI

In [None]:
[3], [1], [1,3], [1,2], 

In [None]:
time_dict = {}
# has 6, 12, 18, 24, and 30 slices combined, respectively
train_animal_sets = [[1,2,3]]
sex = "Female"
behavior = "Naive"
for i, animal_slices in enumerate(train_animal_sets):
    start = time.time()
    MAE = MESSI(sex, behavior, 'general', animal_slices)
    end = time.time()
    time_dict[(i+1)*6] = end-start

print(time_dict)

In [None]:
time_dict

- [ ] Add the time in for 12 and 24 since 36 and 48 are impossible.

# Table 3

##### LightGBM

In [None]:
test_loss_rad_dict = {}

# for each radius value....
for radius in range(0, 90, 10):
    
    ad=anndata.read_h5ad(h5ad_location)
    row=np.zeros(0,dtype=int)
    col=np.zeros(0,dtype=int)
    mode="rad"

    for tid in tqdm.notebook.tqdm(np.unique(ad.obs['Tissue_ID'])):
        good=ad.obs['Tissue_ID']==tid
        pos=np.array(ad.obs[good][['Centroid_X','Centroid_Y']])
        if mode == "neighbors":
            if nneigh == 0:
                E = csr_matrix(np.eye(pos.shape[0]))
            else:
                p=sklearn.neighbors.BallTree(pos)
                E=sklearn.neighbors.kneighbors_graph(pos,nneigh,mode='connectivity')
            col=np.r_[col,idxs[E.tocoo().col]]
            row=np.r_[row,idxs[E.tocoo().row]]
        if mode == "rad":
            p=sp.spatial.cKDTree(pos)
            E=p.query_ball_point(pos, r=radius, return_sorted=False)
        idxs=np.where(good)[0]


    E=sp.sparse.coo_matrix((np.ones(len(col)),(row,col)),shape=(len(ad),len(ad))).tocsr()
    if mode == "neighbors":
        anndata.AnnData(E).write_h5ad(connectivity_matrix_template%nneigh)
    if mode == "rad":
        anndata.AnnData(E).write_h5ad(connectivity_matrix_template%radius)
    
    # load data
    ad=anndata.read_h5ad(h5ad_location)
    if mode == "neighbors":
        connectivity_matrix=anndata.read_h5ad(connectivity_matrix_template%nneigh).X
    if mode == "rad":
         connectivity_matrix=anndata.read_h5ad(connectivity_matrix_template%radius).X
    gene_lookup={x:i for (i,x) in enumerate(ad.var.index)}

    with open(genetypes_location,'rb') as f:
        genetypes=pickle.load(f)
    
    for single_gene in ['Pak3', 'Slc17a8', 'Nnat', 'Th']:
        
        neighset=genetypes['ligands']
        oset=np.r_[genetypes['ligands'],genetypes['receptors']]

        trainX,trainY,feature_names=construct_problem((ad.obs['Animal_ID']<=30), single_gene, neighset,oset)
        testX,testY,feature_names=construct_problem((ad.obs['Animal_ID']>30), single_gene, neighset,oset)
        
        # whiten covariates
        scaler = StandardScaler().fit(trainX)
        trainX = scaler.transform(trainX)
        testX = scaler.transform(testX)
        
        for num_nodes in [10,50,100,250,500,1000,2500]:
            for lr in [0.001, 0.01, 0.1]:
                for l2 in [0, 1, 10]:
                    model=HistGradientBoostingRegressor(loss="absolute_error", max_leaf_nodes=num_nodes, learning_rate=lr, l2_regularization=l2)
                    model.fit(trainX,trainY)
                    test_loss_rad_dict[(radius, single_gene, num_nodes, lr, l2)] = np.mean(np.abs(model.predict(testX)-testY))
                    print((radius, single_gene, num_nodes, lr, l2), test_loss_rad_dict[(radius, single_gene, num_nodes, lr, l2)])

In [None]:
test_loss_rad_dict

##### deepST

In [None]:
test_loss_rad_dict = {}

# for each radius value....
for radius in range(0, 90, 10):
    
    for single_gene in [[93], [116], [142]]:

        # setup framework
        with initialize(config_path="../../config"):
            overrides_train = {
                "datasets": "MerfishDataset",
                "gpus": "[5]",
                "radius": radius,
                "model.kwargs.response_genes": single_gene,
                "training.logger_name": "figure4deepST"
            }
            overrides_train_list = [f"{k}={v}" for k, v in overrides_train.items()]
            cfg_from_terminal = compose(config_name="config", overrides=overrides_train_list)

            # complete training
            model, trainer = train(cfg_from_terminal)

            # load the model with the lowest validation loss
            validation_setup = train(cfg_from_terminal, validate_only=True)

            # run that model on testing data
            output = test(cfg_from_terminal)
            trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
            test_loss_rad_dict[(radius, single_gene[0])] = test_results[0]['test_loss']

##### deepST General

In [None]:
test_loss_rad_dict = {}

# for each radius value....
for radius in range(80, 90, 10):

        # setup framework
        with initialize(config_path="../../config"):
            overrides_train = {
                "datasets": "MerfishDataset",
                "gpus": "[6]",
                "radius": radius,
                "training.logger_name": "figure4deepST_general"
            }
            overrides_train_list = [f"{k}={v}" for k, v in overrides_train.items()]
            cfg_from_terminal = compose(config_name="config", overrides=overrides_train_list)

            
            # complete training
            model, trainer = train(cfg_from_terminal)

            # run that model on testing data
            output = test(cfg_from_terminal)
            trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
            for single_gene in [[93], [116], [142], [151]]:
                test_loss_rad_dict[(radius, single_gene[0])] = torch.mean(np.abs((inputs - gene_expressions)[:, single_gene[0]])).item()

In [None]:
test_loss_rad_dict

In [None]:
# in the event you can't train in one go
for radius in range(0, 90, 10):
    # setup framework
    with initialize(config_path="../../config"):
        overrides_train = {
            "datasets": "MerfishDataset",
            "gpus": "[6]",
            "radius": radius,
            "training.logger_name": "figure4deepST_general"
        }
        overrides_train_list = [f"{k}={v}" for k, v in overrides_train.items()]
        cfg_from_terminal = compose(config_name="config", overrides=overrides_train_list)

        # complete training
        output = test(cfg_from_terminal)
        trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
        for single_gene in [[93], [116], [142], [151]]:
            test_loss_rad_dict[(radius, single_gene[0])] = torch.mean(np.abs((inputs - gene_expressions)[:, single_gene[0]])).item()

# Table 4

- [ ] Store training loss in dictionary.
- [ ] Store validation loss in dictionary.
- [ ] Perform the same analysis for LightGBM

# deepST

In [None]:
deepST_train_loss_rad_dict = {}
deepST_val_loss_rad_dict = {}
deepST_test_loss_rad_dict = {}

# for each radius value....
for radius in range(0, 90, 10):

    # setup framework
    with initialize(config_path="../../config"):
        overrides_train = {
            "datasets": "MerfishDataset",
            "gpus": "[5]",
            "radius": radius,
            "training.logger_name": "figure4deepST"
        }
        overrides_train_list = [f"{k}={v}" for k, v in overrides_train.items()]
        cfg_from_terminal = compose(config_name="config", overrides=overrides_train_list)
        
        # complete training
        model, trainer = train(cfg_from_terminal)
        # uncomment the beneath code to get the training loss at the END of traing
        # deepST_train_loss_rad_dict[radius] = trainer.logged_metrics['train_loss: mae']

        # load the model with the lowest validation loss
        validation_setup, val_trainer = train(cfg_from_terminal, validate_only=True)
        deepST_val_loss_rad_dict[radius] = val_trainer.logged_metrics['val_loss']
        
        # run that model on testing data
        output = test(cfg_from_terminal)
        trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
        deepST_test_loss_rad_dict[radius] = test_results[0]['test_loss']

In [None]:
train_loss_rad_dict, val_loss_rad_dict, test_loss_rad_dict

# LightGBM

In [None]:
response_genes=['Ace2', 'Aldh1l1', 'Amigo2', 'Ano3', 'Aqp4', 'Ar', 'Arhgap36',
       'Baiap2', 'Ccnd2', 'Cd24a', 'Cdkn1a', 'Cenpe', 'Chat', 'Coch',
       'Col25a1', 'Cplx3', 'Cpne5', 'Creb3l1', 'Cspg5', 'Cyp19a1',
       'Cyp26a1', 'Dgkk', 'Ebf3', 'Egr2', 'Ermn', 'Esr1', 'Etv1',
       'Fbxw13', 'Fezf1', 'Gbx2', 'Gda', 'Gem', 'Gjc3', 'Greb1',
       'Irs4', 'Isl1', 'Klf4', 'Krt90', 'Lmod1', 'Man1a', 'Mbp', 'Mki67',
       'Mlc1', 'Myh11', 'Ndnf', 'Ndrg1', 'Necab1', 'Nnat', 'Nos1',
       'Npas1', 'Nup62cl', 'Omp', 'Onecut2', 'Opalin', 'Pak3', 'Pcdh11x',
       'Pgr', 'Plin3', 'Pou3f2', 'Rgs2', 'Rgs5', 'Rnd3', 'Scgn',
       'Serpinb1b', 'Sgk1', 'Slc15a3', 'Slc17a6', 'Slc17a8', 'Slco1a4',
       'Sln', 'Sox4', 'Sox6', 'Sox8', 'Sp9', 'Synpr', 'Syt2', 'Syt4',
       'Sytl4', 'Th', 'Tiparp', 'Tmem108', 'Traf4', 'Ttn', 'Ttyh2']

import time
import json
from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import HistGradientBoostingRegressor

lightgbm_train_loss_dict = {}
lightgbm_test_loss_dict = {}
# for each radius value....
for radius in range(0, 90, 10):
    
    train_loss_list = []
    test_loss_list = []
    
    ad=anndata.read_h5ad(h5ad_location)
    row=np.zeros(0,dtype=int)
    col=np.zeros(0,dtype=int)
    mode="rad"

    for tid in tqdm.notebook.tqdm(np.unique(ad.obs['Tissue_ID'])):
        good=ad.obs['Tissue_ID']==tid
        pos=np.array(ad.obs[good][['Centroid_X','Centroid_Y']])
        if mode == "neighbors":
            if nneigh == 0:
                E = csr_matrix(np.eye(pos.shape[0]))
            else:
                p=sklearn.neighbors.BallTree(pos)
                E=sklearn.neighbors.kneighbors_graph(pos,nneigh,mode='connectivity')
            col=np.r_[col,idxs[E.tocoo().col]]
            row=np.r_[row,idxs[E.tocoo().row]]
        if mode == "rad":
            p=sp.spatial.cKDTree(pos)
            E=p.query_ball_point(pos, r=radius, return_sorted=False)
        idxs=np.where(good)[0]


    E=sp.sparse.coo_matrix((np.ones(len(col)),(row,col)),shape=(len(ad),len(ad))).tocsr()
    if mode == "neighbors":
        anndata.AnnData(E).write_h5ad(connectivity_matrix_template%nneigh)
    if mode == "rad":
        anndata.AnnData(E).write_h5ad(connectivity_matrix_template%radius)
    
    # load data
    ad=anndata.read_h5ad(h5ad_location)
    if mode == "neighbors":
        connectivity_matrix=anndata.read_h5ad(connectivity_matrix_template%nneigh).X
    if mode == "rad":
         connectivity_matrix=anndata.read_h5ad(connectivity_matrix_template%radius).X
    gene_lookup={x:i for (i,x) in enumerate(ad.var.index)}

    with open(genetypes_location,'rb') as f:
        genetypes=pickle.load(f)
    for target_gene in response_genes:
        neighset=genetypes['ligands']
        oset=np.r_[genetypes['ligands'],genetypes['receptors']]
        # oset=neighset

        # oset=[]
        # neighset=[]

        trainX,trainY,feature_names=construct_problem((ad.obs['Animal_ID']<=30),target_gene,neighset,oset,True)
        testX,testY,feature_names=construct_problem((ad.obs['Animal_ID']>30),target_gene,neighset,oset,True)

        # whiten covariates
        mu=np.mean(trainX,axis=0)
        sig=np.std(trainX,axis=0)
        trainX=(trainX-mu)/sig
        testX=(testX-mu)/sig

        model=HistGradientBoostingRegressor(loss="absolute_error")
        model.fit(trainX,trainY)
        train_loss_list.append(np.mean(np.abs(model.predict(trainX)-trainY)))
        test_loss_list.append(np.mean(np.abs(model.predict(testX)-testY)))
        print(f"Radius {radius}, Gene {target_gene} done.")

    lightgbm_train_loss_dict[radius] = np.mean(train_loss_list)
    lightgbm_test_loss_dict[radius] = np.mean(test_loss_list)
    
print(train_loss_dict, test_loss_dict)

In [None]:
lightgbm_train = {0: 0.36687154522253007, 10: 0.36704256779382094, 20: 0.3669286524358801, 30: 0.36704890036396814, 40: 0.3671144034553719, 50: 0.36696370154104097, 60: 0.36694528763440054, 70: 0.3670999470507582, 80: 0.3669915129489699}
lightgbm_test = {0: 0.3930677145164025, 10: 0.3932398374054737, 20: 0.39316476258093197, 30: 0.3931982199864055, 40: 0.3932709117850623, 50: 0.39314322041464844, 60: 0.3932798993305558, 70: 0.39334263072139725, 80: 0.3931518974793018}
deepST_train = {0: 0.3521350920200348, 10: 0.35147982835769653, 20: 0.3492688536643982, 30: 0.3481127917766571, 40: 0.34752157330513, 50: 0.34670382738113403, 60: 0.3451015055179596, 70: 0.3450486958026886, 80: 0.3453604280948639}
deepST_test = {0: 0.3500196635723114, 10: 0.35144540667533875, 20: 0.3483979403972626, 30: 0.3469792604446411, 40: 0.3454304039478302, 50: 0.3444381654262543, 60: 0.34396353363990784, 70: 0.34458568692207336, 80: 0.34435102343559265}

lightgbm_train_df = pd.DataFrame(lightgbm_train.items(), columns=['Radius', 'LightGBM Train Loss'])
lightgbm_test_df = pd.DataFrame(lightgbm_test.items(), columns=['Radius', 'LightGBM Test Loss'])
deepST_train_df = pd.DataFrame(deepST_test.items(), columns=['Radius', 'deepST Train Loss'])
deepST_test_df = pd.DataFrame(deepST_test.items(), columns=['Radius', 'deepST Test Loss'])

results_df = lightgbm_train_df.merge(lightgbm_test_df, on="Radius").merge(deepST_train_df, on="Radius").merge(deepST_test_df, on="Radius")

In [None]:
results_df

In [None]:
sns.set_theme("talk")
sns.set_style("whitegrid")
sns.lineplot(data=results_df, x="Radius", y="LightGBM Train Loss", marker='o')
sns.lineplot(data=results_df, x="Radius", y="LightGBM Test Loss", marker='o')
sns.lineplot(data=results_df, x="Radius", y="deepST Train Loss", marker='o')
sns.lineplot(data=results_df, x="Radius", y="deepST Test Loss", marker='o')
_ = plt.show()

# 0 vs 60 Radius Table

In [None]:
response_genes = [0,
 2,
 3,
 4,
 5,
 6,
 7,
 10,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 32,
 34,
 35,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 52,
 53,
 54,
 55,
 58,
 63,
 64,
 66,
 67,
 69,
 71,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 85,
 86,
 87,
 88,
 93,
 94,
 96,
 97,
 99,
 102,
 103,
 104,
 106,
 110,
 112,
 113,
 114,
 116,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 129,
 130,
 131,
 133,
 134,
 141,
 142,
 147,
 151]

In [None]:
# in the event you can't train in one go
test_loss_dict = {}
for response_gene in response_genes:
    for radius in [0, 60]:
        # setup framework
        with initialize(config_path="../../config"):
            overrides_train = {
                "datasets": "MerfishDataset",
                "gpus": "[6]",
                "radius": radius,
                "training.logger_name": "zero_vs_sixty",
                "model.kwargs.response_genes": [response_gene]
            }
            overrides_train_list = [f"{k}={v}" for k, v in overrides_train.items()]
            cfg_from_terminal = compose(config_name="config", overrides=overrides_train_list)

            # complete training
            output = test(cfg_from_terminal)
            trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
            test_loss_dict[(radius, response_gene)] = test_results[0]['test_loss']

In [None]:
test_loss_dict

In [None]:
data = pd.read_csv("../data/raw/merfish.csv")
data = data.drop(["Blank_1", "Blank_2", "Blank_3", "Blank_4", "Blank_5", "Fos"], axis=1)
data = data.iloc[:, 9:]

In [None]:
test_loss_dict_with_names = {(k[0], data.columns[k[1]]): v for k,v in test_loss_dict.items()}

In [None]:
test_loss_dict_with_names

In [None]:
zeros_vs_sixties = [(test_loss_dict[(0, response_gene)], test_loss_dict[(60, response_gene)]) for response_gene in response_genes if response_gene <= 119]

names = [data.columns[response_gene] for response_gene in response_genes]

gene_diff_dict = {}

for response_gene in response_genes:
    if response_gene <= 119:
        gene_diff_dict[data.columns[response_gene]] = test_loss_dict[(0, response_gene)] - test_loss_dict[(60, response_gene)]

In [None]:
gene_diff_dict

In [None]:
import json

with open('gene_diffs.json', 'w') as fp:
    json.dump(gene_diff_dict, fp)

In [None]:
zeros_vs_sixties = np.array(zeros_vs_sixties)
plt.scatter(zeros_vs_sixties[:, 0], zeros_vs_sixties[:, 1], marker="x")
plt.axline((0, 0), slope=1, color = "red")
plt.xlabel("Zeros")
plt.ylabel("Sixties")

In [None]:
differences = [zeros_vs_sixties[i][0] - zeros_vs_sixties[i][1] for i in range(len(zeros_vs_sixties))]
_ = plt.hist(differences, bins=100)

In [None]:
gene_diff_percent_dict = {}

for response_gene in response_genes:
    if response_gene <= 119:
        gene_diff_percent_dict[data.columns[response_gene]] = 100*(test_loss_dict[(0, response_gene)] - test_loss_dict[(60, response_gene)])/test_loss_dict[(0, response_gene)]

In [None]:
gene_diff_percent_dict

In [None]:
percent_differences = [100*(zeros_vs_sixties[i][0] - zeros_vs_sixties[i][1])/zeros_vs_sixties[i][0] for i in range(len(zeros_vs_sixties))]
_ = plt.hist(percent_differences, bins=25)
# plt.title("Percent reduction in MAE for genes \n in the MERFISH hypothalamus dataset", fontsize=14)
plt.xlabel("% Reduction in MAE", fontsize=12)
plt.ylabel("Number of Genes", fontsize=12)
plt.annotate("Ebf3", (3.05, 1.2), fontsize=10)
plt.annotate("Ermn", (4.6, 1.2), fontsize=10)
plt.annotate("Cpne5", (3.6, 1.2), fontsize=10)
plt.savefig("MLCB.png", dpi=300)
plt.gcf().set_dpi(300)
_ = plt.show()

In [None]:
min(differences), np.quantile(differences, 0.25), np.median(differences), np.quantile(differences, 0.75), max(differences)

In [None]:
0.012604475021362305/0.32853397727012634