## The purpose of this notebook is to attempt to build a graph with no neighbors while still cooperating with the PyG framework.

In [26]:
import os
import types

import h5py
import numpy as np
import pandas as pd
import requests
import torch
import torch_geometric
from sklearn import neighbors


class MerfishDataset(torch_geometric.data.InMemoryDataset):
    def __init__(
        self,
        root,
        n_neighbors=3,
        train=True,
        log_transform=True,
        non_response_genes_file="/home/roko/spatial/spatial/"
        "non_response_blank_removed.txt",
    ):
        super().__init__(root)

        # non-response genes (columns) in MERFISH
        with open(non_response_genes_file, "r", encoding="utf8") as genes_file:
            self.features = [int(x) for x in genes_file.read().split(",")]
            genes_file.close()

        # response genes (columns in MERFISH)
        self.responses = list(set(range(155)) - set(self.features))

        data_list = self.construct_graphs(n_neighbors, train, log_transform)

        with h5py.File(self.merfish_hdf5, "r") as h5f:
            self.gene_names = h5f["gene_names"][:][~self.bad_genes].astype("U")

        self.data, self.slices = self.collate(data_list)

    # from https://datadryad.org/stash/dataset/doi:10.5061/dryad.8t8s248
    url = "https://datadryad.org/stash/downloads/file_stream/67671"

    behavior_types = [
        "Naive",
        "Parenting",
        "Virgin Parenting",
        "Aggression to pup",
        "Aggression to adult",
        "Mating",
    ]
    behavior_lookup = {x: i for (i, x) in enumerate(behavior_types)}
    cell_types = [
        "Ambiguous",
        "Astrocyte",
        "Endothelial 1",
        "Endothelial 2",
        "Endothelial 3",
        "Ependymal",
        "Excitatory",
        "Inhibitory",
        "Microglia",
        "OD Immature 1",
        "OD Immature 2",
        "OD Mature 1",
        "OD Mature 2",
        "OD Mature 3",
        "OD Mature 4",
        "Pericytes",
    ]
    celltype_lookup = {x: i for (i, x) in enumerate(cell_types)}

    bad_genes = np.zeros(161, dtype=bool)
    bad_genes[[12, 13, 14, 15, 16, 144]] = True

    @property
    def raw_file_names(self):
        return ["merfish.csv", "merfish.hdf5"]

    @property
    def merfish_csv(self):
        return os.path.join(self.raw_dir, "merfish.csv")

    @property
    def merfish_hdf5(self):
        return os.path.join(self.raw_dir, "merfish.hdf5")

    def download(self):
        # download csv if necessary
        if not os.path.exists(self.merfish_csv):
            with open(self.merfish_csv, "wb") as csvf:
                csvf.write(requests.get(self.url).content)

        # process csv if necessary
        dataframe = pd.read_csv(self.merfish_csv)

        with h5py.File(self.merfish_hdf5, "w") as h5f:
            # pylint: disable=no-member
            for colnm, dtype in zip(dataframe.keys()[:9], dataframe.dtypes[:9]):
                if dtype.kind == "O":
                    data = np.require(dataframe[colnm], dtype="S36")
                    h5f.create_dataset(colnm, data=data)
                else:
                    h5f.create_dataset(colnm, data=np.require(dataframe[colnm]))

            expression = np.array(dataframe[dataframe.keys()[9:]]).astype(np.float16)
            h5f.create_dataset("expression", data=expression)

            gene_names = np.array(dataframe.keys()[9:], dtype="S80")
            h5f.create_dataset("gene_names", data=gene_names)

    def construct_graph(self, data, anid, breg, n_neighbors, log_transform):
        # get subset of cells in this slice
        good = (data.anids == anid) & (data.bregs == breg)

        # figure out neighborhood structure
        locations_for_this_slice = data.locations[good]
# LINES THAT WERE CHANGED
###############################################################################################        
        if n_neighbors == 0:
            edges = np.concatenate(
                [np.c_[np.array([i]), np.array([i])] for i in range(locations_for_this_slice.shape[0])],
                axis=0
            )
            print(edges)
        
        else:
        
            nbrs = neighbors.NearestNeighbors(
                n_neighbors=n_neighbors + 1, algorithm="ball_tree"
            )
            nbrs.fit(locations_for_this_slice)
            _, kneighbors = nbrs.kneighbors(locations_for_this_slice)
            edges = np.concatenate(
                [np.c_[kneighbors[:, 0], kneighbors[:, i + 1]] for i in range(n_neighbors)],
                axis=0,
            )
        
        edges = torch.tensor(edges, dtype=torch.long).T
############################################################################################### 
        # remove gene 144.  which is bad.  for some reason.
        subexpression = data.expression[good]
        subexpression = subexpression[:, ~self.bad_genes]

        # get behavior ids
        behavior_ids = np.array([self.behavior_lookup[x] for x in data.behavior[good]])
        celltype_ids = np.array([self.celltype_lookup[x] for x in data.celltypes[good]])
        labelinfo = np.c_[behavior_ids, celltype_ids]

        # make it into a torch geometric data object, add it to the list!

        # if we want to first log transform the data, we do it here
        # make this one return statement only changing x
        predictors_x = torch.tensor(subexpression.astype(np.float32))
        if log_transform:
            predictors_x = torch.log1p(predictors_x)

        return torch_geometric.data.Data(
            x=predictors_x,
            edge_index=edges,
            pos=torch.tensor(locations_for_this_slice.astype(np.float32)),
            y=torch.tensor(labelinfo),
            bregma=breg,
        )

    def construct_graphs(self, n_neighbors, train, log_transform=True):
        # load hdf5
        with h5py.File(self.merfish_hdf5, "r") as h5f:
            # pylint: disable=no-member
            data = types.SimpleNamespace(
                anids=h5f["Animal_ID"][:],
                bregs=h5f["Bregma"][:],
                expression=h5f["expression"][:],
                locations=np.c_[h5f["Centroid_X"][:], h5f["Centroid_Y"][:]],
                behavior=h5f["Behavior"][:].astype("U"),
                celltypes=h5f["Cell_class"][:].astype("U"),
            )

        # get the (animal_id,bregma) pairs that define a unique slice
        unique_slices = np.unique(np.c_[data.anids, data.bregs], axis=0)

        # are we looking at train or test sets?
        unique_slices = unique_slices[:150] if train else unique_slices[150:]

        # store all the slices in this list...
        data_list = []
        for anid, breg in unique_slices:
            data_list.append(
                self.construct_graph(data, anid, breg, n_neighbors, log_transform)
            )

        return data_list


class FilteredMerfishDataset(MerfishDataset):
    def __init__(
        self,
        root,
        n_neighbors=3,
        train=True,
        log_transform=True,
        non_response_genes_file="/home/roko/spatial/spatial/"
        "non_response_blank_removed.txt",
        sexes=None,
        behaviors=None,
        test_animal=None,
    ):
        self.root = root
        self.sexes = sexes
        self.behaviors = behaviors
        self.test_animal = test_animal
        original_csv_file = super().merfish_csv
        new_df = pd.read_csv(original_csv_file)
        print(f"Original Data {new_df.shape}")
        if self.sexes is not None:
            new_df = new_df[new_df["Animal_sex"].isin(self.sexes)]
        if self.behaviors is not None:
            new_df = new_df[new_df["Behavior"].isin(self.behaviors)]
        if new_df.shape[0] == 0:
            raise ValueError("Dataframe has no rows. Cannot build graph.")
        new_df.to_csv(self.root + "/raw/merfish_messi.csv", index=False)
        print(f"Filtered Data {new_df.shape}")
        # print("Filtered csv file created!")
        MerfishDataset.download(self)
        super().__init__(
            root,
            n_neighbors=n_neighbors,
            train=train,
            log_transform=log_transform,
            non_response_genes_file=non_response_genes_file,
        )
        # print("Filtered hdf5 file created!")

    #     @property
    #     def raw_file_names(self):
    #         return ["merfish_messi.csv", "merfish_messi.hdf5"]

    # THIS LINE WAS EDITED TO SHOW NEW FILE
    @property
    def merfish_csv(self):
        return os.path.join(self.raw_dir, "merfish_messi.csv")

    # THIS LINE WAS EDITED TO SHOW NEW FILE
    @property
    def merfish_hdf5(self):
        return os.path.join(self.raw_dir, "merfish_messi.hdf5")

    def construct_graphs(self, n_neighbors, train, log_transform=True):
        print(self.merfish_hdf5)
        # load hdf5
        with h5py.File(self.merfish_hdf5, "r") as h5f:
            # pylint: disable=no-member
            data = types.SimpleNamespace(
                anids=h5f["Animal_ID"][:],
                bregs=h5f["Bregma"][:],
                expression=h5f["expression"][:],
                locations=np.c_[h5f["Centroid_X"][:], h5f["Centroid_Y"][:]],
                behavior=h5f["Behavior"][:].astype("U"),
                celltypes=h5f["Cell_class"][:].astype("U"),
            )

        anid_to_bregma_count = {
            1: 12,
            2: 12,
            3: 6,
            4: 5,
            5: 6,
            6: 6,
            7: 12,
            8: 6,
            9: 6,
            10: 6,
            11: 6,
            12: 4,
            13: 4,
            14: 4,
            15: 4,
            16: 4,
            17: 4,
            18: 4,
            19: 4,
            20: 4,
            21: 4,
            22: 4,
            23: 4,
            24: 4,
            25: 4,
            26: 4,
            27: 2,
            28: 4,
            29: 4,
            30: 4,
        }

        # get the (animal_id,bregma) pairs that define a unique slice
        unique_slices = np.unique(np.c_[data.anids, data.bregs], axis=0)

        # are we looking at train or test sets?

        # if we want a specific animals
        if self.test_animal is not None:
            # we need to find which of the slices
            sorted_anids = np.sort(np.unique(data.anids))
            slices_before_test_anid = 0
            for anid in sorted_anids:
                if anid != self.test_animal:
                    slices_before_test_anid += anid_to_bregma_count[anid]
                else:
                    break

            mask_train = np.ones(unique_slices.shape[0], dtype=bool)
            mask_train[
                slices_before_test_anid : (
                    slices_before_test_anid + anid_to_bregma_count[self.test_animal]
                )
            ] = 0
            unique_slices = (
                unique_slices[(1 - mask_train).astype("bool")]
                if not train
                else unique_slices[mask_train]
            )
        else:
            min_animal = anid_to_bregma_count[np.min(data.anids)]
            unique_slices = (
                unique_slices[min_animal:] if train else unique_slices[:min_animal]
            )

        # store all the slices in this list...
        data_list = []
        for anid, breg in unique_slices:
            data_list.append(
                self.construct_graph(data, anid, breg, n_neighbors, log_transform)
            )

        return data_list

In [27]:
test_trial = FilteredMerfishDataset("../data", n_neighbors=0, sexes=["Female"], behaviors=["Naive"])

Original Data (1027848, 170)
Filtered Data (205348, 170)
../data/raw/merfish_messi.hdf5
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [6088 6088]
 [6089 6089]
 [6090 6090]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [6260 6260]
 [6261 6261]
 [6262 6262]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [6325 6325]
 [6326 6326]
 [6327 6327]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [6132 6132]
 [6133 6133]
 [6134 6134]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [5816 5816]
 [5817 5817]
 [5818 5818]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [5690 5690]
 [5691 5691]
 [5692 5692]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [5674 5674]
 [5675 5675]
 [5676 5676]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [5378 5378]
 [5379 5379]
 [5380 5380]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [5724 5724]
 [5725 5725]
 [5726 5726]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [5281 5281]
 [5282 5282]
 [5283 5283]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [5358 5358]
 [5359 5359]
 [

In [28]:
test_trial[0].edge_index

tensor([[   0,    1,    2,  ..., 6088, 6089, 6090],
        [   0,    1,    2,  ..., 6088, 6089, 6090]])

In [1]:
import pandas as pd
import json
import time

import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import random_split
from torch_geometric.data import DataLoader

# from spatial.merfish_dataset import FilteredMerfishDataset, MerfishDataset
from spatial.models.monet_ae import MonetAutoencoder2D, TrivialAutoencoder
from spatial.train import train
from spatial.predict import test

import torch

import hydra
from hydra.experimental import compose, initialize

behaviors = ["Naive"]
sexes = ["Female"]

with open('animal_id.json') as json_file:
    animals = json.load(json_file)

loss_dict = {}
time_dict = {}
loss_excitatory_dict = {}
loss_inhibitory_dict = {}

for behavior in behaviors:
    for sex in sexes:
        try:
            animal_list = animals[behavior][sex]
        except KeyError:
            continue
        behavior = [behavior]
        sex = [sex]
        # print(behavior, sex, animal_list)
        start = time.time()
        with initialize(config_path="../config"):
            cfg_from_terminal = compose(config_name="config")
            # update the behavior to get the model of interest
            OmegaConf.update(cfg_from_terminal, "datasets.dataset.behaviors", behavior)
            OmegaConf.update(cfg_from_terminal, "datasets.dataset.sexes", sex)
            OmegaConf.update(cfg_from_terminal, "datasets.dataset.test_animal", 1)
            OmegaConf.update(cfg_from_terminal, "n_neighbors", 0)
            model = train(cfg_from_terminal)
            output = test(cfg_from_terminal)
            trainer, l1_losses, inputs, gene_expressions, celltypes, test_results = output
            MAE = test_results[0]['test_loss: mae_response']
            excitatory_cells = (celltypes == 6).nonzero(as_tuple=True)[0]
            MAE_excitatory = torch.abs(torch.index_select((gene_expressions-inputs)[excitatory_cells], 1, torch.tensor(model.responses))).mean().item()
            inhibitory_cells = (celltypes == 7).nonzero(as_tuple=True)[0]
            MAE_inhibitory = torch.abs(torch.index_select((gene_expressions-inputs)[inhibitory_cells], 1, torch.tensor(model.responses))).mean().item()
        end = time.time()
#             time_dict[f"{sex}_{behavior}_{animal}"] = end-start
#             loss_dict[f"{sex}_{behavior}_{animal}"] = MAE
#             loss_excitatory_dict[f"{sex}_{behavior}_{animal}"] = MAE_excitatory
#             loss_inhibitory_dict[f"{sex}_{behavior}_{animal}"] = MAE_inhibitory

#             with open("deepST_MAE.json", "w") as outfile:
#                 json.dump(loss_dict, outfile, indent=4)

#             with open("deepST_time.json", "w") as outfile:
#                 json.dump(time_dict, outfile, indent=4)

#             with open("deepST_MAE_excitatory.json", "w") as outfile:
#                 json.dump(loss_excitatory_dict, outfile, indent=4)

#             with open("deepST_MAE_inhibitory.json", "w") as outfile:
#                 json.dump(loss_inhibitory_dict, outfile, indent=4)

See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information


Original Data (1027848, 170)
Filtered Data (205348, 170)
/home/roko/spatial/data/raw/merfish_messi.hdf5
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [6088 6088]
 [6089 6089]
 [6090 6090]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [6260 6260]
 [6261 6261]
 [6262 6262]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [6325 6325]
 [6326 6326]
 [6327 6327]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [6132 6132]
 [6133 6133]
 [6134 6134]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [5816 5816]
 [5817 5817]
 [5818 5818]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [5690 5690]
 [5691 5691]
 [5692 5692]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [5674 5674]
 [5675 5675]
 [5676 5676]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [5378 5378]
 [5379 5379]
 [5380 5380]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [5724 5724]
 [5725 5725]
 [5726 5726]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [5281 5281]
 [5282 5282]
 [5283 5283]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [5358 5358]

  rank_zero_deprecation(
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name            | Type                    | Params
------------------------------------------------------------
0 | encoder_network | DenseReluGMMConvNetwork | 2.8 M 
1 | decoder_network | DenseReluGMMConvNetwork | 2.8 M 
------------------------------------------------------------
5.5 M     Trainable params
0         Non-trainable params
5.5 M     Total params
22.048    Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch 9, global step 209: val_loss reached 0.43063 (best 0.43063), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 19, global step 419: val_loss reached 0.35653 (best 0.35653), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 29, global step 629: val_loss reached 0.32034 (best 0.32034), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 39, global step 839: val_loss reached 0.29829 (best 0.29829), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 49, global step 1049: val_loss reached 0.27206 (best 0.27206), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 59, global step 1259: val_loss reached 0.25861 (best 0.25861), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 69, global step 1469: val_loss reached 0.24626 (best 0.24626), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 79, global step 1679: val_loss reached 0.23887 (best 0.23887), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 89, global step 1889: val_loss reached 0.23275 (best 0.23275), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 99, global step 2099: val_loss reached 0.22632 (best 0.22632), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 109, global step 2309: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 119, global step 2519: val_loss reached 0.22269 (best 0.22269), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 129, global step 2729: val_loss reached 0.22115 (best 0.22115), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 139, global step 2939: val_loss reached 0.22018 (best 0.22018), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 149, global step 3149: val_loss reached 0.21664 (best 0.21664), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 159, global step 3359: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 169, global step 3569: val_loss reached 0.21327 (best 0.21327), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 179, global step 3779: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 189, global step 3989: val_loss reached 0.21189 (best 0.21189), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 199, global step 4199: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 209, global step 4409: val_loss reached 0.21189 (best 0.21189), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 219, global step 4619: val_loss reached 0.21156 (best 0.21156), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 229, global step 4829: val_loss reached 0.21082 (best 0.21082), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 239, global step 5039: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 249, global step 5249: val_loss reached 0.21003 (best 0.21003), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 259, global step 5459: val_loss reached 0.20874 (best 0.20874), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 269, global step 5669: val_loss reached 0.20815 (best 0.20815), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 279, global step 5879: val_loss reached 0.20744 (best 0.20744), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 289, global step 6089: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 299, global step 6299: val_loss reached 0.20720 (best 0.20720), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 309, global step 6509: val_loss reached 0.20583 (best 0.20583), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 319, global step 6719: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 329, global step 6929: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 339, global step 7139: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 349, global step 7349: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 359, global step 7559: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 369, global step 7769: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 379, global step 7979: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 389, global step 8189: val_loss reached 0.20569 (best 0.20569), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 399, global step 8399: val_loss reached 0.20537 (best 0.20537), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 409, global step 8609: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 419, global step 8819: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 429, global step 9029: val_loss reached 0.20501 (best 0.20501), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 439, global step 9239: val_loss reached 0.20474 (best 0.20474), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 449, global step 9449: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 459, global step 9659: val_loss reached 0.20474 (best 0.20474), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 469, global step 9869: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 479, global step 10079: val_loss reached 0.20408 (best 0.20408), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 489, global step 10289: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 499, global step 10499: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 509, global step 10709: val_loss reached 0.20328 (best 0.20328), saving model to "/home/roko/spatial/output/lightning_logs/checkpoints/MonetAutoencoder2D/MonetAutoencoder2D__155__[200, 200]__155__0__['Female']__['Naive']__0.001__25__grid_search_filtered-v1.ckpt" as top True


Validating: 0it [00:00, ?it/s]

Epoch 519, global step 10919: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 529, global step 11129: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 539, global step 11339: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 549, global step 11549: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 559, global step 11759: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 569, global step 11969: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 579, global step 12179: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 589, global step 12389: val_loss was not in top True


Validating: 0it [00:00, ?it/s]

Epoch 599, global step 12599: val_loss was not in top True
FIT Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  1045.2         	|  100 %          	|
--------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                 	|  1.7339         	|600            	|  1040.3         	|  99.53          	|
run_training_batch                 	|  0.052638       	|12600          	|  663.24         	|  63.454         	|
optimizer_step_with_closure_0      	|  0.034276       	|12600          	|  431.88         	|  41.319         	|
training_step_and_backward         	|  0.029738       	|12600          	|  374.69       

Original Data (1027848, 170)
Filtered Data (205348, 170)
/home/roko/spatial/data/raw/merfish_messi.hdf5
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [6506 6506]
 [6507 6507]
 [6508 6508]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [6409 6409]
 [6410 6410]
 [6411 6411]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [6504 6504]
 [6505 6505]
 [6506 6506]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [6602 6602]
 [6603 6603]
 [6604 6604]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [6182 6182]
 [6183 6183]
 [6184 6184]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [6151 6151]
 [6152 6152]
 [6153 6153]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [6108 6108]
 [6109 6109]
 [6110 6110]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [6141 6141]
 [6142 6142]
 [6143 6143]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [5796 5796]
 [5797 5797]
 [5798 5798]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [6064 6064]
 [6065 6065]
 [6066 6066]]
[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [5575 5575]

  rank_zero_deprecation(
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


[[   0    0]
 [   1    1]
 [   2    2]
 ...
 [5581 5581]
 [5582 5582]
 [5583 5583]]


  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

  rank_zero_warn(
TEST Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
---------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  10.487         	|  100 %          	|
---------------------------------------------------------------------------------------------------------------------------------------
run_test_evaluation                	|  10.441         	|1              	|  10.441         	|  99.56          	|
evaluation_step_and_end            	|  0.70293        	|12             	|  8.4352         	|  80.435         	|
test_step                          	|  0.68798        	|12             	|  8.2557         	|  78.724         	|
get_test_batch                     	|  0.10407        	|13             	|  1.3529         	|  12.901         	|
fetch_next_tes

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 0.20122438669204712,
 'test_loss: mae_response': 0.32917433977127075,
 'test_loss: mse': 0.20576705038547516}
--------------------------------------------------------------------------------
