In [1]:
import os
import torch

os.environ["MKL_NUM_THREADS"]="1"
os.environ["NUMEXPR_NUM_THREADS"]="1"
os.environ["OMP_NUM_THREADS"]="1"

import sys

import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import random_split
from torch_geometric.data import DataLoader

from spatial.merfish_dataset import FilteredMerfishDataset, MerfishDataset, SyntheticDataset0, SyntheticDataset1, SyntheticDataset2, SyntheticDataset3
from spatial.models.monet_ae import MonetAutoencoder2D, TrivialAutoencoder, MonetDense
from spatial.train import train
from spatial.predict import test


KeyboardInterrupt



# Testing Any Model

In [None]:
# read in merfish dataset and get columns names
import pandas as pd

# get relevant data stuff
df_file = pd.ExcelFile("~/spatial/data/messi.xlsx")
messi_df = pd.read_excel(df_file, "All.Pairs")
merfish_df = pd.read_csv("~/spatial/data/raw/merfish.csv")
merfish_df = merfish_df.drop(['Blank_1', 'Blank_2', 'Blank_3', 'Blank_4', 'Blank_5', 'Fos'], axis=1)

# these are the 13 ligands or receptors found in MESSI
non_response_genes = ['Cbln1', 'Cxcl14', 'Crhbp', 'Gabra1', 'Cbln2', 'Gpr165', 
                      'Glra3', 'Gabrg1', 'Adora2a', 'Vgf', 'Scg2', 'Cartpt',
                      'Tac2']
# this list stores the control genes aka "Blank_{int}"
blank_genes = []

# we will populate all of the non-response genes as being in one or the other
# the ones already filled in come from the existing 13 L/R genes above
ligands = ["Cbln1", "Cxcl14", "Cbln2", "Vgf", "Scg2", "Cartpt", "Tac2"]
receptors = ["Crhbp", "Gabra1", "Gpr165", "Glra3", "Gabrg1", "Adora2a"]

# ligands and receptor indexes in MERFISH
non_response_indeces = [list(merfish_df.columns).index(gene)-9 for gene in non_response_genes]
ligand_indeces = [list(merfish_df.columns).index(gene)-9 for gene in ligands]
receptor_indeces = [list(merfish_df.columns).index(gene)-9 for gene in receptors]
all_pairs_columns = [
    "Ligand.ApprovedSymbol",
    "Receptor.ApprovedSymbol",
]


# for column name in the column names above
for column in all_pairs_columns:
    for gene in merfish_df.columns:
        if (
            gene.upper() in list(messi_df[column])
            and gene.upper() not in non_response_genes
        ):
            non_response_genes.append(gene)
            non_response_indeces.append(list(merfish_df.columns).index(gene)-9)
            if column[0] == "L":
                ligands.append(gene)
                ligand_indeces.append(list(merfish_df.columns).index(gene)-9)
            else:
                receptors.append(gene)
                receptor_indeces.append(list(merfish_df.columns).index(gene)-9)
        if gene[:5] == "Blank" and gene not in blank_genes:
            blank_genes.append(gene)
            # non_response_indeces.append(list(merfish_df.columns).index(gene)-9)

print(non_response_genes)
print(
    "There are "
    + str(len(non_response_genes))
    + " genes recognized as either ligands or receptors (including new ones)."
)

print(
    "There are "
    + str(len(blank_genes))
    + " blank genes."
)

print(
    "There are "
    + str(155 - len(blank_genes) - len(non_response_genes))
    + " genes that are treated as response variables."
)

print(
    "There are "
    + str(len(ligands))
    + " ligands."
)

print(
    "There are "
    + str(len(receptors))
    + " receptors."
)

response_indeces = list(set(range(155)) - set(non_response_indeces))

In [None]:
import hydra
from hydra.experimental import compose, initialize

test_loss_rad_dict = {}

for rad in range(0,80,10):
    for test_animal in [1,2,3,4]:
        with initialize(config_path="../config"):
            cfg_from_terminal = compose(config_name="config")
            OmegaConf.update(cfg_from_terminal, "training.logger_name", "table2")
            OmegaConf.update(cfg_from_terminal, "radius", rad)
            OmegaConf.update(cfg_from_terminal, "datasets.dataset.test_animal", test_animal)
            output = test(cfg_from_terminal)
            trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
            test_loss_rad_dict[(rad, test_animal)] = test_results[0]['test_loss']

In [None]:
test_loss_rad_dict

In [None]:
import hydra
from hydra.experimental import compose, initialize

test_loss_rad_dict = {}

for rad in range(0,50,5):
    for synthetic_exp in range(4):
        with initialize(config_path="../config"):
            cfg_from_terminal = compose(config_name="config")
            OmegaConf.update(cfg_from_terminal, "model.kwargs.hidden_dimensions", [128, 128])
            OmegaConf.update(cfg_from_terminal, "training.logger_name", f"synthetic{synthetic_exp}")
            OmegaConf.update(cfg_from_terminal, "radius", rad)
            OmegaConf.update(cfg_from_terminal, "model.kwargs.response_genes", [0])
            OmegaConf.update(cfg_from_terminal, "datasets.dataset", [MerfishDataset])
            output = test(cfg_from_terminal)
            trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
            test_loss_rad_dict[rad] = test_results[0]['test_loss']

In [None]:
test_loss_rad_dict

In [None]:
import hydra
from hydra.experimental import compose, initialize

test_loss_rad_dict_response = {}

for rad in range(0,50,10):
    with initialize(config_path="../config"):
        cfg_from_terminal = compose(config_name="config")
        OmegaConf.update(cfg_from_terminal, "model.kwargs.hidden_dimensions", [256, 256])
        OmegaConf.update(cfg_from_terminal, "training.logger_name", "gene93")
        OmegaConf.update(cfg_from_terminal, "radius", rad)
        OmegaConf.update(cfg_from_terminal, "model.kwargs.response_genes", [93])
        OmegaConf.update(cfg_from_terminal, "training.filepath", f"{cfg_from_terminal.model.name}__{cfg_from_terminal.model.kwargs.hidden_dimensions}__{cfg_from_terminal.radius}__{cfg_from_terminal.training.logger_name}")
        output = test(cfg_from_terminal)
        trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
        test_loss_rad_dict_response[rad] = test_results[0]['test_loss']
        break

In [None]:
torch.mean(torch.abs(inputs[:, 93] - gene_expressions[:, 93]))

In [None]:
test_loss_rad_dict_response

General deepST for Individual Gene Predictions

In [None]:
import hydra
from hydra.experimental import compose, initialize

test_loss_rad_dict_response = {}
test_loss_rad_dict_93 = {}
test_loss_rad_dict_151 = {}

for rad in range(0,80,10):
    with initialize(config_path="../config"):
        cfg_from_terminal = compose(config_name="config")
        OmegaConf.update(cfg_from_terminal, "model.kwargs.hidden_dimensions", [128, 128])
        OmegaConf.update(cfg_from_terminal, "training.logger_name", "False")
        OmegaConf.update(cfg_from_terminal, "radius", rad)
        OmegaConf.update(cfg_from_terminal, "training.filepath", f"{cfg_from_terminal.model.name}__{cfg_from_terminal.model.kwargs.hidden_dimensions}__{cfg_from_terminal.radius}__{cfg_from_terminal.training.logger_name}")
        output = test(cfg_from_terminal)
        trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
        test_loss_rad_dict_response[rad] = test_results[0]['test_loss']
        test_loss_rad_dict_93[rad] = torch.mean(torch.abs(inputs[:, 93] - gene_expressions[:, 93]))
        test_loss_rad_dict_151[rad] = torch.mean(torch.abs(inputs[:, 93] - gene_expressions[:, 151]))

In [None]:
test_loss_rad_dict_93, test_loss_rad_dict_151

In [None]:
# equivalent to spatial

import hydra
from hydra.experimental import compose, initialize

with initialize(config_path="../config"):
    cfg_from_terminal = compose(config_name="config")
    OmegaConf.update(cfg_from_terminal, "model.kwargs.hidden_dimensions", [256, 256])
    OmegaConf.update(cfg_from_terminal, "training.logger_name", "neighbors_large")
    OmegaConf.update(cfg_from_terminal, "radius", 0)
    output = test(cfg_from_terminal)
    trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
    excitatory_cells = (celltypes == 6).nonzero(as_tuple=True)[0]
    MAE_excitatory = torch.abs(torch.index_select((gene_expressions-inputs)[excitatory_cells], 1, torch.tensor(response_indeces))).mean().item()

In [None]:
MAE_excitatory

In [None]:
trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output

In [None]:
test_results[0]['test_loss: mae_response']

In [None]:
# equivalent to spatial

import hydra
from hydra.experimental import compose, initialize

with initialize(config_path="../config"):
    cfg_from_terminal = compose(config_name="config")
    output = test(cfg_from_terminal)

# Testing Models with Updates

In [None]:
# equivalent to spatial

import hydra
from hydra.experimental import compose, initialize

with initialize(config_path="../config"):
    cfg_from_terminal = compose(config_name="config")
    # update the behavior to get the model of interest
    OmegaConf.update(cfg_from_terminal, "datasets.dataset.behaviors", ["Parenting"])
    output = test(cfg_from_terminal)

In [None]:
trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output

In [None]:
with open('../spatial/non_response.txt', "r") as genes_file:
    features = [int(x) for x in genes_file.read().split(",")]
    response_indeces = torch.tensor(list(set(range(160)) - set(features)))
    genes_file.close()

In [None]:
excitatory_cells = (celltypes == 6).nonzero(as_tuple=True)[0]

In [None]:
import torch

loss = torch.nn.L1Loss()
loss(torch.index_select(inputs[excitatory_cells], 1, response_indeces), torch.index_select(gene_expressions[excitatory_cells], 1, response_indeces))

In [None]:
# equivalent to spatial

import hydra
from hydra.experimental import compose, initialize

with initialize(config_path="../config"):
    cfg_from_terminal = compose(config_name="config")
    # update the behavior to get the model of interest
    OmegaConf.update(cfg_from_terminal, "datasets.dataset.behaviors", ["Virgin Parenting"])
    output = test(cfg_from_terminal)

In [None]:
trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output

In [None]:
excitatory_cells = (celltypes == 6).nonzero(as_tuple=True)[0]

In [None]:
import torch

loss = torch.nn.L1Loss()
loss(torch.index_select(inputs[excitatory_cells], 1, response_indeces), torch.index_select(gene_expressions[excitatory_cells], 1, response_indeces))

In [None]:
# equivalent to spatial

import hydra
from hydra.experimental import compose, initialize

with initialize(config_path="../config"):
    cfg_from_terminal = compose(config_name="config")
    # update the behavior to get the model of interest
    OmegaConf.update(cfg_from_terminal, "datasets.dataset.behaviors", ["Naive"])
    output = test(cfg_from_terminal)

In [None]:
trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output

In [None]:
excitatory_cells = (celltypes == 6).nonzero(as_tuple=True)[0]

In [None]:
import torch

loss = torch.nn.L1Loss()
loss(torch.index_select(inputs[excitatory_cells], 1, response_indeces), torch.index_select(gene_expressions[excitatory_cells], 1, response_indeces))