In [1]:
import os
import sys

import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import random_split
from torch_geometric.data import DataLoader

from spatial.merfish_dataset import MerfishDataset
from spatial.models.monet_ae import MonetAutoencoder2D, TrivialAutoencoder
from spatial.train import train
from spatial.predict import test

In [2]:
# equivalent to spatial

import hydra
from hydra.experimental import compose, initialize

with initialize(config_path="../config"):
    cfg_from_terminal = compose(config_name="config")
    output = test(cfg_from_terminal)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 1.4626601934432983}
--------------------------------------------------------------------------------


# MAE Imaging

In [3]:
import matplotlib.pyplot as plt

trainer, l1_losses, inputs, gene_expressions, celltypes = output

We can see in the source code that the cell types are listed as follows:

    cell_types = [
        "Ambiguous",
        "Astrocyte",
        "Endothelial 1",
        "Endothelial 2",
        "Endothelial 3",
        "Ependymal",
        "Excitatory",
        "Inhibitory",
        "Microglia",
        "OD Immature 1",
        "OD Immature 2",
        "OD Mature 1",
        "OD Mature 2",
        "OD Mature 3",
        "OD Mature 4",
        "Pericytes",
    ]
    
So, we can keep only the cell types the MESSI uses to evaluate their performance. These would be [1, 2, 6, 7, 8, 9, 10, 11]

#### For a single sample.

# FIX THIS TO BE BOXPLOT OVER ALL EXPRESSIONS

In [4]:
# plt.boxplot(L1_losses.reshape(1,-1))

In [5]:
# plt.boxplot(L1_losses.reshape(1,-1)[:,:])

The code below simply identifies the response and non-response genes.

In [6]:
# read in merfish dataset and get columns names
import pandas as pd

# get relevant data stuff
df_file = pd.ExcelFile("~/spatial/data/messi.xlsx")
messi_df = pd.read_excel(df_file, "All.Pairs")
merfish_df = pd.read_csv("~/spatial/data/raw/merfish.csv")

# these are the 13 ligands or receptors found in MESSI
non_response_genes = ['Cbln1', 'Cxcl14', 'Crhbp', 'Gabra1', 'Cbln2', 'Gpr165', 
                      'Glra3', 'Gabrg1', 'Adora2a', 'Vgf', 'Scg2', 'Cartpt',
                      'Tac2']
# this list stores the control genes aka "Blank_{int}"
blank_genes = []
# ligands and receptor indexes in MERFISH
non_response_indeces = [list(merfish_df.columns).index(gene)-10 for gene in non_response_genes]
all_pairs_columns = [
    "Ligand.ApprovedSymbol",
    "Receptor.ApprovedSymbol",
]

# for column name in the column names above
for column in all_pairs_columns:
    for gene in merfish_df.columns:
        if (
            gene.upper() in list(messi_df[column])
            and gene.upper() not in non_response_genes
        ):
            non_response_genes.append(gene)
            non_response_indeces.append(list(merfish_df.columns).index(gene)-10)
        if gene[:5] == "Blank" and gene not in blank_genes:
            blank_genes.append(gene)
            non_response_indeces.append(list(merfish_df.columns).index(gene)-10)
print(non_response_genes)
print(
    "There are "
    + str(len(non_response_genes))
    + " genes recognized as either ligands or receptors (including new ones)."
)

print(
    "There are "
    + str(len(blank_genes))
    + " blank genes."
)

print(
    "There are "
    + str(160 - len(blank_genes) - len(non_response_genes)) # -5 is because of the blanks
    + " genes that are treated as response variables."
)

  warn(msg)


['Cbln1', 'Cxcl14', 'Crhbp', 'Gabra1', 'Cbln2', 'Gpr165', 'Glra3', 'Gabrg1', 'Adora2a', 'Vgf', 'Scg2', 'Cartpt', 'Tac2', 'Bdnf', 'Bmp7', 'Cyr61', 'Fn1', 'Fst', 'Gad1', 'Ntng1', 'Pnoc', 'Selplg', 'Sema3c', 'Sema4d', 'Serpine1', 'Adcyap1', 'Cck', 'Crh', 'Gal', 'Gnrh1', 'Nts', 'Oxt', 'Penk', 'Sst', 'Tac1', 'Trh', 'Ucn3', 'Avpr1a', 'Avpr2', 'Brs3', 'Calcr', 'Cckar', 'Cckbr', 'Crhr1', 'Crhr2', 'Galr1', 'Galr2', 'Grpr', 'Htr2c', 'Igf1r', 'Igf2r', 'Kiss1r', 'Lepr', 'Lpar1', 'Mc4r', 'Npy1r', 'Npy2r', 'Ntsr1', 'Oprd1', 'Oprk1', 'Oprl1', 'Oxtr', 'Pdgfra', 'Prlr', 'Ramp3', 'Rxfp1', 'Slc17a7', 'Slc18a2', 'Tacr1', 'Tacr3', 'Trhr']
There are 71 genes recognized as either ligands or receptors (including new ones).
There are 5 blank genes.
There are 84 genes that are treated as response variables.


The code below generates some MAE values separated by response status.

In [7]:
import torch
from torch.nn import functional as F
import matplotlib.pyplot as plt

non_response_inputs = torch.index_select(inputs, 1, torch.tensor(non_response_indeces))
response_inputs = torch.index_select(inputs, 1, torch.tensor(list(set(range(160)) - set(non_response_indeces))))
print("Initial Response Expression Mean: " + str(response_inputs.mean().item()))
print("Initial Non-Response Expression Mean: " + str(non_response_inputs.mean().item()))

non_response_expressions =  torch.index_select(gene_expressions, 1, torch.tensor(non_response_indeces))
response_expressions = torch.index_select(gene_expressions, 1, torch.tensor(list(set(range(160)) - set(non_response_indeces))))
print("AE Response Expression Mean: " + str(response_expressions.mean().item()))
print("AE Non-Response Expression Mean: " + str(non_response_expressions.mean().item()))

mean_MAE_L1 = F.l1_loss(inputs.mean(axis=0), gene_expressions.mean(axis=0)).item()

print("Average L1 Loss for Mean " + str(mean_MAE_L1))

### L1 Losses 
# response_gene_boxplot = abs(inputs.mean(axis=0) - gene_expressions.mean(axis=0))
# print(pd.DataFrame(response_gene_boxplot).describe())
# plt.boxplot(response_gene_boxplot.reshape(1,-1))
# _ = plt.show()

mean_MAE_response_L1 = F.l1_loss(response_inputs.mean(axis=0), response_expressions.mean(axis=0))
print("Average L1 Loss for Mean among Response Genes: " + str(mean_MAE_response_L1.item()))

mean_MAE_non_response_L1 = F.l1_loss(non_response_inputs.mean(axis=0), non_response_expressions.mean(axis=0))
print("Average L1 Loss for Mean among Non-Response Genes: " + str(mean_MAE_non_response_L1.item()))

Initial Response Expression Mean: 1.7898054122924805
Initial Non-Response Expression Mean: 1.3080828189849854
AE Response Expression Mean: 0.1979086995124817
AE Non-Response Expression Mean: 0.1586369425058365
Average L1 Loss for Mean 1.3821319341659546
Average L1 Loss for Mean among Response Genes: 1.5918967723846436
Average L1 Loss for Mean among Non-Response Genes: 1.1502866744995117


Now, we need to do the same thing, but separate the celltypes we care about.

In [11]:
cells = [cell for cell, celltype in enumerate(celltypes[:,1]) if celltype in [1, 2, 6, 7, 8, 9, 10, 11]]

In [9]:
import torch
from torch.nn import functional as F
import matplotlib.pyplot as plt

# filtering the inputs this time
filtered_inputs = torch.index_select(inputs, 0, torch.tensor(cells))
filtered_gene_expressions = torch.index_select(gene_expressions, 0, torch.tensor(cells))

non_response_inputs = torch.index_select(filtered_inputs, 1, torch.tensor(non_response_indeces))
response_inputs = torch.index_select(filtered_inputs, 1, torch.tensor(list(set(range(160)) - set(non_response_indeces))))
print("Initial Response Expression Mean: " + str(response_inputs.mean().item()))
print("Initial Non-Response Expression Mean: " + str(non_response_inputs.mean().item()))

non_response_expressions =  torch.index_select(filtered_gene_expressions, 1, torch.tensor(non_response_indeces))
response_expressions = torch.index_select(filtered_gene_expressions, 1, torch.tensor(list(set(range(160)) - set(non_response_indeces))))
print("AE Response Expression Mean: " + str(response_expressions.mean().item()))
print("AE Non-Response Expression Mean: " + str(non_response_expressions.mean().item()))

mean_MAE_L1 = F.l1_loss(filtered_inputs.mean(axis=0), filtered_gene_expressions.mean(axis=0)).item()

print("Average L1 Loss for Mean " + str(mean_MAE_L1))

# response_gene_boxplot = abs(filtered_inputs.mean(axis=0) - filtered_gene_expressions.mean(axis=0))
# print(pd.DataFrame(response_gene_boxplot).describe())
# plt.boxplot(response_gene_boxplot.reshape(1,-1))
# _ = plt.show()

mean_MAE_response_L1 = F.l1_loss(response_inputs.mean(axis=0), response_expressions.mean(axis=0))
print("Average L1 Loss for Mean among Response Genes: " + str(mean_MAE_response_L1.item()))

mean_MAE_non_response_L1 = F.l1_loss(non_response_inputs.mean(axis=0), non_response_expressions.mean(axis=0))
print("Average L1 Loss for Mean among Non-Response Genes: " + str(mean_MAE_non_response_L1.item()))

Initial Response Expression Mean: 1.6762956380844116
Initial Non-Response Expression Mean: 1.413140892982483
AE Response Expression Mean: 0.20724879205226898
AE Non-Response Expression Mean: 0.17563262581825256
Average L1 Loss for Mean 1.3590658903121948
Average L1 Loss for Mean among Response Genes: 1.469046711921692
Average L1 Loss for Mean among Non-Response Genes: 1.2375082969665527


# Checklist of Things To Do:

- [x] Compare runs on tensorboard.
- [x] Ensure that outputs for every test example get saved (i.e. we should have a list of 31 tensors.) <b>No need. The tensor gets concatenated. We could create a tensor of tensors, but idk how helpful that is.</b>
- [x] Understand which anid, bregma correspond to which mouse and figure out how MESSI dealt with this. <b>Most animals have multiples bregmas and MESSI uses several animals (16-19) in their demo. </b>
- [x] Filter out outputs based on their respective celltype values (y).
- [ ] Implement XGBoost on the 8 celltypes with pip install xgboost.
- [x] Make batch size hydra parameter.
- [ ] Try multi-batch run with appropriate batch size and n_neighbors.
- [x] Change logger to only log MAE.
- [x] Add nbstripout to poetry.
- [x] Figure out how to label tensorboard folders.
- [x] Fix CUDA_VISIBLE_DEVICES <b>This is "fixed" but setting visible devices to more than 1 makes it "used" as a process even if a process doesn't get set to a GPU. Does being visible open up a process?</b>
- [ ] Explore different neighborhood methodologies.
    - [ ] Nearest Neighbor Graph
    - [ ] Dulaney Trinagulation
    - [ ] Gabriel Graph
    - [ ] Urquhart Graph (this will be a very good comparison to Dulaney Triangulation