# Get the names and indexes of the response genes.

In [1]:
import numpy as np
import pandas as pd
import os
import json
import matplotlib.pyplot as plt

import os
import torch

import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import random_split
from torch_geometric.data import DataLoader

from spatial.models.monet_ae import MonetDense
from spatial.predict import test

import hydra
from hydra.experimental import compose, initialize

os.environ["MKL_NUM_THREADS"]="1"
os.environ["NUMEXPR_NUM_THREADS"]="1"
os.environ["OMP_NUM_THREADS"]="1"
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6"

In [2]:
response_genes=np.array(['Ace2', 'Aldh1l1', 'Amigo2', 'Ano3', 'Aqp4', 'Ar', 'Arhgap36',
       'Baiap2', 'Ccnd2', 'Cd24a', 'Cdkn1a', 'Cenpe', 'Chat', 'Coch',
       'Col25a1', 'Cplx3', 'Cpne5', 'Creb3l1', 'Cspg5', 'Cyp19a1',
       'Cyp26a1', 'Dgkk', 'Ebf3', 'Egr2', 'Ermn', 'Esr1', 'Etv1',
       'Fbxw13', 'Fezf1', 'Gbx2', 'Gda', 'Gem', 'Gjc3', 'Greb1',
       'Irs4', 'Isl1', 'Klf4', 'Krt90', 'Lmod1', 'Man1a', 'Mbp', 'Mki67',
       'Mlc1', 'Myh11', 'Ndnf', 'Ndrg1', 'Necab1', 'Nnat', 'Nos1',
       'Npas1', 'Nup62cl', 'Omp', 'Onecut2', 'Opalin', 'Pak3', 'Pcdh11x',
       'Pgr', 'Plin3', 'Pou3f2', 'Rgs2', 'Rgs5', 'Rnd3', 'Scgn',
       'Serpinb1b', 'Sgk1', 'Slc15a3', 'Slc17a6', 'Slc17a8', 'Slco1a4',
       'Sln', 'Sox4', 'Sox6', 'Sox8', 'Sp9', 'Synpr', 'Syt2', 'Syt4',
       'Sytl4', 'Th', 'Tiparp', 'Tmem108', 'Traf4', 'Ttn', 'Ttyh2'])

In [3]:
data = pd.read_csv("../data/raw/merfish.csv")

In [4]:
data = data.drop(["Fos", "Blank_1", "Blank_2", "Blank_3", "Blank_4", "Blank_5"], axis=1)

In [5]:
data = data.iloc[:, 9:]

In [None]:
data.shape, data.columns # should start with "Ace2"

In [7]:
response_indexes = [0,2,3,4,5,6,7,10,19,20,21,22,23,24,25,26,27,28,32,34,35,37,38,39,40,41,42,43,44,52,53,54,55,58,63,64,66,67,69,71,73,74,75,76,77,78,79,80,85,86,87,88,93,94,96,97,99,102,103,104,106,110,112,113,114,116,118,119,120,121,122,123,124,125,126,129,130,131,133,134,141,142,147,151]

In [8]:
IGNORANT_RADIUS = 0
AWARE_RADIUS = [5,10,15,20,25,30,35,40,45,50]

# Load the test losses for spatially ignorant baseline.

In [None]:
test_loss_rad_dict = {}

for rad in [IGNORANT_RADIUS]:
    with initialize(config_path="../config"):
        # try:
            cfg_from_terminal = compose(config_name="config")
            OmegaConf.update(cfg_from_terminal, "paths.data", "../data")
            OmegaConf.update(cfg_from_terminal, "paths.root", "/nfs/turbo/lsa-regier/scratch/roko")
            OmegaConf.update(cfg_from_terminal, "model.kwargs.observables_dimension", 87)
            OmegaConf.update(cfg_from_terminal, "model.kwargs.hidden_dimensions", [512, 512, 512, 512, 512, 512])
            OmegaConf.update(cfg_from_terminal, "model.kwargs.output_dimension", 84)
            OmegaConf.update(cfg_from_terminal, "optimizer.name", "Adam")
            OmegaConf.update(cfg_from_terminal, "model.kwargs.kernel_size", 15)
            OmegaConf.update(cfg_from_terminal, "training.logger_name", "WITH_CELLTYPES")
            OmegaConf.update(cfg_from_terminal, "training.trainer.strategy", "auto")
            OmegaConf.update(cfg_from_terminal, "datasets.dataset.include_celltypes", True)
            OmegaConf.update(cfg_from_terminal, "model.kwargs.include_skip_connections", True)
            OmegaConf.update(cfg_from_terminal, "radius", rad)
            OmegaConf.update(cfg_from_terminal, "gpus", [2])
            print(cfg_from_terminal.training.filepath)
            output = test(cfg_from_terminal)
            trainer, l1_losses, inputs_BASE, gene_expressions_BASE, celltypes, test_results_BASE = output
            test_loss_rad_dict[rad] = test_results_BASE[0]['test_loss']
        # except:
        #     try:
        #         OmegaConf.update(cfg_from_terminal, "model.kwargs.hidden_dimensions", [256, 256, 256, 256, 256, 256])
        #         print(cfg_from_terminal.training.filepath)
        #         output = test(cfg_from_terminal)
        #         trainer, l1_losses, inputs_BASE, gene_expressions_BASE, celltypes, test_results_BASE = output
        #         test_loss_rad_dict[rad] = test_results_BASE[0]['test_loss']
        #     except:
        #         print(f"Model with radius of {rad} micrometers doesn't exist :(")

# Load the test losses for spatially aware model.

In [11]:
def calc_r2(truth, predictions):
    SS_res = torch.sum((truth - predictions)**2).item()
    SS_tot = torch.sum((truth - torch.mean(truth))**2).item()
    return 1 - SS_res/SS_tot

In [12]:
with open("0v60_w_celltypes.json", "r") as f:
    loss_dict = json.load(f)

In [None]:
r2_dict = {}
inputs_BASE_responses = inputs_BASE[:, response_indexes]

for rad in AWARE_RADIUS:
    with initialize(config_path="../config"):
        try:
            cfg_from_terminal = compose(config_name="config")
            OmegaConf.update(cfg_from_terminal, "paths.data", "../data")
            OmegaConf.update(cfg_from_terminal, "paths.root", "/nfs/turbo/lsa-regier/scratch/roko")
            OmegaConf.update(cfg_from_terminal, "model.kwargs.observables_dimension", 87)
            OmegaConf.update(cfg_from_terminal, "model.kwargs.hidden_dimensions", [512, 512, 512, 512, 512, 512])
            OmegaConf.update(cfg_from_terminal, "model.kwargs.output_dimension", 84)
            OmegaConf.update(cfg_from_terminal, "optimizer.name", "Adam")
            OmegaConf.update(cfg_from_terminal, "model.kwargs.kernel_size", 15)
            OmegaConf.update(cfg_from_terminal, "training.logger_name", "WITH_CELLTYPES")
            OmegaConf.update(cfg_from_terminal, "training.trainer.strategy", "auto")
            OmegaConf.update(cfg_from_terminal, "datasets.dataset.include_celltypes", True)
            OmegaConf.update(cfg_from_terminal, "model.kwargs.include_skip_connections", True)
            OmegaConf.update(cfg_from_terminal, "radius", rad)
            OmegaConf.update(cfg_from_terminal, "gpus", [2])
            print(cfg_from_terminal.training.filepath)
            output = test(cfg_from_terminal)
            trainer, l1_losses, inputs_SPATIAL, gene_expressions_SPATIAL, celltypes, test_results_SPATIAL = output
            test_loss_rad_dict[rad] = test_results_SPATIAL[0]['test_loss']
        except:
            try:
                OmegaConf.update(cfg_from_terminal, "model.kwargs.hidden_dimensions", [256, 256, 256, 256, 256, 256])
                print(cfg_from_terminal.training.filepath)
                output = test(cfg_from_terminal)
                trainer, l1_losses, inputs_SPATIAL, gene_expressions_SPATIAL, celltypes, test_results_SPATIAL = output
                test_loss_rad_dict[rad] = test_results_SPATIAL[0]['test_loss']
            except:
                print(f"Model with radius of {rad} micrometers doesn't exist :(")
    loss_dict[rad] = {}
    r2_dict[rad] = {}
    for i, gene in enumerate(response_genes):
        inputs_SPATIAL_responses = inputs_SPATIAL[:, response_indexes]
        current_gene = data.columns[response_indexes[i]]
        r2_dict[rad][current_gene] = {"base": calc_r2(inputs_BASE_responses[:, i], gene_expressions_BASE[:, i]), "spatial": calc_r2(inputs_SPATIAL_responses[:, i], gene_expressions_SPATIAL[:, i])}
        loss_dict[rad][current_gene] = {"base": torch.mean( (inputs_BASE_responses[:, i] - gene_expressions_BASE[:, i]) ** 2 ).item(), "spatial": torch.mean( (inputs_SPATIAL_responses[:, i] - gene_expressions_SPATIAL[:, i]) ** 2 ).item()}
        loss_dict[rad][current_gene]["diff"] = loss_dict[rad][current_gene]["spatial"] - loss_dict[rad][current_gene]["base"]
        loss_dict[rad][current_gene]["percent_diff"] = (loss_dict[rad][current_gene]["diff"]/loss_dict[rad][current_gene]["base"]) * 100.0

# Get the losses for each response gene in both cases.

In [15]:
# loss_dict

In [16]:
with open("0v60_w_celltypes.json", "w") as file:
    json.dump(loss_dict, file)

In [17]:
# r2_dict

In [18]:
with open("r2_0v60_w_celltypes.json", "w") as file:
    json.dump(r2_dict, file)

# Plot the results.

In [19]:
import matplotlib.pyplot as plt

In [20]:
percent_differences = [-loss_dict["25"][x]["percent_diff"] for x in loss_dict["25"]]

In [None]:
plt.hist(percent_differences, bins=len(loss_dict))
plt.title("% loss reduction in spatial model relative to baseline")
plt.xlabel("Percent Reduction in Test Loss")
plt.show()

# Annotated Plot

In [None]:
plt.hist(percent_differences, bins=len(loss_dict))
plt.title("% Loss reduction in spatial model relative to baseline.")
plt.xlabel("Percent Reduction in Test Loss")
plt.annotate("Nnat", xy=(39.5866-2, 1.25))
plt.annotate("Mbp", xy=(24.8099-1.3, 1.25))
plt.savefig("0VSspatial.png", dpi=300)
plt.show()

# Heatmap

In [None]:
import numpy as np
import matplotlib.pyplot as plt

with open("0v60_w_celltypes.json", "r") as file:
    loss_dict = json.load(file)

radius_values = list(range(25, -1, -5))
genes = [data.columns[i] for i in response_indexes]

loss_array = np.array([
    [-loss_dict[str(rad)][gene]["percent_diff"] if rad != 0 else 0 for rad in radius_values] for gene in genes
]).T

# Set the figure size based on the dimensions of your data
cell_width = 0.5
cell_height = 2.5
fig, ax = plt.subplots(figsize=(cell_width*len(genes), cell_height*len(radius_values)))

im = ax.imshow(loss_array, cmap='seismic', vmin=-np.max(np.abs(loss_array)), vmax=np.max(np.abs(loss_array)))

# Add a colorbar for reference
cbar = ax.figure.colorbar(im, ax=ax, fraction=0.005)
cbar.ax.set_ylabel("Loss values", rotation=-90, va="bottom")

ax.set_xticks(np.arange(len(genes)))
ax.set_xticklabels(genes)
ax.set_yticks(np.arange(len(radius_values)))
ax.set_yticklabels(list(radius_values))


# Adjust font size
ax.tick_params(axis='both', which='major', labelsize=10)

plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

# Dynamically adjust text color based on cell value and format the number to 2 decimal places
for i in range(len(radius_values)):
    for j in range(len(genes)):
        if loss_array[i, j] < 0:  # If value is negative, use white text
            text_color = 'red'
        else:  # Otherwise, use black text
            text_color = 'green'
        text = ax.text(j, i, f"{loss_array[i, j]:.2f}",
                       ha="center", va="center", color=text_color, fontsize=8)

# ax.set_title("Predictions Across $r$ Values", fontsize=25)
fig.tight_layout()
plt.savefig("spatial_horizontal_w_celltypes.png", dpi=600)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import json

with open("0v60_w_celltypes.json", "r") as file:
    loss_dict = json.load(file)

fig, axes = plt.subplots(2, 2, figsize=(12,18))

quarter = 0

full_loss_array = np.array([
    [-loss_dict[str(rad)][gene]["percent_diff"] if rad != 0 else 0 for rad in radius_values] for gene in genes
]).T

for ax in axes.ravel():
    
    quarter += 1
    radius_values = list(range(0, 26, 5))
    n = len(response_indexes)
    genes = [data.columns[response_indexes[i]] for i in range(n * (quarter - 1) // 4, n * quarter // 4)]
    
    loss_array = np.array([
        [-loss_dict[str(rad)][gene]["percent_diff"] if rad != 0 else 0 for rad in radius_values] for gene in genes
    ])

    # Set the figure size based on the dimensions of your data
    cell_width = 2.5
    cell_height = 0.5
    
    im = ax.imshow(loss_array, cmap='seismic', vmin=-np.max(np.abs(full_loss_array)), vmax=np.max(np.abs(full_loss_array)))
    
    # Add a colorbar for reference with adjusted fraction to make it thinner
    cbar = ax.figure.colorbar(im, ax=ax, fraction=0.05)
    cbar.ax.set_ylabel("% Loss Reduction", fontsize=20, rotation=-90, va="bottom")
    
    ax.set_xticks(np.arange(len(radius_values)))
    ax.set_xticklabels(list(radius_values))
    ax.set_yticks(np.arange(len(genes)))
    ax.set_yticklabels(genes)
    
    # Adjust font size
    ax.tick_params(axis='both', which='major', labelsize=10)
    
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    
    # Dynamically adjust text color based on cell value and format the number to 2 decimal places
    for i in range(len(genes)):
        for j in range(len(radius_values)):
            if loss_array[i, j] < 0:  # If value is negative, use white text
                text_color = 'red'
            else:  # Otherwise, use black text
                text_color = 'green'
            text = ax.text(j, i, f"{loss_array[i, j]:.2f}",
                           ha="center", va="center", color=text_color, fontsize=10)

# ax.set_title("Predictions Across $r$ Values", fontsize=25)
fig.tight_layout()
plt.savefig("spatial_vertical_full_w_celltypes.png", dpi=600)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import json

with open("0v60_w_celltypes.json", "r") as file:
    loss_dict = json.load(file)

radius_values = list(range(0, 26, 5))
genes = [data.columns[i] for i in response_indexes]

loss_array = np.array([
    [-loss_dict[str(rad)][gene]["percent_diff"] if rad != 0 else 0 for rad in radius_values] for gene in genes
])

# regenerate the loss array with gene filter
THRESHOLD_ABS = 20
THRESHOLD_MIN = 0
genes = [genes[i] for i in range(len(genes)) if ( (sum([np.abs(x) for x in loss_array[i]]) > THRESHOLD_ABS) or sum(loss_array[i]) < THRESHOLD_MIN )]
loss_array = np.array([
    [-loss_dict[str(rad)][gene]["percent_diff"] if rad != 0 else 0 for rad in radius_values] for gene in genes
])

# Set the figure size based on the dimensions of your data
cell_width = 2.5
cell_height = 0.5
fig, ax = plt.subplots(figsize=(cell_width*len(radius_values), cell_height*len(genes)))

im = ax.imshow(loss_array, cmap='seismic', vmin=-np.max(np.abs(loss_array)), vmax=np.max(np.abs(loss_array)))

# Add a colorbar for reference with adjusted fraction to make it thinner
cbar = ax.figure.colorbar(im, ax=ax, fraction=0.05)
cbar.ax.set_ylabel("% Loss Reduction", fontsize=20, rotation=-90, va="bottom")

ax.set_xticks(np.arange(len(radius_values)))
ax.set_xticklabels(list(radius_values))
ax.set_yticks(np.arange(len(genes)))
ax.set_yticklabels(genes)

# Adjust font size
ax.tick_params(axis='both', which='major', labelsize=10)

plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

# Dynamically adjust text color based on cell value and format the number to 2 decimal places
for i in range(len(genes)):
    for j in range(len(radius_values)):
        if loss_array[i, j] < 0:  # If value is negative, use white text
            text_color = 'red'
        else:  # Otherwise, use black text
            text_color = 'green'
        text = ax.text(j, i, f"{loss_array[i, j]:.2f}",
                       ha="center", va="center", color=text_color, fontsize=10)

# ax.set_title("Predictions Across $r$ Values", fontsize=25)
# fig.tight_layout()
plt.savefig("spatial_vertical_w_celltypes.png", dpi=600)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import json

with open("r2_0v60_w_celltypes.json", "r") as file:
    r2_dict = json.load(file)

fig, axes = plt.subplots(2, 2, figsize=(12,18))

quarter = 0

radius_values = list(range(0, 26, 5))
genes = [data.columns[i] for i in response_indexes]

full_r2_array = np.array([
    [r2_dict[str(rad)][gene]["spatial"] if rad != 0 else r2_dict["5"][gene]["base"] for rad in radius_values] for gene in genes
])

for ax in axes.ravel():
    
    quarter += 1
    n = len(response_indexes)
    genes = [data.columns[response_indexes[i]] for i in range(n * (quarter - 1) // 4, n * quarter // 4)]
    
    r2_array = np.array([
        [r2_dict[str(rad)][gene]["spatial"] if rad != 0 else r2_dict["5"][gene]["base"] for rad in radius_values] for gene in genes
    ])

    # Set the figure size based on the dimensions of your data
    cell_width = 2.5
    cell_height = 0.5
    
    im = ax.imshow(r2_array, cmap='coolwarm', vmin=0, vmax=1)
    
    # Add a colorbar for reference with adjusted fraction to make it thinner
    cbar = ax.figure.colorbar(im, ax=ax, fraction=0.05)
    cbar.ax.set_ylabel(r"$R^2$", fontsize=20, rotation=-90, va="bottom")
    
    ax.set_xticks(np.arange(len(radius_values)))
    ax.set_xticklabels(list(radius_values))
    ax.set_yticks(np.arange(len(genes)))
    ax.set_yticklabels(genes)
    
    # Adjust font size
    ax.tick_params(axis='both', which='major', labelsize=10)
    
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    
    # Dynamically adjust text color based on cell value and format the number to 2 decimal places
    for i in range(len(genes)):
        for j in range(len(radius_values)):
            if r2_array[i, j] < r2_array[i, 0]:  # If value is negative, use white text
                text_color = 'red'
            else:  # Otherwise, use black text
                text_color = 'white'
            text = ax.text(j, i, f"{r2_array[i, j]:.2f}",
                           ha="center", va="center", color=text_color, fontsize=10)

# ax.set_title("Predictions Across $r$ Values", fontsize=25)
fig.tight_layout()
plt.savefig("spatial_vertical_r2_w_celltypes.png", dpi=600)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import json

with open("r2_0v60_w_celltypes.json", "r") as file:
    r2_dict = json.load(file)

radius_values = list(range(0, 26, 5))
genes = [data.columns[i] for i in response_indexes]

r2_array = np.array([
    [r2_dict[str(rad)][gene]["spatial"] if rad != 0 else r2_dict["5"][gene]["base"] for rad in radius_values] for gene in genes
])

# Set the figure size based on the dimensions of your data
cell_width = 2.5
cell_height = 0.5
fig, ax = plt.subplots(figsize=(cell_width*len(radius_values), cell_height*len(genes)))

im = ax.imshow(r2_array, vmin=0, vmax=1, cmap="coolwarm")
# im = ax.imshow(r2_array, vmin=0, vmax=1, cmap="PiYG")

# Add a colorbar for reference with adjusted fraction to make it thinner
cbar = ax.figure.colorbar(im, ax=ax, fraction=0.05)
cbar.ax.set_ylabel(r"$R^2$", fontsize=20, rotation=-90, va="bottom")

ax.set_xticks(np.arange(len(radius_values)))
ax.set_xticklabels(list(radius_values))
ax.set_yticks(np.arange(len(genes)))
ax.set_yticklabels(genes)

# Adjust font size
ax.tick_params(axis='both', which='major', labelsize=10)

plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

# Dynamically adjust text color based on cell value and format the number to 2 decimal places
for i in range(len(genes)):
    for j in range(len(radius_values)):
        if r2_array[i, j] < r2_array[i, 0]:  # If value is negative, use white text
            text_color = 'red'
        else:  # Otherwise, use black text
            text_color = 'white'
        text = ax.text(j, i, f"{r2_array[i, j]:.2f}",
                       ha="center", va="center", color=text_color, fontsize=10)

# ax.set_title("Predictions Across $r$ Values", fontsize=25)
# fig.tight_layout()
plt.savefig("spatial_vertical_r2_w_celltypes.png", dpi=600)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import json

with open("0v60_w_celltypes.json", "r") as file:
    loss_dict = json.load(file)

radius_values = list(range(0, 26, 5))
genes = [data.columns[i] for i in response_indexes]

loss_array = np.array([
    [- (loss_dict[str(rad)][gene]["spatial"] - loss_dict[str(rad)][gene]["base"]) if rad != 0 else 0 for rad in radius_values] for gene in genes
])

# regenerate the loss array with gene filter
THRESHOLD_ABS = 0.03
THRESHOLD_MIN = 0
genes = [genes[i] for i in range(len(genes)) if ( (sum([np.abs(x) for x in loss_array[i]]) > THRESHOLD_ABS) or sum(loss_array[i]) < THRESHOLD_MIN )]
loss_array = np.array([
    [- (loss_dict[str(rad)][gene]["spatial"] - loss_dict[str(rad)][gene]["base"]) if rad != 0 else 0 for rad in radius_values] for gene in genes
])

# Set the figure size based on the dimensions of your data
cell_width = 2.5
cell_height = 0.5
fig, ax = plt.subplots(figsize=(cell_width*len(radius_values), cell_height*len(genes)))

im = ax.imshow(loss_array, cmap='seismic', vmin=-np.max(np.abs(loss_array)), vmax=np.max(np.abs(loss_array)))

# Add a colorbar for reference with adjusted fraction to make it thinner
cbar = ax.figure.colorbar(im, ax=ax, fraction=0.05)
cbar.ax.set_ylabel("% Loss Reduction", fontsize=20, rotation=-90, va="bottom")

ax.set_xticks(np.arange(len(radius_values)))
ax.set_xticklabels(list(radius_values))
ax.set_yticks(np.arange(len(genes)))
ax.set_yticklabels(genes)

# Adjust font size
ax.tick_params(axis='both', which='major', labelsize=10)

plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

# Dynamically adjust text color based on cell value and format the number to 2 decimal places
for i in range(len(genes)):
    for j in range(len(radius_values)):
        if loss_array[i, j] < 0:  # If value is negative, use white text
            text_color = 'red'
        else:  # Otherwise, use black text
            text_color = 'green'
        text = ax.text(j, i, f"{loss_array[i, j]:.2f}",
                       ha="center", va="center", color=text_color, fontsize=10)

# ax.set_title("Predictions Across $r$ Values", fontsize=25)
# fig.tight_layout()
plt.savefig("spatial_vertical_LRT_w_celltypes.png", dpi=600)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import json

with open("0v60_w_celltypes.json", "r") as file:
    loss_dict = json.load(file)

radius_values = list(range(0, 26, 5))
genes = [data.columns[i] for i in response_indexes]

loss_array = np.array([
    [- (loss_dict[str(rad)][gene]["spatial"] - loss_dict[str(rad)][gene]["base"]) if rad != 0 else 0 for rad in radius_values] for gene in genes
])

# regenerate the loss array with gene filter
# THRESHOLD_ABS = 0.03
# THRESHOLD_MIN = 0
# genes = [genes[i] for i in range(len(genes)) if ( (sum([np.abs(x) for x in loss_array[i]]) > THRESHOLD_ABS) or sum(loss_array[i]) < THRESHOLD_MIN )]
# loss_array = np.array([
#     [- (loss_dict[str(rad)][gene]["spatial"] - loss_dict[str(rad)][gene]["base"]) if rad != 0 else 0 for rad in radius_values] for gene in genes
# ])

# Set the figure size based on the dimensions of your data
cell_width = 2.5
cell_height = 0.5
fig, ax = plt.subplots(figsize=(cell_width*len(radius_values), cell_height*len(genes)))

im = ax.imshow(loss_array, cmap='seismic', vmin=-np.max(np.abs(loss_array)), vmax=np.max(np.abs(loss_array)))

# Add a colorbar for reference with adjusted fraction to make it thinner
cbar = ax.figure.colorbar(im, ax=ax, fraction=0.05)
cbar.ax.set_ylabel(r"$MSE_{r} - MSE_{0}$", fontsize=20, rotation=-90, va="bottom")

ax.set_xticks(np.arange(len(radius_values)))
ax.set_xticklabels(list(radius_values))
ax.set_yticks(np.arange(len(genes)))
ax.set_yticklabels(genes)

# Adjust font size
ax.tick_params(axis='both', which='major', labelsize=10)

plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

# Dynamically adjust text color based on cell value and format the number to 2 decimal places
for i in range(len(genes)):
    for j in range(len(radius_values)):
        if loss_array[i, j] < 0:  # If value is negative, use white text
            text_color = 'red'
        else:  # Otherwise, use black text
            text_color = 'green'
        text = ax.text(j, i, f"{loss_array[i, j]:.2f}",
                       ha="center", va="center", color=text_color, fontsize=10)

# ax.set_title("Predictions Across $r$ Values", fontsize=25)
# fig.tight_layout()
plt.savefig("spatial_vertical_LRT_full_w_celltypes.png", dpi=600)
plt.show()