In [1]:
import os
import torch

os.environ["MKL_NUM_THREADS"]="1"
os.environ["NUMEXPR_NUM_THREADS"]="1"
os.environ["OMP_NUM_THREADS"]="1"

import sys

import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import random_split
from torch_geometric.data import DataLoader

from spatial.merfish_dataset import FilteredMerfishDataset, MerfishDataset, SyntheticDataset0, SyntheticDataset1, SyntheticDataset2, SyntheticDataset3
from spatial.models.monet_ae import MonetAutoencoder2D, TrivialAutoencoder, MonetDense
from spatial.train import train
from spatial.predict import test

# makes the notebook loadable
import warnings
warnings.filterwarnings('ignore')

In [2]:
import hydra
from hydra.experimental import compose, initialize

test_loss_rad_dict = {}
zero_separated_loss_dict = {}


for tol in [0, 0.25, 0.5, 0.75, 1, 1.5, 2]:
    with initialize(config_path="../config"):
        cfg_from_terminal = compose(config_name="config0")
        OmegaConf.update(cfg_from_terminal, "model.kwargs.hidden_dimensions", [128, 128])
        OmegaConf.update(cfg_from_terminal, "training.logger_name", f"synthetic0_zeroinf")
        OmegaConf.update(cfg_from_terminal, "radius", 30)
        OmegaConf.update(cfg_from_terminal, "datasets.dataset.splits", 2)
        OmegaConf.update(cfg_from_terminal, "datasets.dataset.tol", tol)
        OmegaConf.update(cfg_from_terminal, "training.filepath", f"MonetDense__[128, 128]__[0]__30__synthetic0_zeroinf__{tol}")
        output = test(cfg_from_terminal)
        trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
        print(inputs)
        break
        test_loss_rad_dict[tol] = test_results[0]['test_loss']
        num_zeros = torch.sum(inputs[:, 0] == 0)
        n = len(inputs[:, 0])
        loss_for_non_zeros = (n / (n - num_zeros)) * torch.mean(((inputs[:, 0] - gene_expressions[:, 0]) ** 2) * (inputs[:, 0] != 0).reshape(1, -1))
        loss_for_zeros = (n / num_zeros) * torch.mean(((inputs[:, 0] - gene_expressions[:, 0]) ** 2) * (inputs[:, 0] == 0).reshape(1, -1))
        loss = torch.mean((inputs[:, 0] - gene_expressions[:, 0]) ** 2)
        zero_separated_loss_dict[tol] = {
            "Loss (MSE)": loss.item(),
            "Loss for non-0s": loss_for_non_zeros.item(),
            "Loss for 0s": loss_for_zeros.item()
        }

  0%|                                                                                   | 0/24 [00:00<?, ?it/s]

tensor([0.8755, 0.3365, 0.3365, 0.7885, 0.7885, 0.1823, 0.4700, 0.8755, 0.8755,
        1.0296, 1.0296, 0.4700, 0.7885, 0.5878, 0.4700, 0.3365, 0.4700, 1.0986,
        0.6931, 1.0296, 0.6931, 1.2238, 0.5878, 1.0296, 0.7885, 1.0986, 0.6931,
        1.1632, 0.0000, 1.0986, 0.0000, 1.1632, 1.1632, 1.0986, 0.9555, 0.3365,
        0.4700, 0.0000, 0.6931, 0.6931, 0.8755, 0.7885, 0.9555, 0.7885, 0.4700,
        0.7885, 0.7885, 1.2809, 0.7885, 0.6931, 0.6931, 1.0296, 1.4816, 0.4700,
        0.8755, 0.8755, 0.7885, 1.1632, 0.4700, 0.4700, 0.4700, 0.4700, 0.8755,
        1.2238, 1.3350, 1.6094, 0.8755, 0.6931, 1.3863, 0.4700, 0.6931, 1.0986,
        1.1632, 0.8755, 0.7885, 0.5878, 0.5878, 1.1632, 0.6931, 0.6931, 0.4700,
        1.0986, 0.7885, 0.7885, 1.0986, 0.4700, 1.0986, 1.2809, 1.2238, 1.3863,
        1.0296, 0.9555, 0.7885, 1.0296, 0.8755, 0.1823, 0.9555, 0.9555, 0.6931,
        0.7885, 0.7885, 0.3365, 0.6931, 0.4700, 1.0986, 0.7885, 0.7885, 0.3365,
        1.2809, 1.0296, 1.2238, 0.3365, 

tensor([1.0296, 1.0296, 1.0296, 0.3365, 0.3365, 0.3365, 0.3365, 0.4700, 0.1823,
        0.4700, 0.0000, 0.1823, 0.5878, 0.0000, 0.1823, 0.9555, 0.1823, 0.0000,
        0.4700, 0.0000, 0.6931, 0.6931, 0.6931, 0.3365, 0.4700, 0.0000, 0.7885,
        0.8755, 0.3365, 0.3365, 0.8755, 0.8755, 0.3365, 0.8755, 0.5878, 0.1823,
        0.3365, 0.5878, 0.5878, 0.5878, 0.3365, 0.6931, 0.6931, 0.7885, 0.8755,
        0.4700, 0.1823, 0.4700, 0.3365, 0.3365, 0.7885, 0.4700, 0.4700, 0.5878,
        0.1823, 0.3365, 0.0000, 0.5878, 0.4700, 0.6931, 0.0000, 0.4700, 0.6931,
        0.3365, 0.6931, 0.6931, 0.0000, 0.3365, 0.6931, 0.3365, 0.6931, 0.5878,
        0.7885, 0.8755, 0.4700, 0.4700, 0.4700, 0.4700, 0.5878, 0.1823, 0.5878,
        0.7885, 1.0986, 0.9555, 1.1632, 0.9555, 0.3365, 1.0986, 1.4351, 1.1632,
        1.3350, 0.5878, 1.0296, 0.5878, 1.2809, 0.6931, 0.9555, 0.9555, 0.5878,
        1.2238, 0.7885, 0.7885, 0.5878, 0.8755, 0.4700, 0.7885, 1.2238, 0.6931,
        0.1823, 1.0986, 0.3365, 0.4700, 

tensor([0.8755, 0.7885, 0.1823, 0.1823, 0.0000, 0.1823, 0.3365, 0.0000, 0.7885,
        0.1823, 0.3365, 0.0000, 0.3365, 0.0000, 0.0000, 0.3365, 0.3365, 0.3365,
        0.3365, 0.1823, 0.5878, 0.4700, 0.4700, 0.4700, 0.3365, 0.5878, 0.3365,
        0.6931, 0.6931, 0.6931, 0.3365, 0.6931, 0.3365, 0.4700, 0.5878, 0.6931,
        0.5878, 0.1823, 0.5878, 0.7885, 0.1823, 0.3365, 0.1823, 0.5878, 0.6931,
        1.0986, 0.1823, 0.3365, 0.8755, 1.0296, 0.3365, 0.3365, 0.5878, 0.5878,
        0.5878, 0.1823, 0.3365, 0.7885, 0.3365, 0.1823, 0.5878, 0.3365, 0.5878,
        0.9555, 0.8755, 0.3365, 0.0000, 0.4700, 0.5878, 0.4700, 1.0296, 0.7885,
        0.4700, 0.5878, 1.0296, 0.1823, 0.5878, 1.0296, 1.2238, 0.3365, 1.2238,
        0.7885, 0.8755, 1.5686, 1.4351, 1.1632, 0.9555, 1.0986, 1.2809, 1.0296,
        1.2238, 1.2809, 0.6931, 0.6931, 0.8755, 0.5878, 1.0296, 0.8755, 0.6931,
        0.6931, 0.1823, 1.0986, 0.5878, 0.7885, 0.3365, 0.8755, 0.4700, 0.9555,
        0.8755, 0.9555, 0.1823, 0.5878, 

  4%|███▏                                                                       | 1/24 [00:01<00:36,  1.58s/it]

tensor([0.9555, 0.9555, 0.3365, 0.1823, 0.3365, 0.6931, 0.6931, 0.6931, 0.5878,
        0.8755, 0.9555, 0.7885, 0.9555, 0.6931, 0.9555, 1.2238, 1.2809, 0.5878,
        1.3350, 0.6931, 0.4700, 1.3350, 0.8755, 0.7885, 1.0986, 0.5878, 1.0986,
        0.5878, 0.0000, 0.7885, 0.5878, 1.5261, 1.4351, 1.5686, 1.0296, 1.2238,
        0.7885, 1.5261, 0.5878, 1.4816, 1.3863, 0.6931, 1.0986, 0.5878, 1.2809,
        0.8755, 0.6931, 0.9555, 0.5878, 0.5878, 0.6931, 1.0296, 0.5878, 0.3365,
        0.4700, 0.7885, 0.5878, 0.6931, 0.0000, 0.1823, 0.0000, 0.3365, 1.0986,
        1.5261, 1.2809, 0.5878, 0.0000, 0.9555, 1.1632, 0.0000, 0.7885, 0.3365,
        0.4700, 0.4700, 0.3365, 0.1823, 0.5878, 0.4700, 0.7885, 0.4700, 0.4700,
        1.0296, 1.1632, 1.3350, 1.2809, 1.0986, 1.3350, 1.0296, 1.3350, 1.1632,
        0.8755, 0.7885, 0.8755, 1.3350, 0.8755, 1.0296, 1.0296, 0.9555, 0.3365,
        0.3365, 1.1632, 0.5878, 0.3365, 0.4700, 0.3365, 0.5878, 0.1823, 1.0296,
        0.4700, 0.9555, 0.4700, 0.4700, 

tensor([0.5878, 0.0000, 0.0000, 0.4700, 0.8755, 0.5878, 0.4700, 0.5878, 0.1823,
        0.5878, 0.1823, 0.4700, 0.1823, 0.8755, 0.5878, 1.0986, 0.1823, 0.8755,
        0.7885, 0.1823, 1.0296, 0.5878, 0.5878, 0.7885, 0.7885, 0.5878, 0.4700,
        0.7885, 0.3365, 0.7885, 1.0296, 1.0296, 1.0986, 1.1632, 0.0000, 0.4700,
        1.0296, 0.5878, 0.8755, 0.9555, 0.4700, 1.0296, 0.8755, 0.8755, 0.1823,
        0.7885, 0.1823, 0.6931, 0.7885, 0.8755, 0.4700, 0.5878, 0.4700, 0.1823,
        0.3365, 0.7885, 0.1823, 0.9555, 0.9555, 1.0986, 0.5878, 1.0986, 0.4700,
        1.0296, 1.0986, 0.1823, 1.1632, 1.0296, 0.6931, 1.1632, 1.0986, 0.5878,
        0.4700, 1.0986, 0.3365, 0.6931, 0.6931, 0.5878, 0.6931, 0.3365, 0.4700,
        0.3365, 0.4700, 0.6931, 0.5878, 1.0986, 0.9555, 0.9555, 1.0296, 1.0986,
        1.0986, 0.7885, 1.0296, 0.7885, 0.8755, 0.7885, 1.2238, 0.8755, 0.9555,
        0.3365, 1.0296, 0.5878, 0.5878, 0.3365, 0.4700, 0.3365, 0.6931, 0.4700,
        0.6931, 0.7885, 0.7885, 0.3365, 

tensor([0.6931, 0.1823, 0.0000, 0.5878, 0.1823, 0.4700, 0.5878, 0.9555, 0.4700,
        0.6931, 0.1823, 0.1823, 0.4700, 0.4700, 0.5878, 0.6931, 0.5878, 0.6931,
        0.6931, 0.6931, 0.5878, 0.4700, 0.3365, 0.4700, 0.1823, 0.1823, 0.5878,
        0.5878, 0.1823, 0.1823, 0.5878, 0.8755, 0.9555, 0.8755, 0.3365, 0.8755,
        0.8755, 0.3365, 0.0000, 0.1823, 0.4700, 0.3365, 0.3365, 0.3365, 0.0000,
        0.4700, 1.0296, 0.1823, 0.5878, 0.3365, 1.0296, 1.0986, 0.4700, 0.1823,
        0.4700, 1.2238, 0.8755, 0.7885, 0.8755, 0.8755, 0.8755, 0.3365, 0.3365,
        0.4700, 0.9555, 0.4700, 0.4700, 0.3365, 0.1823, 0.1823, 0.3365, 0.3365,
        0.1823, 0.5878, 0.4700, 0.4700, 0.4700, 0.0000, 0.3365, 0.5878, 0.1823,
        0.7885, 0.6931, 0.6931, 0.6931, 0.7885, 0.8755, 0.1823, 1.3350, 0.1823,
        1.3350, 0.6931, 1.2809, 1.3350, 0.9555, 0.6931, 0.6931, 1.0986, 0.3365,
        0.7885, 0.9555, 0.3365, 0.7885, 0.4700, 0.7885, 0.6931, 0.3365, 0.9555,
        1.0986, 0.1823, 0.6931, 0.3365, 

tensor([0.6931, 0.0000, 0.0000, 0.4700, 0.5878, 0.5878, 0.1823, 0.7885, 0.3365,
        0.8755, 0.3365, 0.8755, 1.0296, 0.6931, 0.4700, 0.9555, 1.0296, 0.8755,
        0.7885, 1.0986, 0.1823, 0.9555, 0.7885, 1.2809, 1.1632, 0.1823, 0.6931,
        0.7885, 1.2809, 0.7885, 0.4700, 1.0296, 0.7885, 1.0986, 0.4700, 0.9555,
        1.0986, 1.2809, 0.8755, 1.1632, 0.3365, 1.2238, 1.0296, 1.0986, 0.9555,
        0.7885, 0.9555, 0.3365, 0.5878, 0.4700, 0.7885, 0.8755, 0.6931, 0.9555,
        0.8755, 0.5878, 1.3350, 0.9555, 0.8755, 0.5878, 1.1632, 1.0296, 0.5878,
        0.4700, 1.2809, 1.2238, 0.5878, 1.2809, 0.5878, 1.1632, 1.4816, 1.1632,
        1.0296, 1.0296, 0.4700, 1.2809, 0.5878, 0.7885, 0.7885, 0.6931, 0.4700,
        1.0986, 1.0296, 0.9555, 0.9555, 0.1823, 0.9555, 0.8755, 1.1632, 1.0986,
        1.0296, 0.1823, 0.1823, 1.1632, 1.0986, 0.7885, 0.9555, 0.9555, 0.8755,
        1.1632, 1.0296, 1.0296, 0.1823, 1.0986, 0.1823, 0.7885, 1.0296, 0.9555,
        0.6931, 0.4700, 0.9555, 0.8755, 

  8%|██████▎                                                                    | 2/24 [00:03<00:33,  1.54s/it]

tensor([0.8755, 0.7885, 0.8755, 0.5878, 0.8755, 0.8755, 0.4700, 0.7885, 0.5878,
        0.7885, 0.5878, 0.0000, 0.8755, 0.4700, 0.8755, 0.4700, 0.7885, 0.1823,
        0.6931, 0.4700, 0.5878, 0.6931, 1.4351, 0.8755, 1.0296, 0.8755, 0.9555,
        1.2809, 1.0296, 1.2809, 1.0986, 0.8755, 0.7885, 1.0296, 0.9555, 0.7885,
        0.7885, 1.3350, 0.8755, 0.6931, 0.5878, 0.9555, 0.7885, 0.6931, 0.8755,
        0.9555, 0.5878, 0.1823, 1.0986, 0.3365, 0.6931, 0.6931, 0.4700, 0.8755,
        0.9555, 1.3863, 0.8755, 1.2238, 1.0296, 0.7885, 0.5878, 0.4700, 0.4700,
        1.0986, 1.2809, 0.6931, 1.2238, 0.5878, 1.2238, 0.4700, 0.4700, 1.1632,
        1.2809, 1.1632, 0.5878, 0.4700, 0.5878, 1.3350, 0.5878, 0.6931, 0.5878,
        0.5878, 0.6931, 0.5878, 0.5878, 1.1632, 1.0296, 0.1823, 0.9555, 1.2809,
        1.1632, 0.8755, 0.4700, 0.7885, 1.3350, 1.0986, 0.8755, 0.5878, 0.6931,
        0.6931, 0.7885, 0.6931, 0.9555, 0.8755, 0.1823, 0.0000, 0.1823, 0.4700,
        0.3365, 0.4700, 0.4700, 0.6931, 

tensor([0.7885, 0.8755, 0.6931, 0.5878, 0.8755, 0.4700, 0.6931, 0.4700, 0.8755,
        0.4700, 0.3365, 0.8755, 0.6931, 0.8755, 0.7885, 0.8755, 0.9555, 0.8755,
        0.8755, 0.6931, 0.6931, 0.1823, 0.5878, 0.7885, 0.5878, 0.6931, 0.4700,
        1.1632, 0.5878, 0.5878, 0.0000, 1.1632, 0.4700, 0.4700, 0.3365, 0.8755,
        0.7885, 0.7885, 0.8755, 0.6931, 1.0296, 0.8755, 0.0000, 1.4816, 1.0986,
        0.3365, 0.0000, 0.3365, 1.0296, 0.1823, 1.0296, 0.5878, 0.4700, 0.6931,
        0.6931, 0.4700, 1.0986, 0.5878, 0.5878, 0.3365, 1.1632, 0.9555, 1.5261,
        1.5261, 1.2238, 1.0986, 0.6931, 0.9555, 0.3365, 0.7885, 0.6931, 0.1823,
        0.3365, 0.0000, 0.6931, 0.4700, 0.0000, 0.1823, 0.4700, 0.6931, 0.1823,
        0.7885, 0.0000, 0.3365, 0.8755, 0.0000, 0.6931, 1.1632, 0.1823, 1.0296,
        1.2809, 0.4700, 1.0296, 1.2809, 0.9555, 0.8755, 0.8755, 0.8755, 0.4700,
        1.2238, 1.0296, 0.7885, 0.4700, 0.5878, 0.7885, 0.5878, 0.4700, 1.0296,
        0.9555, 0.3365, 0.8755, 0.7885, 

tensor([0.6931, 0.6931, 0.6931, 0.0000, 0.4700, 0.4700, 0.6931, 0.4700, 0.4700,
        0.6931, 0.3365, 0.1823, 0.1823, 0.8755, 0.0000, 0.4700, 0.5878, 0.1823,
        0.4700, 0.4700, 0.1823, 0.4700, 0.5878, 0.4700, 0.3365, 0.1823, 0.4700,
        1.2238, 1.2238, 1.2809, 1.1632, 1.0986, 0.8755, 1.1632, 0.3365, 0.1823,
        1.0296, 0.1823, 1.0296, 0.5878, 0.4700, 0.6931, 0.4700, 1.0296, 0.4700,
        1.0986, 0.8755, 0.9555, 0.8755, 0.9555, 0.4700, 1.0986, 1.0986, 0.3365,
        0.3365, 0.4700, 0.3365, 0.1823, 0.4700, 0.1823, 0.8755, 0.1823, 1.0986,
        0.9555, 0.7885, 1.0986, 1.3350, 0.3365, 0.4700, 1.2238, 0.8755, 0.4700,
        0.8755, 0.9555, 0.5878, 0.3365, 0.0000, 0.5878, 0.6931, 0.5878, 0.5878,
        0.8755, 1.3350, 0.7885, 0.4700, 0.5878, 0.3365, 0.4700, 0.8755, 0.6931,
        1.0296, 0.7885, 1.2238, 0.9555, 1.1632, 1.0296, 1.3863, 1.0986, 0.8755,
        0.4700, 0.4700, 0.8755, 0.3365, 0.6931, 0.3365, 0.4700, 0.4700, 0.9555,
        0.6931, 0.5878, 1.0296, 0.6931, 

tensor([0.3365, 0.5878, 0.3365, 0.5878, 0.1823, 0.4700, 0.7885, 0.4700, 0.4700,
        0.6931, 0.5878, 0.4700, 0.1823, 0.8755, 0.3365, 1.1632, 0.0000, 1.1632,
        0.5878, 1.0296, 0.0000, 1.0296, 0.9555, 0.7885, 0.8755, 0.7885, 0.6931,
        0.7885, 0.1823, 0.7885, 0.5878, 1.0296, 1.0296, 0.5878, 0.0000, 1.1632,
        1.1632, 0.4700, 0.8755, 0.8755, 0.9555, 0.6931, 1.0296, 0.9555, 0.5878,
        0.6931, 1.0986, 1.0986, 0.9555, 0.5878, 0.1823, 1.2238, 0.7885, 0.9555,
        0.8755, 1.0296, 0.5878, 0.5878, 0.9555, 0.5878, 0.8755, 1.2238, 1.5686,
        1.2809, 1.5686, 1.0296, 1.1632, 1.2238, 1.2238, 1.3863, 1.1632, 1.2238,
        0.4700, 0.4700, 1.2238, 1.3350, 0.4700, 0.8755, 0.5878, 0.5878, 0.4700,
        0.7885, 0.7885, 0.4700, 1.0296, 1.0986, 0.4700, 0.6931, 0.6931, 0.7885,
        1.6094, 1.0296, 1.1632, 0.6931, 0.6931, 0.6931, 1.6094, 1.0296, 1.4816,
        1.1632, 1.2238, 1.1632, 0.5878, 0.6931, 0.5878, 1.0296, 1.2238, 1.0986,
        1.0296, 0.9555, 1.1632, 1.0296, 

 12%|█████████▍                                                                 | 3/24 [00:04<00:31,  1.48s/it]

tensor([0.3365, 0.5878, 0.8755, 0.7885, 0.6931, 0.7885, 1.0296, 0.5878, 0.5878,
        0.4700, 0.1823, 0.7885, 0.7885, 0.6931, 0.5878, 0.8755, 0.3365, 0.4700,
        0.3365, 0.4700, 0.3365, 1.1632, 0.6931, 1.0296, 0.5878, 1.0296, 1.0986,
        1.3350, 1.3350, 0.5878, 1.2238, 1.1632, 1.3863, 1.4351, 1.0986, 1.0986,
        1.0986, 0.6931, 0.5878, 0.6931, 1.3350, 0.5878, 0.6931, 1.6094, 1.5261,
        0.4700, 0.5878, 0.4700, 1.1632, 0.3365, 0.6931, 0.7885, 0.7885, 0.3365,
        0.5878, 1.1632, 0.4700, 1.1632, 1.0296, 0.5878, 1.5686, 1.2809, 1.0986,
        0.5878, 1.4351, 0.5878, 1.1632, 0.5878, 0.4700, 0.4700, 0.4700, 0.7885,
        1.0986, 0.7885, 0.5878, 0.7885, 0.8755, 0.7885, 1.0296, 1.1632, 1.0296,
        0.5878, 0.5878, 0.9555, 0.9555, 0.4700, 0.3365, 0.5878, 0.3365, 0.6931,
        0.9555, 1.0986, 0.7885, 0.7885, 0.8755, 1.0986, 0.9555, 0.5878, 1.0986,
        0.8755, 0.6931, 0.9555, 0.6931, 0.6931, 0.6931, 0.8755, 0.6931, 1.0296,
        0.5878, 0.7885, 0.4700, 0.5878, 

tensor([1.0986, 1.1632, 0.5878, 0.7885, 0.6931, 0.7885, 0.6931, 0.4700, 1.3350,
        0.8755, 1.3350, 0.6931, 0.4700, 0.7885, 0.5878, 1.0296, 0.6931, 0.4700,
        0.7885, 0.5878, 0.7885, 0.6931, 0.7885, 0.7885, 0.6931, 0.6931, 0.6931,
        0.5878, 0.6931, 0.6931, 0.8755, 0.9555, 0.3365, 0.3365, 0.6931, 0.6931,
        0.5878, 0.0000, 0.4700, 0.3365, 0.6931, 0.6931, 0.3365, 0.3365, 0.5878,
        0.1823, 0.9555, 0.8755, 0.3365, 0.4700, 0.1823, 0.7885, 1.1632, 0.7885,
        0.3365, 0.3365, 0.4700, 0.6931, 0.5878, 0.4700, 0.6931, 0.8755, 0.6931,
        0.6931, 0.7885, 0.1823, 0.1823, 0.5878, 0.5878, 1.3350, 0.4700, 1.3350,
        0.4700, 0.3365, 0.6931, 0.6931, 0.6931, 0.7885, 0.6931, 0.5878, 0.6931,
        0.1823, 0.3365, 0.6931, 0.6931, 0.1823, 0.7885, 0.5878, 1.0986, 0.0000,
        0.5878, 0.7885, 0.5878, 0.8755, 0.5878, 0.4700, 0.3365, 0.1823, 0.3365,
        0.5878, 0.5878, 0.8755, 0.6931, 0.9555, 0.7885, 0.3365, 0.6931, 0.1823,
        0.5878, 0.1823, 0.3365, 0.1823, 

tensor([1.2238, 1.1632, 0.3365, 1.2809, 0.4700, 0.4700, 0.0000, 0.4700, 0.3365,
        1.0296, 0.7885, 0.1823, 0.4700, 0.6931, 0.3365, 0.3365, 0.1823, 0.0000,
        0.5878, 0.4700, 0.5878, 0.5878, 0.6931, 0.4700, 0.1823, 0.5878, 0.6931,
        1.1632, 1.2809, 0.4700, 0.5878, 0.4700, 0.3365, 1.2809, 0.5878, 1.2238,
        0.8755, 0.1823, 0.1823, 0.5878, 0.5878, 0.5878, 1.4351, 0.4700, 0.9555,
        0.6931, 0.1823, 0.6931, 0.0000, 0.7885, 0.5878, 0.1823, 0.0000, 0.1823,
        0.0000, 0.8755, 0.0000, 0.9555, 0.8755, 0.8755, 1.0296, 0.0000, 1.0296,
        0.1823, 1.0986, 1.0296, 0.8755, 0.4700, 0.3365, 0.4700, 0.4700, 1.0296,
        0.3365, 0.3365, 0.4700, 0.4700, 0.4700, 1.0296, 1.1632, 0.5878, 0.1823,
        0.3365, 0.3365, 1.0296, 1.0986, 0.0000, 0.6931, 1.1632, 0.4700, 1.0986,
        0.4700, 0.9555, 1.0296, 1.1632, 0.9555, 0.6931, 0.5878, 0.3365, 0.9555,
        0.3365, 0.4700, 0.1823, 1.0296, 0.1823, 0.5878, 0.6931, 0.6931, 0.4700,
        0.8755, 0.1823, 1.0986, 0.8755, 

tensor([0.9555, 0.1823, 0.1823, 0.1823, 0.5878, 0.8755, 0.5878, 0.7885, 0.7885,
        0.7885, 1.0296, 0.5878, 0.7885, 0.5878, 0.3365, 0.5878, 0.9555, 0.3365,
        0.0000, 0.5878, 0.6931, 1.0986, 0.7885, 0.4700, 0.6931, 0.5878, 0.7885,
        0.7885, 0.3365, 0.5878, 0.7885, 0.9555, 0.4700, 0.6931, 0.5878, 0.3365,
        0.1823, 0.0000, 0.1823, 0.1823, 0.3365, 0.0000, 0.3365, 0.1823, 0.3365,
        0.1823, 0.4700, 0.4700, 0.1823, 0.1823, 0.0000, 0.1823, 0.1823, 0.5878,
        0.0000, 0.0000, 0.0000, 0.6931, 0.4700, 1.1632, 0.4700, 1.0296, 1.1632,
        1.1632, 0.7885, 0.7885, 0.1823, 0.3365, 0.5878, 0.3365, 0.4700, 0.4700,
        0.3365, 0.7885, 0.6931, 0.0000, 0.0000, 0.4700, 0.4700, 0.6931, 0.1823,
        0.6931, 0.8755, 1.1632, 1.1632, 0.0000, 0.6931, 0.5878, 0.7885, 0.1823,
        0.4700, 0.3365, 1.3350, 1.2809, 1.2809, 0.4700, 0.6931, 0.4700, 0.9555,
        0.6931, 0.0000, 0.6931, 0.3365, 0.3365, 0.3365, 0.6931, 0.8755, 0.5878,
        0.8755, 0.1823, 0.8755, 0.6931, 

 17%|████████████▌                                                              | 4/24 [00:05<00:28,  1.42s/it]

tensor([1.2809, 0.7885, 1.0296, 0.4700, 0.3365, 1.5261, 0.7885, 0.4700, 1.3863,
        0.9555, 1.0296, 1.2238, 1.0296, 1.0986, 0.9555, 0.8755, 0.1823, 0.0000,
        0.1823, 0.0000, 0.3365, 1.5261, 1.3863, 1.5261, 1.5261, 0.5878, 0.6931,
        1.2809, 1.0296, 0.6931, 0.7885, 0.5878, 0.7885, 0.7885, 0.8755, 0.8755,
        0.3365, 0.4700, 0.3365, 0.5878, 0.8755, 0.7885, 0.8755, 0.0000, 0.4700,
        0.7885, 0.4700, 1.0986, 0.3365, 0.5878, 0.8755, 1.0296, 0.9555, 0.9555,
        0.8755, 1.0296, 0.7885, 0.9555, 1.4351, 0.7885, 1.2809, 0.6931, 0.8755,
        0.6931, 1.3863, 1.0986, 1.4351, 0.8755, 0.3365, 0.3365, 1.1632, 0.7885,
        0.1823, 1.6094, 0.7885, 1.1632, 0.7885, 0.4700, 0.3365, 0.8755, 0.0000,
        0.5878, 0.6931, 0.7885, 0.9555, 0.7885, 0.1823, 0.8755, 1.0296, 0.6931,
        1.0296, 0.8755, 1.0296, 0.6931, 0.1823, 0.6931, 0.6931, 0.6931, 0.3365,
        0.3365, 0.0000, 0.6931, 0.6931, 0.1823, 0.5878, 0.4700, 0.4700, 0.6931,
        0.7885, 0.6931, 0.4700, 0.1823, 

tensor([0.3365, 0.4700, 0.3365, 0.3365, 0.1823, 0.3365, 0.5878, 0.7885, 0.5878,
        0.4700, 0.0000, 0.7885, 0.4700, 0.3365, 0.7885, 0.5878, 0.3365, 0.6931,
        0.4700, 0.3365, 0.0000, 0.3365, 0.3365, 0.0000, 0.1823, 0.1823, 0.1823,
        0.9555, 0.4700, 0.1823, 0.7885, 0.1823, 0.3365, 0.8755, 0.3365, 0.1823,
        0.7885, 0.7885, 0.1823, 0.6931, 0.3365, 0.6931, 0.1823, 0.7885, 0.0000,
        0.1823, 0.0000, 0.0000, 0.1823, 0.1823, 0.0000, 0.1823, 0.1823, 0.1823,
        0.1823, 0.5878, 0.5878, 0.5878, 0.7885, 0.5878, 1.4351, 1.4351, 0.6931,
        1.0296, 1.3350, 0.9555, 0.5878, 0.6931, 0.5878, 0.7885, 0.3365, 0.8755,
        1.2809, 0.3365, 1.0986, 0.6931, 0.3365, 0.0000, 0.4700, 0.0000, 0.6931,
        1.0986, 0.6931, 0.5878, 0.7885, 0.7885, 1.4816, 1.0296, 1.0986, 0.3365,
        1.0986, 0.4700, 1.1632, 1.4816, 0.6931, 0.7885, 0.8755, 1.1632, 0.3365,
        0.8755, 0.6931, 1.4351, 0.5878, 0.7885, 0.6931, 0.1823, 1.0986, 1.0296,
        0.5878, 0.7885, 0.6931, 0.4700, 

tensor([0.7885, 0.9555, 1.0986, 0.3365, 0.1823, 0.8755, 0.0000, 0.5878, 0.3365,
        0.3365, 0.6931, 0.6931, 0.5878, 0.7885, 0.7885, 0.6931, 0.7885, 0.8755,
        0.8755, 0.4700, 0.5878, 0.5878, 0.4700, 0.4700, 0.8755, 0.1823, 0.7885,
        0.6931, 0.4700, 0.7885, 0.4700, 0.4700, 0.8755, 0.4700, 0.6931, 0.1823,
        0.6931, 0.6931, 1.2809, 0.7885, 0.4700, 0.8755, 0.8755, 0.4700, 0.6931,
        0.8755, 0.9555, 0.5878, 1.2238, 1.0296, 0.7885, 0.8755, 0.5878, 0.4700,
        0.4700, 0.4700, 0.5878, 0.8755, 0.6931, 0.0000, 0.7885, 0.5878, 0.6931,
        0.3365, 1.0296, 0.7885, 1.1632, 0.1823, 0.4700, 0.4700, 0.9555, 0.8755,
        0.8755, 0.6931, 0.7885, 0.3365, 0.5878, 0.3365, 0.9555, 0.4700, 1.3350,
        0.7885, 0.9555, 0.8755, 1.2238, 1.2238, 1.0986, 1.0296, 1.3350, 0.3365,
        1.4816, 0.7885, 0.3365, 0.3365, 0.5878, 0.9555, 0.5878, 0.6931, 0.5878,
        0.4700, 0.7885, 0.1823, 0.0000, 0.6931, 0.4700, 0.0000, 0.1823, 0.4700,
        0.5878, 0.4700, 0.9555, 0.4700, 

tensor([0.4700, 0.4700, 0.1823, 0.0000, 0.0000, 0.5878, 1.0296, 0.8755, 0.9555,
        0.7885, 0.9555, 0.3365, 1.0296, 0.8755, 1.1632, 0.8755, 0.6931, 0.7885,
        1.1632, 0.4700, 1.0296, 0.7885, 0.6931, 0.9555, 0.7885, 1.2809, 0.6931,
        1.2809, 0.5878, 0.3365, 1.0986, 0.3365, 1.2238, 1.0986, 0.9555, 1.0986,
        0.9555, 0.7885, 1.2238, 1.2238, 0.9555, 0.5878, 0.5878, 0.6931, 1.0296,
        0.4700, 0.8755, 0.9555, 0.4700, 0.4700, 1.2809, 0.9555, 0.9555, 0.8755,
        1.2238, 0.4700, 1.0296, 1.0986, 1.3863, 1.6487, 1.0296, 0.9555, 0.8755,
        1.3863, 1.3350, 1.3863, 1.0296, 0.5878, 1.4351, 0.8755, 1.6094, 0.9555,
        1.0296, 1.0986, 1.0986, 1.2238, 0.1823, 0.9555, 1.2238, 0.8755, 0.7885,
        0.6931, 0.8755, 0.8755, 1.0296, 0.6931, 0.6931, 1.2809, 0.9555, 0.5878,
        1.3863, 1.0296, 0.4700, 0.7885, 0.6931, 0.8755, 1.0986, 0.6931, 0.4700,
        0.6931, 0.8755, 0.1823, 0.6931, 0.1823, 1.0986, 0.7885, 0.7885, 0.4700,
        0.4700, 0.4700, 0.1823, 1.0986, 

 21%|███████████████▋                                                           | 5/24 [00:07<00:27,  1.46s/it]

tensor([0.5878, 0.6931, 0.5878, 1.0296, 1.2238, 1.2238, 1.3350, 0.8755, 0.9555,
        1.0986, 0.8755, 0.5878, 0.4700, 1.0296, 0.5878, 0.3365, 1.0986, 1.2238,
        1.2238, 1.2809, 1.1632, 0.8755, 0.9555, 0.9555, 1.2809, 0.9555, 1.0296,
        1.0296, 0.7885, 0.8755, 1.0296, 1.1632, 1.1632, 1.3350, 1.4816, 1.3350,
        0.9555, 1.2809, 1.2238, 1.0296, 1.2809, 1.0986, 1.0986, 1.3350, 1.3350,
        0.8755, 1.0986, 1.4816, 0.6931, 1.4816, 0.5878, 0.1823, 1.3350, 1.0986,
        1.4351, 1.0986, 1.1632, 0.5878, 1.5261, 1.0296, 0.9555, 0.8755, 0.3365,
        0.5878, 1.2238, 1.2238, 1.1632, 1.1632, 1.0986, 0.9555, 1.3350, 1.3863,
        1.2238, 0.9555, 0.3365, 0.3365, 0.4700, 0.0000, 0.3365, 0.3365, 0.8755,
        0.9555, 0.9555, 0.9555, 0.9555, 1.0986, 1.4351, 0.5878, 1.2809, 0.9555,
        0.9555, 0.8755, 0.7885, 0.4700, 0.7885, 0.5878, 1.2809, 0.3365, 0.4700,
        0.4700, 0.4700, 0.0000, 1.2238, 1.1632, 1.2809, 0.5878, 0.6931, 0.9555,
        0.4700, 0.4700, 1.0986, 0.3365, 

tensor([0.8755, 0.5878, 0.0000, 0.1823, 1.1632, 0.7885, 0.8755, 0.6931, 0.4700,
        1.0296, 0.9555, 0.7885, 0.4700, 0.5878, 0.4700, 0.4700, 0.4700, 0.4700,
        0.4700, 0.0000, 0.0000, 0.1823, 0.8755, 0.1823, 0.8755, 0.3365, 0.6931,
        0.5878, 0.5878, 0.5878, 0.5878, 0.3365, 0.1823, 0.9555, 0.5878, 0.6931,
        0.6931, 0.1823, 0.1823, 0.4700, 0.8755, 0.1823, 0.6931, 0.6931, 0.4700,
        0.0000, 0.4700, 0.3365, 0.4700, 0.4700, 0.5878, 0.4700, 0.1823, 0.6931,
        0.4700, 0.6931, 1.0296, 0.6931, 0.1823, 0.5878, 0.7885, 0.4700, 0.6931,
        1.2238, 1.0986, 0.9555, 0.1823, 0.1823, 0.4700, 0.4700, 0.3365, 0.4700,
        0.5878, 0.3365, 0.3365, 0.3365, 0.5878, 0.3365, 0.3365, 0.4700, 0.3365,
        0.3365, 0.4700, 0.4700, 0.7885, 0.8755, 0.9555, 0.6931, 0.1823, 0.4700,
        0.9555, 0.4700, 0.9555, 0.5878, 0.7885, 0.8755, 1.0296, 0.9555, 0.1823,
        0.1823, 0.0000, 0.8755, 0.1823, 0.4700, 0.1823, 0.6931, 0.3365, 1.0296,
        0.7885, 0.8755, 0.5878, 0.1823, 

tensor([0.5878, 0.5878, 0.7885, 0.5878, 0.8755, 0.4700, 0.1823, 0.5878, 0.1823,
        0.5878, 0.3365, 0.4700, 0.4700, 0.4700, 0.4700, 0.0000, 0.4700, 0.1823,
        0.4700, 0.9555, 0.5878, 0.1823, 0.1823, 0.4700, 0.3365, 0.0000, 0.1823,
        0.5878, 0.5878, 0.6931, 0.3365, 0.6931, 0.5878, 0.3365, 0.6931, 0.5878,
        0.3365, 0.1823, 0.0000, 0.3365, 0.6931, 0.5878, 0.5878, 0.5878, 0.0000,
        0.7885, 0.3365, 0.0000, 0.1823, 0.1823, 0.4700, 0.1823, 0.3365, 0.3365,
        0.3365, 0.5878, 0.3365, 0.6931, 0.3365, 0.3365, 0.7885, 1.1632, 0.6931,
        0.3365, 0.9555, 0.7885, 1.0296, 0.8755, 0.4700, 0.4700, 0.3365, 0.6931,
        0.6931, 0.3365, 0.0000, 0.6931, 0.5878, 1.1632, 0.6931, 0.9555, 0.6931,
        0.9555, 0.9555, 0.1823, 0.7885, 0.8755, 0.7885, 0.3365, 0.4700, 0.7885,
        0.4700, 0.6931, 1.0986, 1.0986, 0.7885, 0.9555, 0.4700, 0.9555, 0.1823,
        0.4700, 0.4700, 0.1823, 0.8755, 0.5878, 0.7885, 0.0000, 0.1823, 0.4700,
        0.1823, 0.9555, 0.7885, 0.4700, 

tensor([1.0296, 1.2238, 0.9555, 0.5878, 0.3365, 0.3365, 0.0000, 0.6931, 0.5878,
        0.8755, 0.8755, 0.1823, 1.0296, 1.0986, 0.3365, 0.9555, 0.5878, 0.5878,
        0.9555, 1.2238, 1.2809, 1.2238, 0.0000, 1.0296, 1.0986, 0.6931, 0.9555,
        1.1632, 1.0296, 1.0986, 1.3863, 1.1632, 1.2238, 1.2238, 0.9555, 0.9555,
        0.9555, 0.9555, 0.9555, 0.8755, 1.2809, 1.1632, 0.9555, 1.0296, 0.7885,
        0.0000, 0.4700, 0.9555, 0.3365, 0.3365, 0.5878, 0.5878, 0.3365, 0.8755,
        0.1823, 1.0296, 1.1632, 0.9555, 0.3365, 1.1632, 1.0296, 0.3365, 1.1632,
        1.2238, 1.2238, 1.0296, 0.7885, 0.7885, 0.9555, 0.5878, 1.0296, 0.9555,
        0.4700, 0.6931, 0.5878, 0.5878, 0.1823, 0.7885, 0.8755, 0.5878, 0.7885,
        0.9555, 0.7885, 0.9555, 0.8755, 1.0296, 0.9555, 0.6931, 0.6931, 0.6931,
        0.8755, 0.6931, 0.1823, 1.0296, 0.7885, 0.4700, 0.4700, 1.2238, 0.4700,
        0.4700, 0.1823, 0.1823, 0.4700, 0.1823, 0.4700, 1.2809, 0.6931, 0.3365,
        0.1823, 0.4700, 0.4700, 1.2238, 

 25%|██████████████████▊                                                        | 6/24 [00:08<00:27,  1.50s/it]

tensor([0.9555, 1.0986, 0.0000, 1.3350, 1.0986, 0.7885, 1.4816, 1.0296, 1.0986,
        1.2238, 0.7885, 0.1823, 1.0296, 0.9555, 0.0000, 0.9555, 0.5878, 1.0296,
        1.0296, 0.1823, 0.5878, 0.4700, 0.4700, 0.1823, 0.6931, 0.5878, 0.3365,
        0.5878, 0.5878, 0.9555, 0.9555, 0.8755, 0.5878, 0.6931, 0.8755, 0.8755,
        0.9555, 0.6931, 1.0986, 0.1823, 0.9555, 0.9555, 0.5878, 0.8755, 0.0000,
        0.1823, 0.1823, 0.9555, 0.6931, 0.3365, 0.4700, 0.9555, 1.0986, 0.1823,
        0.5878, 0.8755, 1.0986, 0.3365, 0.6931, 0.3365, 0.6931, 0.4700, 0.4700,
        0.6931, 0.7885, 0.5878, 0.5878, 0.3365, 0.5878, 0.6931, 1.0296, 0.5878,
        0.5878, 0.6931, 0.0000, 0.4700, 0.1823, 0.1823, 0.3365, 0.0000, 0.3365,
        0.1823, 1.0296, 0.1823, 1.2238, 1.0986, 0.4700, 1.2809, 1.2238, 0.9555,
        0.6931, 0.5878, 0.3365, 0.7885, 1.0986, 0.8755, 0.5878, 1.1632, 0.6931,
        0.7885, 1.2238, 0.4700, 0.4700, 0.3365, 0.8755, 0.5878, 0.3365, 0.5878,
        0.5878, 0.9555, 0.1823, 1.1632, 

 25%|██████████████████▊                                                        | 6/24 [00:08<00:26,  1.50s/it]


KeyboardInterrupt: 

In [86]:
zero_separated_loss_dict

{0: {'Loss (MSE)': 0.06835666298866272,
  'Loss for non-0s': 0.06468713283538818,
  'Loss for 0s': 0.1363668292760849},
 0.25: {'Loss (MSE)': 0.07899671792984009,
  'Loss for non-0s': 0.06423668563365936,
  'Loss for 0s': 0.1707461029291153},
 0.5: {'Loss (MSE)': 0.10136043280363083,
  'Loss for non-0s': 0.07536786049604416,
  'Loss for 0s': 0.18052752315998077},
 0.75: {'Loss (MSE)': 0.1291235089302063,
  'Loss for non-0s': 0.10279928147792816,
  'Loss for 0s': 0.17537926137447357},
 1: {'Loss (MSE)': 0.17137663066387177,
  'Loss for non-0s': 0.21866212785243988,
  'Loss for 0s': 0.1372186243534088},
 1.5: {'Loss (MSE)': 0.17394901812076569,
  'Loss for non-0s': 0.4264993667602539,
  'Loss for 0s': 0.08794506639242172},
 2: {'Loss (MSE)': 0.12416359782218933,
  'Loss for non-0s': 0.866157054901123,
  'Loss for 0s': 0.03509766235947609}}

In [6]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.rcParams.update(plt.rcParamsDefault)

In [5]:
tol_list = [0, 0.25, 0.5, 0.75, 1, 1.5, 2]
plt.plot(tol_list, [zero_separated_loss_dict[tol]["Loss (MSE)"] for tol in tol_list], marker='o', label="Loss")
plt.plot(tol_list, [zero_separated_loss_dict[tol]["Loss for 0s"] for tol in tol_list], marker='o', label="Loss for 0s")
plt.plot(tol_list, [zero_separated_loss_dict[tol]["Loss for non-0s"] for tol in tol_list], marker='o', label="Loss for non-0s")
plt.legend(labels=["Loss", "Loss for 0s", "Loss for non-0s"], loc="upper left")
plt.title("Losses")
plt.xlabel("Gating Tolerance")
plt.ylabel("MSE")
_ = plt.show()
plt.savefig("zero_inflated_losses.png")

NameError: name 'plt' is not defined

In [2]:
import hydra
from hydra.experimental import compose, initialize

test_loss_rad_dict = {}
zero_separated_loss_dict = {}


with initialize(config_path="../config"):
    cfg_from_terminal = compose(config_name="config8")
    OmegaConf.update(cfg_from_terminal, "model.kwargs.hidden_dimensions", [128, 128])
    OmegaConf.update(cfg_from_terminal, "training.logger_name", f"synthetic8")
    OmegaConf.update(cfg_from_terminal, "radius", 0)
    OmegaConf.update(cfg_from_terminal, "datasets.dataset.splits", 2)
    output = test(cfg_from_terminal)
    trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
    test_loss_rad_dict = test_results[0]['test_loss']
    num_zeros = torch.sum(inputs[:, 0] == 0)
    n = len(inputs[:, 0])
    loss_for_non_zeros = (n / (n - num_zeros)) * torch.mean(((inputs[:, 0] - gene_expressions[:, 0]) ** 2) * (inputs[:, 0] != 0).reshape(1, -1))
    loss_for_zeros = (n / num_zeros) * torch.mean(((inputs[:, 0] - gene_expressions[:, 0]) ** 2) * (inputs[:, 0] == 0).reshape(1, -1))
    loss = torch.mean((inputs[:, 0] - gene_expressions[:, 0]) ** 2)
    zero_separated_loss_dict = {
        "Loss (MSE)": loss.item(),
        "Loss for non-0s": loss_for_non_zeros.item(),
        "Loss for 0s": loss_for_zeros.item()
    }
    
zero_separated_loss_dict

100%|██████████████████████████████████████████████████████████████████████████| 24/24 [00:26<00:00,  1.12s/it]


/home/roko/spatial/data/raw/synth8.hdf5


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Testing: 0it [00:00, ?it/s]

TEST Profiler Report

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                                                          	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.14432787895202637
     test_loss: mse         0.14432787895202637
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


{'Loss (MSE)': 0.14432787895202637,
 'Loss for non-0s': 0.04854154214262962,
 'Loss for 0s': 1.1739932298660278}

In [9]:
torch.sum(gene_expressions[:, 0] == 0)

tensor(0)

In [11]:
inputs

tensor([[0.2091, 0.0979, 0.0044,  ..., 0.0429, 0.2568, 0.1001],
        [0.1587, 0.0586, 0.0091,  ..., 0.0725, 0.0080, 0.0110],
        [0.1272, 0.0419, 0.0811,  ..., 0.0620, 0.0398, 0.0980],
        ...,
        [0.0865, 0.0782, 0.1126,  ..., 0.1696, 0.1441, 0.0077],
        [0.1343, 0.0864, 0.0276,  ..., 0.1115, 0.0504, 0.0668],
        [0.2595, 0.1641, 0.0397,  ..., 0.0057, 0.0145, 0.0018]])

In [17]:
torch.set_printoptions(sci_mode=False)
torch.round(gene_expressions, decimals=4)

tensor([[     0.1806,      1.5100,     -2.2698,  ...,      2.4237,
             -1.2052,      2.8728],
        [     0.1502,     -0.8261,     -0.0012,  ...,     -0.3826,
             -1.3962,      1.8240],
        [     0.1346,     -1.9792,      0.9231,  ...,     -1.0029,
             -1.1496,     -0.1602],
        ...,
        [     0.1785,      0.1571,     -0.1246,  ...,      1.8092,
             -3.2476,      0.5719],
        [     0.1861,      0.7607,      0.0453,  ...,     -0.8183,
             -0.5509,      0.1121],
        [     0.2606,     -0.1781,     -1.4081,  ...,     -0.3479,
             -0.2700,      0.2911]])