In [1]:
import os
import torch

os.environ["MKL_NUM_THREADS"]="1"
os.environ["NUMEXPR_NUM_THREADS"]="1"
os.environ["OMP_NUM_THREADS"]="1"

import sys

import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import random_split
from torch_geometric.data import DataLoader

from spatial.merfish_dataset import FilteredMerfishDataset, MerfishDataset
from spatial.models.monet_ae import MonetAutoencoder2D, TrivialAutoencoder, MonetDense
from spatial.train import train
from spatial.predict import test

# Testing Any Model

In [2]:
# read in merfish dataset and get columns names
import pandas as pd

# get relevant data stuff
df_file = pd.ExcelFile("~/spatial/data/messi.xlsx")
messi_df = pd.read_excel(df_file, "All.Pairs")
merfish_df = pd.read_csv("~/spatial/data/raw/merfish.csv")
merfish_df = merfish_df.drop(['Blank_1', 'Blank_2', 'Blank_3', 'Blank_4', 'Blank_5', 'Fos'], axis=1)

# these are the 13 ligands or receptors found in MESSI
non_response_genes = ['Cbln1', 'Cxcl14', 'Crhbp', 'Gabra1', 'Cbln2', 'Gpr165', 
                      'Glra3', 'Gabrg1', 'Adora2a', 'Vgf', 'Scg2', 'Cartpt',
                      'Tac2']
# this list stores the control genes aka "Blank_{int}"
blank_genes = []

# we will populate all of the non-response genes as being in one or the other
# the ones already filled in come from the existing 13 L/R genes above
ligands = ["Cbln1", "Cxcl14", "Cbln2", "Vgf", "Scg2", "Cartpt", "Tac2"]
receptors = ["Crhbp", "Gabra1", "Gpr165", "Glra3", "Gabrg1", "Adora2a"]

# ligands and receptor indexes in MERFISH
non_response_indeces = [list(merfish_df.columns).index(gene)-9 for gene in non_response_genes]
ligand_indeces = [list(merfish_df.columns).index(gene)-9 for gene in ligands]
receptor_indeces = [list(merfish_df.columns).index(gene)-9 for gene in receptors]
all_pairs_columns = [
    "Ligand.ApprovedSymbol",
    "Receptor.ApprovedSymbol",
]


# for column name in the column names above
for column in all_pairs_columns:
    for gene in merfish_df.columns:
        if (
            gene.upper() in list(messi_df[column])
            and gene.upper() not in non_response_genes
        ):
            non_response_genes.append(gene)
            non_response_indeces.append(list(merfish_df.columns).index(gene)-9)
            if column[0] == "L":
                ligands.append(gene)
                ligand_indeces.append(list(merfish_df.columns).index(gene)-9)
            else:
                receptors.append(gene)
                receptor_indeces.append(list(merfish_df.columns).index(gene)-9)
        if gene[:5] == "Blank" and gene not in blank_genes:
            blank_genes.append(gene)
            # non_response_indeces.append(list(merfish_df.columns).index(gene)-9)

print(non_response_genes)
print(
    "There are "
    + str(len(non_response_genes))
    + " genes recognized as either ligands or receptors (including new ones)."
)

print(
    "There are "
    + str(len(blank_genes))
    + " blank genes."
)

print(
    "There are "
    + str(155 - len(blank_genes) - len(non_response_genes))
    + " genes that are treated as response variables."
)

print(
    "There are "
    + str(len(ligands))
    + " ligands."
)

print(
    "There are "
    + str(len(receptors))
    + " receptors."
)

response_indeces = list(set(range(155)) - set(non_response_indeces))

  warn(msg)


['Cbln1', 'Cxcl14', 'Crhbp', 'Gabra1', 'Cbln2', 'Gpr165', 'Glra3', 'Gabrg1', 'Adora2a', 'Vgf', 'Scg2', 'Cartpt', 'Tac2', 'Bdnf', 'Bmp7', 'Cyr61', 'Fn1', 'Fst', 'Gad1', 'Ntng1', 'Pnoc', 'Selplg', 'Sema3c', 'Sema4d', 'Serpine1', 'Adcyap1', 'Cck', 'Crh', 'Gal', 'Gnrh1', 'Nts', 'Oxt', 'Penk', 'Sst', 'Tac1', 'Trh', 'Ucn3', 'Avpr1a', 'Avpr2', 'Brs3', 'Calcr', 'Cckar', 'Cckbr', 'Crhr1', 'Crhr2', 'Galr1', 'Galr2', 'Grpr', 'Htr2c', 'Igf1r', 'Igf2r', 'Kiss1r', 'Lepr', 'Lpar1', 'Mc4r', 'Npy1r', 'Npy2r', 'Ntsr1', 'Oprd1', 'Oprk1', 'Oprl1', 'Oxtr', 'Pdgfra', 'Prlr', 'Ramp3', 'Rxfp1', 'Slc17a7', 'Slc18a2', 'Tacr1', 'Tacr3', 'Trhr']
There are 71 genes recognized as either ligands or receptors (including new ones).
There are 0 blank genes.
There are 84 genes that are treated as response variables.
There are 31 ligands.
There are 40 receptors.


In [7]:
import hydra
from hydra.experimental import compose, initialize

test_loss_rad_dict = {}

for rad in range(0,80,10):
    for test_animal in [1,2,3,4]:
        with initialize(config_path="../config"):
            cfg_from_terminal = compose(config_name="config")
            OmegaConf.update(cfg_from_terminal, "training.logger_name", "table2")
            OmegaConf.update(cfg_from_terminal, "radius", rad)
            OmegaConf.update(cfg_from_terminal, "datasets.dataset.test_animal", test_animal)
            output = test(cfg_from_terminal)
            trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
            test_loss_rad_dict[(rad, test_animal)] = test_results[0]['test_loss']

ConfigAttributeError: Key 'test_animal' is not in struct
    full_key: datasets.dataset.test_animal
    object_type=dict

In [9]:
test_loss_rad_dict

{(0, 1): 0.33348438143730164,
 (0, 2): 0.32344749569892883,
 (0, 3): 0.3152901828289032,
 (0, 4): 0.3326919674873352,
 (10, 1): 0.33578985929489136,
 (10, 2): 0.32577621936798096,
 (10, 3): 0.3188694715499878,
 (10, 4): 0.33603790402412415,
 (20, 1): 0.3356193006038666,
 (20, 2): 0.32561439275741577,
 (20, 3): 0.31718048453330994,
 (20, 4): 0.3340090215206146,
 (30, 1): 0.3336813747882843,
 (30, 2): 0.32789313793182373,
 (30, 3): 0.31558355689048767,
 (30, 4): 0.33445167541503906,
 (40, 1): 0.33460772037506104,
 (40, 2): 0.32464149594306946,
 (40, 3): 0.3162347376346588,
 (40, 4): 0.33553236722946167,
 (50, 1): 0.33392974734306335,
 (50, 2): 0.3266395628452301,
 (50, 3): 0.31530994176864624,
 (50, 4): 0.33241382241249084,
 (60, 1): 0.33370622992515564,
 (60, 2): 0.3268411457538605,
 (60, 3): 0.31519392132759094,
 (60, 4): 0.3317946195602417,
 (70, 1): 0.3324054479598999,
 (70, 2): 0.32449594140052795,
 (70, 3): 0.3154938519001007,
 (70, 4): 0.33066630363464355}

In [8]:
import hydra
from hydra.experimental import compose, initialize

test_loss_rad_dict = {}

for rad in range(0,90,10):
    with initialize(config_path="../config"):
        cfg_from_terminal = compose(config_name="config")
        OmegaConf.update(cfg_from_terminal, "model.kwargs.hidden_dimensions", [128, 128])
        OmegaConf.update(cfg_from_terminal, "training.logger_name", "gene151")
        OmegaConf.update(cfg_from_terminal, "radius", rad)
        OmegaConf.update(cfg_from_terminal, "model.kwargs.response_genes", [151])
        output = test(cfg_from_terminal)
        trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
        test_loss_rad_dict[rad] = test_results[0]['test_loss']

InterpolationKeyError: Interpolation key 'datasets.dataset.test_animal' not found
    full_key: training.filepath
    object_type=dict

In [7]:
test_loss_rad_dict

{0: 0.0499076209962368,
 10: 0.0508580319583416,
 20: 0.05018540844321251,
 30: 0.05055781826376915,
 40: 0.05067247524857521,
 50: 0.05058332160115242,
 60: 0.05033117160201073,
 70: 0.05010917782783508,
 80: 0.05047113820910454}

In [10]:
import hydra
from hydra.experimental import compose, initialize

test_loss_rad_dict_response = {}

for rad in range(0,50,10):
    with initialize(config_path="../config"):
        cfg_from_terminal = compose(config_name="config")
        OmegaConf.update(cfg_from_terminal, "model.kwargs.hidden_dimensions", [256, 256])
        OmegaConf.update(cfg_from_terminal, "training.logger_name", "gene93")
        OmegaConf.update(cfg_from_terminal, "radius", rad)
        OmegaConf.update(cfg_from_terminal, "model.kwargs.response_genes", [93])
        OmegaConf.update(cfg_from_terminal, "training.filepath", f"{cfg_from_terminal.model.name}__{cfg_from_terminal.model.kwargs.hidden_dimensions}__{cfg_from_terminal.radius}__{cfg_from_terminal.training.logger_name}")
        output = test(cfg_from_terminal)
        trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
        test_loss_rad_dict_response[rad] = test_results[0]['test_loss']
        break

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

  new_batch_obj.x *= torch.tensor(responses)
TEST Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
---------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  36.651         	|  100 %          	|
---------------------------------------------------------------------------------------------------------------------------------------
run_test_evaluation                	|  33.403         	|1              	|  33.403         	|  91.137         	|
evaluation_step_and_end            	|  0.27882        	|96             	|  26.767         	|  73.033         	|
test_step                          	|  0.27869        	|96             	|  26.754         	|  72.998         	|
get_test_batch                     	|  0.049069       	|97             	|  4.7597         	|  12.98

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 0.5410650372505188, 'test_loss: mse': 0.5704374313354492}
--------------------------------------------------------------------------------


In [17]:
torch.mean(torch.abs(inputs[:, 93] - gene_expressions[:, 93]))

tensor(0.5411)

In [None]:
test_loss_rad_dict_response

General deepST for Individual Gene Predictions

In [3]:
import hydra
from hydra.experimental import compose, initialize

test_loss_rad_dict_response = {}
test_loss_rad_dict_93 = {}
test_loss_rad_dict_151 = {}

for rad in range(0,80,10):
    with initialize(config_path="../config"):
        cfg_from_terminal = compose(config_name="config")
        OmegaConf.update(cfg_from_terminal, "model.kwargs.hidden_dimensions", [128, 128])
        OmegaConf.update(cfg_from_terminal, "training.logger_name", "False")
        OmegaConf.update(cfg_from_terminal, "radius", rad)
        OmegaConf.update(cfg_from_terminal, "training.filepath", f"{cfg_from_terminal.model.name}__{cfg_from_terminal.model.kwargs.hidden_dimensions}__{cfg_from_terminal.radius}__{cfg_from_terminal.training.logger_name}")
        output = test(cfg_from_terminal)
        trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
        test_loss_rad_dict_response[rad] = test_results[0]['test_loss']
        test_loss_rad_dict_93[rad] = torch.mean(torch.abs(inputs[:, 93] - gene_expressions[:, 93]))
        test_loss_rad_dict_151[rad] = torch.mean(torch.abs(inputs[:, 93] - gene_expressions[:, 151]))

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

  new_batch_obj.x *= torch.tensor(responses)
TEST Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
---------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  35.988         	|  100 %          	|
---------------------------------------------------------------------------------------------------------------------------------------
run_test_evaluation                	|  32.885         	|1              	|  32.885         	|  91.379         	|
evaluation_step_and_end            	|  0.28655        	|96             	|  27.509         	|  76.439         	|
test_step                          	|  0.28641        	|96             	|  27.496         	|  76.403         	|
get_test_batch                     	|  0.035479       	|97             	|  3.4414         	|  9.562

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 0.3506155014038086, 'test_loss: mse': 0.434160441160202}
--------------------------------------------------------------------------------


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Testing: 0it [00:00, ?it/s]

TEST Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
---------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  31.21          	|  100 %          	|
---------------------------------------------------------------------------------------------------------------------------------------
run_test_evaluation                	|  31.194         	|1              	|  31.194         	|  99.947         	|
evaluation_step_and_end            	|  0.26957        	|96             	|  25.879         	|  82.917         	|
test_step                          	|  0.26945        	|96             	|  25.867         	|  82.879         	|
get_test_batch                     	|  0.036855       	|97             	|  3.5749         	|  11.454         	|
fetch_next_test_batch           

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 0.35077738761901855, 'test_loss: mse': 0.43806537985801697}
--------------------------------------------------------------------------------
Filtered Data (109105, 170)
/home/roko/spatial/data/raw/merfish_messi.hdf5


InstantiationException: Error in call to target 'spatial.merfish_dataset.FilteredMerfishDataset':
IndexError('list index out of range')
full_key: datasets.dataset

In [None]:
test_loss_rad_dict_93, test_loss_rad_dict_151

In [8]:
# equivalent to spatial

import hydra
from hydra.experimental import compose, initialize

with initialize(config_path="../config"):
    cfg_from_terminal = compose(config_name="config")
    OmegaConf.update(cfg_from_terminal, "model.kwargs.hidden_dimensions", [256, 256])
    OmegaConf.update(cfg_from_terminal, "training.logger_name", "neighbors_large")
    OmegaConf.update(cfg_from_terminal, "radius", 0)
    output = test(cfg_from_terminal)
    trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
    excitatory_cells = (celltypes == 6).nonzero(as_tuple=True)[0]
    MAE_excitatory = torch.abs(torch.index_select((gene_expressions-inputs)[excitatory_cells], 1, torch.tensor(response_indeces))).mean().item()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Testing: 0it [00:00, ?it/s]

TEST Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
---------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  21.973         	|  100 %          	|
---------------------------------------------------------------------------------------------------------------------------------------
run_test_evaluation                	|  21.839         	|1              	|  21.839         	|  99.39          	|
evaluation_step_and_end            	|  0.56187        	|31             	|  17.418         	|  79.27          	|
test_step                          	|  0.56175        	|31             	|  17.414         	|  79.253         	|
get_test_batch                     	|  0.10323        	|32             	|  3.3033         	|  15.034         	|
fetch_next_test_batch           

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 0.08934661746025085,
 'test_loss: mae_response': 0.12857602536678314,
 'test_loss: mse': 3.003896951675415}
--------------------------------------------------------------------------------


In [4]:
MAE_excitatory

0.15488235652446747

In [3]:
trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output

In [4]:
test_results[0]['test_loss: mae_response']

0.33515501022338867

In [3]:
# equivalent to spatial

import hydra
from hydra.experimental import compose, initialize

with initialize(config_path="../config"):
    cfg_from_terminal = compose(config_name="config")
    output = test(cfg_from_terminal)

See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information


Original Data (1027848, 170)
Filtered Data (205348, 170)
/home/roko/spatial/data/raw/merfish_messi.hdf5


  rank_zero_deprecation(
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

  rank_zero_warn(
TEST Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
---------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  8.7991         	|  100 %          	|
---------------------------------------------------------------------------------------------------------------------------------------
run_test_evaluation                	|  8.7592         	|1              	|  8.7592         	|  99.547         	|
evaluation_step_and_end            	|  0.57904        	|12             	|  6.9484         	|  78.968         	|
test_step                          	|  0.56331        	|12             	|  6.7597         	|  76.823         	|
get_test_batch                     	|  0.093008       	|13             	|  1.2091         	|  13.741         	|
fetch_next_tes

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 0.19637224078178406,
 'test_loss: mae_response': 0.29283836483955383,
 'test_loss: mse': 0.1893213391304016}
--------------------------------------------------------------------------------


# Testing Models with Updates

In [55]:
# equivalent to spatial

import hydra
from hydra.experimental import compose, initialize

with initialize(config_path="../config"):
    cfg_from_terminal = compose(config_name="config")
    # update the behavior to get the model of interest
    OmegaConf.update(cfg_from_terminal, "datasets.dataset.behaviors", ["Parenting"])
    output = test(cfg_from_terminal)

See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information


Original Data (1027848, 170)
Filtered Data (86902, 170)
/home/roko/spatial/data/raw/merfish_messi.hdf5


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 0.2307247370481491,
 'test_loss: mae_response': 0.36967429518699646,
 'test_loss: mse': 0.2517955005168915}
--------------------------------------------------------------------------------


In [56]:
trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output

In [57]:
with open('../spatial/non_response.txt', "r") as genes_file:
    features = [int(x) for x in genes_file.read().split(",")]
    response_indeces = torch.tensor(list(set(range(160)) - set(features)))
    genes_file.close()

In [58]:
excitatory_cells = (celltypes == 6).nonzero(as_tuple=True)[0]

In [59]:
import torch

loss = torch.nn.L1Loss()
loss(torch.index_select(inputs[excitatory_cells], 1, response_indeces), torch.index_select(gene_expressions[excitatory_cells], 1, response_indeces))

tensor(0.4182)

In [60]:
# equivalent to spatial

import hydra
from hydra.experimental import compose, initialize

with initialize(config_path="../config"):
    cfg_from_terminal = compose(config_name="config")
    # update the behavior to get the model of interest
    OmegaConf.update(cfg_from_terminal, "datasets.dataset.behaviors", ["Virgin Parenting"])
    output = test(cfg_from_terminal)

See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information


Original Data (1027848, 170)
Filtered Data (109105, 170)
/home/roko/spatial/data/raw/merfish_messi.hdf5


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 0.22909943759441376,
 'test_loss: mae_response': 0.3661682903766632,
 'test_loss: mse': 0.23969760537147522}
--------------------------------------------------------------------------------


In [61]:
trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output

In [62]:
excitatory_cells = (celltypes == 6).nonzero(as_tuple=True)[0]

In [63]:
import torch

loss = torch.nn.L1Loss()
loss(torch.index_select(inputs[excitatory_cells], 1, response_indeces), torch.index_select(gene_expressions[excitatory_cells], 1, response_indeces))

tensor(0.4185)

In [64]:
# equivalent to spatial

import hydra
from hydra.experimental import compose, initialize

with initialize(config_path="../config"):
    cfg_from_terminal = compose(config_name="config")
    # update the behavior to get the model of interest
    OmegaConf.update(cfg_from_terminal, "datasets.dataset.behaviors", ["Naive"])
    output = test(cfg_from_terminal)

See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information


Original Data (1027848, 170)
Filtered Data (205348, 170)
/home/roko/spatial/data/raw/merfish_messi.hdf5


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 0.19866959750652313,
 'test_loss: mae_response': 0.3151684105396271,
 'test_loss: mse': 0.18607354164123535}
--------------------------------------------------------------------------------


In [65]:
trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output

In [66]:
excitatory_cells = (celltypes == 6).nonzero(as_tuple=True)[0]

In [67]:
import torch

loss = torch.nn.L1Loss()
loss(torch.index_select(inputs[excitatory_cells], 1, response_indeces), torch.index_select(gene_expressions[excitatory_cells], 1, response_indeces))

tensor(0.3570)