# Workflow

OR Write a single script that starts with call to FilteredMerfishDataset, extract animals with that behavior and sex, and then run a CV.

In [1]:
import pandas as pd
import json
import time

import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import random_split
from torch_geometric.data import DataLoader

from spatial.merfish_dataset import FilteredMerfishDataset, MerfishDataset
from spatial.models.monet_ae import MonetAutoencoder2D, TrivialAutoencoder
from spatial.train import train
from spatial.predict import test

import torch

import hydra
from hydra.experimental import compose, initialize

In [2]:
behaviors = ["Parenting", "Virgin Parenting", "Naive"]
sexes = ["Female"]

In [3]:
with open('animal_id.json') as json_file:
    animals = json.load(json_file)

In [4]:
loss_dict = {}
time_dict = {}
loss_excitatory_dict = {}
loss_inhibitory_dict = {}

for behavior in behaviors:
    for sex in sexes:
        try:
            animal_list = animals[behavior][sex]
        except KeyError:
            continue
        behavior = [behavior]
        sex = [sex]
        # print(behavior, sex, animal_list)
        for animal in animal_list:
            start = time.time()
            with initialize(config_path="../config"):
                cfg_from_terminal = compose(config_name="config")
                # update the behavior to get the model of interest
                OmegaConf.update(cfg_from_terminal, "datasets.dataset.non_response_genes_file", "/home/roko/spatial/spatial/ligands_only.txt")
                OmegaConf.update(cfg_from_terminal, "datasets.dataset.behaviors", behavior)
                OmegaConf.update(cfg_from_terminal, "datasets.dataset.sexes", sex)
                OmegaConf.update(cfg_from_terminal, "datasets.dataset.test_animal", animal)
                model = train(cfg_from_terminal)
                output = test(cfg_from_terminal)
                trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
                MAE = test_results[0]['test_loss: mae_response']
                excitatory_cells = (celltypes == 6).nonzero(as_tuple=True)[0]
                MAE = test_results[0]['test_loss: mae_response']
                excitatory_cells = (celltypes == 6).nonzero(as_tuple=True)[0]
                MAE_excitatory = torch.abs(torch.index_select((gene_expressions-inputs)[excitatory_cells], 1, torch.tensor(model.responses))).mean().item()
                inhibitory_cells = (celltypes == 7).nonzero(as_tuple=True)[0]
                MAE_inhibitory = torch.abs(torch.index_select((gene_expressions-inputs)[inhibitory_cells], 1, torch.tensor(model.responses))).mean().item()
            end = time.time()
            time_dict[f"{sex}_{behavior}_{animal}"] = end-start
            loss_dict[f"{sex}_{behavior}_{animal}"] = MAE
            loss_excitatory_dict[f"{sex}_{behavior}_{animal}"] = MAE_excitatory
            loss_inhibitory_dict[f"{sex}_{behavior}_{animal}"] = MAE_inhibitory
            
            with open("deepST_MAE_ligandsOnly.json", "w") as outfile:
                json.dump(loss_dict, outfile, indent=4)

            with open("deepST_time_ligandsOnly.json", "w") as outfile:
                json.dump(time_dict, outfile, indent=4)
                
            with open("deepST_MAE_excitatory_ligandsOnly.json", "w") as outfile:
                json.dump(loss_excitatory_dict, outfile, indent=4)
                
            with open("deepST_MAE_inhibitory_ligandsOnly.json", "w") as outfile:
                json.dump(loss_inhibitory_dict, outfile, indent=4)

See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information


Original Data (1027848, 170)
Filtered Data (86902, 170)
/home/roko/spatial/data/raw/merfish_messi.hdf5


  rank_zero_deprecation(
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name            | Type                    | Params
------------------------------------------------------------
0 | encoder_network | DenseReluGMMConvNetwork | 2.8 M 
1 | decoder_network | DenseReluGMMConvNetwork | 2.8 M 
------------------------------------------------------------
5.5 M     Trainable params
0         Non-trainable params
5.5 M     Total params
22.048    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch 9, global step 149: val_loss reached 0.56027 (best 0.56027), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__3__['Female']__['Parenting']__0.001__deepST_CV_16.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 19, global step 299: val_loss reached 0.49573 (best 0.49573), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__3__['Female']__['Parenting']__0.001__deepST_CV_16.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 29, global step 449: val_loss reached 0.44104 (best 0.44104), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__3__['Female']__['Parenting']__0.001__deepST_CV_16.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 39, global step 599: val_loss reached 0.42671 (best 0.42671), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__3__['Female']__['Parenting']__0.001__deepST_CV_16.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 49, global step 749: val_loss reached 0.39717 (best 0.39717), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__3__['Female']__['Parenting']__0.001__deepST_CV_16.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 59, global step 899: val_loss reached 0.39091 (best 0.39091), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__3__['Female']__['Parenting']__0.001__deepST_CV_16.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 69, global step 1049: val_loss reached 0.38293 (best 0.38293), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__3__['Female']__['Parenting']__0.001__deepST_CV_16.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 79, global step 1199: val_loss reached 0.37283 (best 0.37283), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__3__['Female']__['Parenting']__0.001__deepST_CV_16.ckpt" as top True
  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


KeyboardInterrupt: 

## Graph Organized by Animal ID Verification

In order for a specific animal to be held out for testing, we need to understand how exactly the slices get stored before graph construction. We can accomplish this by putting a break point in unique_slices, and observing what happens and assuring it's not random.

    cell_types = [
        "Ambiguous",
        "Astrocyte",
        "Endothelial 1",
        "Endothelial 2",
        "Endothelial 3",
        "Ependymal",
        "Excitatory",
        "Inhibitory",
        "Microglia",
        "OD Immature 1",
        "OD Immature 2",
        "OD Mature 1",
        "OD Mature 2",
        "OD Mature 3",
        "OD Mature 4",
        "Pericytes",
    ]