In [None]:
import os
import torch

os.environ["MKL_NUM_THREADS"]="1"
os.environ["NUMEXPR_NUM_THREADS"]="1"
os.environ["OMP_NUM_THREADS"]="1"

import sys

import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import random_split
from torch_geometric.data import DataLoader

from spatial.merfish_dataset import FilteredMerfishDataset, MerfishDataset, SyntheticDataset0, SyntheticDataset1, SyntheticDataset2, SyntheticDataset3
from spatial.models.monet_ae import MonetAutoencoder2D, TrivialAutoencoder, MonetDense
from spatial.train import train
from spatial.predict import test

In [None]:
import hydra
from hydra.experimental import compose, initialize

test_loss_rad_dict = {}
for rad in range(0,65,5):
    for synthetic_exp in range(4):
        with initialize(config_path=f"../../config"):
            cfg_from_terminal = compose(config_name=f"config{synthetic_exp}")
            OmegaConf.update(cfg_from_terminal, "model.kwargs.hidden_dimensions", [128, 128])
            OmegaConf.update(cfg_from_terminal, "training.logger_name", f"synthetic{synthetic_exp}")
            OmegaConf.update(cfg_from_terminal, "radius", rad)
            OmegaConf.update(cfg_from_terminal, "model.kwargs.response_genes", [0])
            OmegaConf.update(cfg_from_terminal, "datasets.dataset.splits", 2)
            output = test(cfg_from_terminal)
            trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
            test_loss_rad_dict[(rad, synthetic_exp)] = test_results[0]['test_loss']

In [None]:
test_loss_rad_dict

In [None]:
# test_loss_rad_dict = {str(k): v for k,v in test_loss_rad_dict.items()}

# import json

# with open('../deepST_synthetic_results.json', 'w') as deepST:
#     json.dump(test_loss_rad_dict, deepST)

In [None]:
import json

with open('../deepST_synthetic_results.json', 'r') as deepST:
    test_loss_rad_dict = json.load(deepST)

In [None]:
import json

with open('../LightGBM_synthetic_results.json', 'r') as linear:
    linear_models_dict = json.load(linear)

In [None]:
linear_models_dict

In [None]:
# create dataframe
import pandas as pd

# linear models + XGBoost
data = pd.DataFrame(columns=['Model', 'Radius', 'Experiment #', 'L1 Loss'])
for k,v in linear_models_dict.items():
    k = k.split()
    entry = pd.DataFrame.from_dict({
        'Model': [k[0]],
        'Radius': [int(k[1])],
        'Experiment #': [int(k[2])],
        'L1 Loss': [v]
    })
    data = pd.concat([data, entry], ignore_index=True)

# deepST
for k,v in test_loss_rad_dict.items():
    k = eval(k)
    entry = pd.DataFrame.from_dict({
        'Model': 'deepST',
        'Radius': [k[0]],
        'Experiment #': [k[1]],
        'L1 Loss': [v]
    })
    data = pd.concat([data, entry], ignore_index=True)

data

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(2, 2, figsize=(18, 18))

# create pointplot
for exp in range(4):
    new_data = data[data['Experiment #'] == exp]
    synth_results = sns.pointplot(ax=axes[exp // 2, exp % 2], x='Radius', y= 'L1 Loss', hue='Model', data=new_data)
    baseline = new_data[(new_data["Model"] == "deepST") & (new_data["Radius"] == 0)]["L1 Loss"]
    synth_results.axhline(baseline.item(), linestyle = '-', linewidth=3)
    synth_results.set_title(f"Experiment #{exp}", fontsize=25)
    synth_results.set_xlabel("Radius", fontsize = 20)
    synth_results.set_ylabel("L1 Loss", fontsize = 20)
    synth_results.set_xticklabels(range(0, 65, 5), fontsize=15)
plt.savefig(f'synth_experiments.png')