# Step 1: Recreating the data example in MESSI's home tutorial notebook. (Get Animals 16-19)

In [1]:
import pandas as pd

In [2]:
merfish_df = pd.read_csv('../data/raw/merfish.csv')

In [3]:
MESSI_jupyter_example = merfish_df[(merfish_df['Animal_sex'] == "Female") & (merfish_df['Behavior'] == "Parenting")]

In [4]:
MESSI_jupyter_example.to_csv('../data/raw/merfish_messi.csv', index=False)

# Step 2: Create the torch Geometric object that represents this subset


In [5]:
import os
import types

import h5py
import numpy as np
import pandas as pd
import requests
import torch
import torch_geometric
from sklearn import neighbors


class MerfishDataset(torch_geometric.data.InMemoryDataset):
    def __init__(
        self,
        root,
        n_neighbors=3,
        train=True,
        log_transform=True,
        non_response_genes_file="/home/roko/spatial/spatial/non_response.txt",
    ):
        super().__init__(root)

        # non-response genes (columns) in MERFISH
        with open(non_response_genes_file, "r") as genes_file:
            self.features = [int(x) for x in genes_file.read().split(",")]
            genes_file.close()

        # response genes (columns in MERFISH)
        self.responses = list(set(range(160)) - set(self.features))

        data_list = self.construct_graphs(n_neighbors, train, log_transform)

        with h5py.File(self.merfish_hdf5, "r") as h5f:
            self.gene_names = h5f["gene_names"][:][~self.bad_genes].astype("U")

        self.data, self.slices = self.collate(data_list)

    # from https://datadryad.org/stash/dataset/doi:10.5061/dryad.8t8s248
    url = "https://datadryad.org/stash/downloads/file_stream/67671"

    behavior_types = [
        "Naive",
        "Parenting",
        "Virgin Parenting",
        "Aggression to pup",
        "Aggression to adult",
        "Mating",
    ]
    behavior_lookup = {x: i for (i, x) in enumerate(behavior_types)}
    cell_types = [
        "Ambiguous",
        "Astrocyte",
        "Endothelial 1",
        "Endothelial 2",
        "Endothelial 3",
        "Ependymal",
        "Excitatory",
        "Inhibitory",
        "Microglia",
        "OD Immature 1",
        "OD Immature 2",
        "OD Mature 1",
        "OD Mature 2",
        "OD Mature 3",
        "OD Mature 4",
        "Pericytes",
    ]
    celltype_lookup = {x: i for (i, x) in enumerate(cell_types)}

    bad_genes = np.zeros(161, dtype=bool)
    bad_genes[144] = True

    @property
    def raw_file_names(self):
        return ["merfish_messi.csv", "merfish_messi.hdf5"]

    # THIS LINE WAS EDITED TO SHOW NEW FILE
    @property
    def merfish_csv(self):
        return os.path.join(self.raw_dir, "merfish_messi.csv")

    # THIS LINE WAS EDITED TO SHOW NEW FILE
    @property
    def merfish_hdf5(self):
        return os.path.join(self.raw_dir, "merfish_messi.hdf5")

    def download(self):
        # download csv if necessary
        if not os.path.exists(self.merfish_csv):
            with open(self.merfish_csv, "wb") as csvf:
                csvf.write(requests.get(self.url).content)

        # process csv if necessary
        dataframe = pd.read_csv(self.merfish_csv)

        with h5py.File(self.merfish_hdf5, "w") as h5f:
            for colnm, dtype in zip(dataframe.keys()[:9], dataframe.dtypes[:9]):
                if dtype.kind == "O":
                    data = np.require(dataframe[colnm], dtype="S36")
                    h5f.create_dataset(colnm, data=data)
                else:
                    h5f.create_dataset(colnm, data=np.require(dataframe[colnm]))

            expression = np.array(dataframe[dataframe.keys()[9:]]).astype(np.float16)
            h5f.create_dataset("expression", data=expression)

            gene_names = np.array(dataframe.keys()[9:], dtype="S80")
            h5f.create_dataset("gene_names", data=gene_names)

    def construct_graph(self, data, anid, breg, n_neighbors, log_transform):
        # get subset of cells in this slice
        good = (data.anids == anid) & (data.bregs == breg)

        # figure out neighborhood structure
        locations_for_this_slice = data.locations[good]
        nbrs = neighbors.NearestNeighbors(
            n_neighbors=n_neighbors + 1, algorithm="ball_tree"
        )
        nbrs.fit(locations_for_this_slice)
        _, kneighbors = nbrs.kneighbors(locations_for_this_slice)
        edges = np.concatenate(
            [np.c_[kneighbors[:, 0], kneighbors[:, i + 1]] for i in range(n_neighbors)],
            axis=0,
        )
        edges = torch.tensor(edges, dtype=torch.long).T

        # remove gene 144.  which is bad.  for some reason.
        subexpression = data.expression[good]
        subexpression = subexpression[:, ~self.bad_genes]

        # get behavior ids
        behavior_ids = np.array([self.behavior_lookup[x] for x in data.behavior[good]])
        celltype_ids = np.array([self.celltype_lookup[x] for x in data.celltypes[good]])
        labelinfo = np.c_[behavior_ids, celltype_ids]

        # make it into a torch geometric data object, add it to the list!

        # if we want to first log transform the data, we do it here
        # make this one return statement only changing x
        predictors_x = torch.tensor(subexpression.astype(np.float32))
        if log_transform:
            predictors_x = torch.log1p(predictors_x)

        return torch_geometric.data.Data(
            x=predictors_x,
            edge_index=edges,
            pos=torch.tensor(locations_for_this_slice.astype(np.float32)),
            y=torch.tensor(labelinfo),
        )

    def construct_graphs(self, n_neighbors, train, log_transform=True):
        # load hdf5
        with h5py.File(self.merfish_hdf5, "r") as h5f:
            # pylint: disable=no-member
            data = types.SimpleNamespace(
                anids=h5f["Animal_ID"][:],
                bregs=h5f["Bregma"][:],
                expression=h5f["expression"][:],
                locations=np.c_[h5f["Centroid_X"][:], h5f["Centroid_Y"][:]],
                behavior=h5f["Behavior"][:].astype("U"),
                celltypes=h5f["Cell_class"][:].astype("U"),
            )

        # get the (animal_id,bregma) pairs that define a unique slice
        unique_slices = np.unique(np.c_[data.anids, data.bregs], axis=0)

        # are we looking at train or test sets?
        unique_slices = unique_slices[:150] if train else unique_slices[150:]

        # store all the slices in this list...
        data_list = []
        for anid, breg in unique_slices:
            data_list.append(
                self.construct_graph(data, anid, breg, n_neighbors, log_transform)
            )

        return data_list


In [6]:
new_messi_csv = pd.read_csv('../data/raw/merfish_messi.csv', index_col=0)

In [7]:
test = MerfishDataset('../data')

In [8]:
test

MerfishDataset(16)

#### The code below just shows that the graph objects are created identically.

In [9]:
f = h5py.File('../data/raw/merfish_messi.hdf5')

In [10]:
for key in f.keys():
    print(key) #Names of the groups in HDF5 file.

Animal_ID
Animal_sex
Behavior
Bregma
Cell_ID
Cell_class
Centroid_X
Centroid_Y
Neuron_cluster_ID
expression
gene_names


In [11]:
f = h5py.File('../data/raw/merfish.hdf5')

In [12]:
for key in f.keys():
    print(key) #Names of the groups in HDF5 file.

Animal_ID
Animal_sex
Behavior
Bregma
Cell_ID
Cell_class
Centroid_X
Centroid_Y
Neuron_cluster_ID
expression
gene_names


# Step 3: Get a Trained and Tested Example Running!

Just use the filtered merfish dataset in the train.py and predict.py

This will be accomplished by creating a class that inherits from MerfishDataset, overwrites the methods I change, and run a new example.

# Step 4: Testing the Trained Example

Make sure that in predict.py that the correct checkpoint file is selected for testing.

In [1]:
from spatial import predict, train
from spatial.merfish_dataset import MerfishDataset, FilteredMerfishDataset

In [2]:
import os

os.environ["MKL_NUM_THREADS"]="1"
os.environ["NUMEXPR_NUM_THREADS"]="1"
os.environ["OMP_NUM_THREADS"]="1"

import sys
import torch

import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import random_split
from torch_geometric.data import DataLoader

from spatial.merfish_dataset import FilteredMerfishDataset, MerfishDataset
from spatial.models.monet_ae import MonetAutoencoder2D, TrivialAutoencoder
from spatial.train import train
from spatial.predict import test

In [3]:
# equivalent to spatial

import hydra
from hydra.experimental import compose, initialize

with initialize(config_path="../config"):
    cfg_from_terminal = compose(config_name="config")
    # for now just to keep the code running
    output = test(cfg_from_terminal)

For more details, see https://github.com/omry/omegaconf/issues/367
GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 0.23230068385601044,
 'test_loss: mae_response': 0.36871105432510376,
 'test_loss: mse': 0.2573970556259155}
--------------------------------------------------------------------------------


# Choosing the Correct Filtered Cells

We can see in the source code that the cell types are listed as follows:

    cell_types = [
        "Ambiguous",
        "Astrocyte",
        "Endothelial 1",
        "Endothelial 2",
        "Endothelial 3",
        "Ependymal",
        "Excitatory",
        "Inhibitory",
        "Microglia",
        "OD Immature 1",
        "OD Immature 2",
        "OD Mature 1",
        "OD Mature 2",
        "OD Mature 3",
        "OD Mature 4",
        "Pericytes",
    ]
    
So, we can keep only the cell types the MESSI uses to evaluate their performance. These would be [1, 2, 6, 7, 8, 9, 10, 11]

In [4]:
import matplotlib.pyplot as plt

trainer, l1_losses, inputs, gene_expressions, celltypes = output

In [5]:
responses = torch.tensor(MerfishDataset('../data').responses)

In [6]:
excitatory_cells = (celltypes == 6).nonzero(as_tuple=True)[0]

In [7]:
deepST_inputs = torch.index_select(torch.index_select(inputs, 0, excitatory_cells), 1, responses)

In [8]:
import pandas as pd

df = pd.DataFrame(deepST_inputs.numpy())
df.to_csv("results/inputs.csv", index=False)

In [9]:
deepST_outputs = torch.index_select(torch.index_select(gene_expressions, 0, excitatory_cells), 1, responses)

In [10]:
df = pd.DataFrame(deepST_outputs.numpy())
df.to_csv("results/deepST_outputs.csv", index=False)

In [11]:
inputs

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 5.4199e-02, 4.9734e-03,
         0.0000e+00],
        [0.0000e+00, 9.5009e-01, 9.5009e-01,  ..., 1.4219e-02, 0.0000e+00,
         0.0000e+00],
        [6.1154e-01, 1.2611e+00, 0.0000e+00,  ..., 3.8231e-02, 6.9036e-03,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.6553e-03, 5.8650e-03,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 1.4980e+00,  ..., 8.1744e-04, 6.8809e-03,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 4.6887e-03, 0.0000e+00,
         4.6684e-02]])