In [1]:
import os
import torch
import json
import requests
import h5py
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

os.environ["MKL_NUM_THREADS"]="1"
os.environ["NUMEXPR_NUM_THREADS"]="1"
os.environ["OMP_NUM_THREADS"]="1"
# os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6"

import sys

import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import random_split
from torch_geometric.data import DataLoader

from spatial.merfish_dataset import FilteredMerfishDataset, MerfishDataset, SyntheticNonlinear, SyntheticDataset0, SyntheticDataset1, SyntheticDataset2, SyntheticDataset3
from spatial.models.monet_ae import MonetAutoencoder2D, TrivialAutoencoder, MonetDense
from spatial.train import train
from spatial.predict import test

import hydra
from hydra.experimental import compose, initialize

import warnings
warnings.filterwarnings("ignore")

In [2]:
response_indexes = [0,2,3,4,5,6,7,10,19,20,21,22,23,24,25,26,27,28,32,34,35,37,38,39,40,41,42,43,44,52,53,54,55,58,63,64,66,67,69,71,73,74,75,76,77,78,79,80,85,86,87,88,93,94,96,97,99,102,103,104,106,110,112,113,114,116,118,119,120,121,122,123,124,125,126,129,130,131,133,134,141,142,147,151]

In [5]:
import hydra
from hydra.experimental import compose, initialize

test_loss_rad_dict = {}
r2_loss_rad_dict = {}

for rad in [0,25]:
    with initialize(config_path="../../config"):
        try:
            cfg_from_terminal = compose(config_name="config")
            OmegaConf.update(cfg_from_terminal, "model.kwargs.observables_dimension", 71)
            OmegaConf.update(cfg_from_terminal, "model.kwargs.hidden_dimensions", [512, 512, 512, 512, 512, 512])
            OmegaConf.update(cfg_from_terminal, "model.kwargs.output_dimension", 84)
            OmegaConf.update(cfg_from_terminal, "optimizer.name", "Adam")
            OmegaConf.update(cfg_from_terminal, "training.logger_name", "table2_FULL_no_celltypes")
            OmegaConf.update(cfg_from_terminal, "training.trainer.strategy", "auto")
            OmegaConf.update(cfg_from_terminal, "datasets.dataset.include_celltypes", False)
            OmegaConf.update(cfg_from_terminal, "model.kwargs.include_skip_connections", True)
            OmegaConf.update(cfg_from_terminal, "radius", rad)
            OmegaConf.update(cfg_from_terminal, "gpus", [2])
            print(cfg_from_terminal.training.filepath)
            output = test(cfg_from_terminal)
            trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
            test_loss_rad_dict[rad] = test_results[0]['test_loss']
            r2_loss_rad_dict[rad] = np.corrcoef(inputs[: response_indexes], gene_expressions) ** 2
        except:
            try:
                OmegaConf.update(cfg_from_terminal, "model.kwargs.hidden_dimensions", [256, 256, 256, 256, 256, 256])
                print(cfg_from_terminal.training.filepath)
                output = test(cfg_from_terminal)
                trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
                test_loss_rad_dict[rad] = test_results[0]['test_loss']
                r2_loss_rad_dict[rad] = np.corrcoef(inputs[: response_indexes], gene_expressions) ** 2
            except:
                print(f"Model with radius of {rad} micrometers doesn't exist :(")

MonetDense__[512, 512, 512, 512, 512, 512]__0__table2_FULL_no_celltypes__Adam


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:21<00:00,  1.11it/s]
Lightning automatically upgraded your loaded checkpoint from v1.8.6 to v2.1.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../output/lightning_logs/checkpoints/MonetDense/MonetDense__[512, 512, 512, 512, 512, 512]__0__table2_FULL_no_celltypes__Adam.ckpt`


MonetDense__[256, 256, 256, 256, 256, 256]__0__table2_FULL_no_celltypes__Adam


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:22<00:00,  1.07it/s]
Lightning automatically upgraded your loaded checkpoint from v1.8.6 to v2.1.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../output/lightning_logs/checkpoints/MonetDense/MonetDense__[256, 256, 256, 256, 256, 256]__0__table2_FULL_no_celltypes__Adam.ckpt`


Model with radius of 0 micrometers doesn't exist :(
MonetDense__[512, 512, 512, 512, 512, 512]__25__table2_FULL_no_celltypes__Adam


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:21<00:00,  1.13it/s]
Lightning automatically upgraded your loaded checkpoint from v1.8.6 to v2.1.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../output/lightning_logs/checkpoints/MonetDense/MonetDense__[512, 512, 512, 512, 512, 512]__25__table2_FULL_no_celltypes__Adam.ckpt`


MonetDense__[256, 256, 256, 256, 256, 256]__25__table2_FULL_no_celltypes__Adam


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:22<00:00,  1.06it/s]
Lightning automatically upgraded your loaded checkpoint from v1.8.6 to v2.1.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../output/lightning_logs/checkpoints/MonetDense/MonetDense__[256, 256, 256, 256, 256, 256]__25__table2_FULL_no_celltypes__Adam.ckpt`


Model with radius of 25 micrometers doesn't exist :(
