In [161]:
import os
import types
import json
import itertools as it
import pathlib
import copy

import h5py
import numpy as np
import pandas as pd
import requests
import torch
import torch.nn.functional as F
import torch_geometric
from sklearn import neighbors
from scipy.spatial import cKDTree


class MerfishDataset(torch_geometric.data.InMemoryDataset):
    def __init__(
        self,
        root,
        n_neighbors=3,
        train=True,
        log_transform=True,
        neighbor_celltypes=False,
        radius=None,
        non_response_genes_file="/home/roko/spatial/spatial/"
        "non_response_blank_removed.txt",
    ):
        super().__init__(root)

        # non-response genes (columns) in MERFISH
        with open(non_response_genes_file, "r", encoding="utf8") as genes_file:
            self.features = [int(x) for x in genes_file.read().split(",")]
            genes_file.close()

        # response genes (columns in MERFISH)
        self.responses = list(set(range(155)) - set(self.features))

        data_list = self.construct_graphs(
            n_neighbors, train, log_transform, neighbor_celltypes, radius
        )

        with h5py.File(self.merfish_hdf5, "r") as h5f:
            self.gene_names = h5f["gene_names"][:][~self.bad_genes].astype("U")

        self.data, self.slices = self.collate(data_list)

    # from https://datadryad.org/stash/dataset/doi:10.5061/dryad.8t8s248
    url = "https://datadryad.org/stash/downloads/file_stream/67671"

    behavior_types = [
        "Naive",
        "Parenting",
        "Virgin Parenting",
        "Aggression to pup",
        "Aggression to adult",
        "Mating",
    ]
    behavior_lookup = {x: i for (i, x) in enumerate(behavior_types)}
    cell_types = [
        "Ambiguous",
        "Astrocyte",
        "Endothelial 1",
        "Endothelial 2",
        "Endothelial 3",
        "Ependymal",
        "Excitatory",
        "Inhibitory",
        "Microglia",
        "OD Immature 1",
        "OD Immature 2",
        "OD Mature 1",
        "OD Mature 2",
        "OD Mature 3",
        "OD Mature 4",
        "Pericytes",
    ]
    celltype_lookup = {x: i for (i, x) in enumerate(cell_types)}

    bad_genes = np.zeros(161, dtype=bool)
    bad_genes[[12, 13, 14, 15, 16, 144]] = True

    @property
    def raw_file_names(self):
        return ["merfish.csv", "merfish.hdf5"]

    @property
    def merfish_csv(self):
        return os.path.join(self.raw_dir, "merfish.csv")

    @property
    def merfish_hdf5(self):
        return os.path.join(self.raw_dir, "merfish.hdf5")

    def download(self):
        # download csv if necessary
        if not os.path.exists(self.merfish_csv):
            with open(self.merfish_csv, "wb") as csvf:
                csvf.write(requests.get(self.url).content)

        # process csv if necessary
        dataframe = pd.read_csv(self.merfish_csv)

        with h5py.File(self.merfish_hdf5, "w") as h5f:
            # pylint: disable=no-member
            for colnm, dtype in zip(dataframe.keys()[:9], dataframe.dtypes[:9]):
                if dtype.kind == "O":
                    data = np.require(dataframe[colnm], dtype="S36")
                    h5f.create_dataset(colnm, data=data)
                else:
                    h5f.create_dataset(colnm, data=np.require(dataframe[colnm]))

            expression = np.array(dataframe[dataframe.keys()[9:]]).astype(np.float16)
            h5f.create_dataset("expression", data=expression)

            gene_names = np.array(dataframe.keys()[9:], dtype="S80")
            h5f.create_dataset("gene_names", data=gene_names)

    def construct_graph(
        self, data, anid, breg, n_neighbors, log_transform, neighbor_celltypes, radius
    ):
        def get_neighbors(edges, x_shape):
            return [edges[:, edges[0] == i][1] for i in range(x_shape)]

        def get_celltype_simplex(cell_behavior_tensor, neighbors_tensor):
            num_classes = cell_behavior_tensor.max() + 1
            return torch.cat(
                [
                    (
                        torch.mean(
                            1.0
                            * F.one_hot(
                                cell_behavior_tensor.index_select(0, neighbors),
                                num_classes=num_classes,
                            ),
                            dim=0,
                        )
                    ).unsqueeze(0)
                    for neighbors in neighbors_tensor
                ],
                dim=0,
            )

        # get subset of cells in this slice
        good = (data.anids == anid) & (data.bregs == breg)

        # figure out neighborhood structure
        locations_for_this_slice = data.locations[good]

        # only include self edges if n_neighbors is 0
        if n_neighbors == 0:
            edges = np.concatenate(
                [
                    np.c_[np.array([i]), np.array([i])]
                    for i in range(locations_for_this_slice.shape[0])
                ],
                axis=0,
            )
            print(edges)

        else:

            if radius is None:
                nbrs = neighbors.NearestNeighbors(
                    n_neighbors=n_neighbors + 1, algorithm="ball_tree"
                )
                nbrs.fit(locations_for_this_slice)
                _, kneighbors = nbrs.kneighbors(locations_for_this_slice)
                edges = np.concatenate(
                    [
                        np.c_[kneighbors[:, 0], kneighbors[:, i + 1]]
                        for i in range(n_neighbors)
                    ],
                    axis=0,
                )

            else:

                tree = cKDTree(locations_for_this_slice)
                kneighbors = tree.query_ball_point(
                    locations_for_this_slice, r=32, return_sorted=False
                )
                edges = np.concatenate(
                    [
                        np.c_[
                            np.repeat(i, len(kneighbors[i]) - 1),
                            [x for x in kneighbors[i] if x != i],
                        ]
                        for i in range(len(kneighbors))
                    ],
                    axis=0,
                )

        edges = torch.tensor(edges, dtype=torch.long).T

        # remove gene 144.  which is bad.  for some reason.
        subexpression = data.expression[good]
        subexpression = subexpression[:, ~self.bad_genes]

        # get behavior ids
        behavior_ids = np.array([self.behavior_lookup[x] for x in data.behavior[good]])
        celltype_ids = np.array([self.celltype_lookup[x] for x in data.celltypes[good]])
        labelinfo = np.c_[behavior_ids, celltype_ids]

        # make it into a torch geometric data object, add it to the list!

        # if we want to first log transform the data, we do it here
        # make this one return statement only changing x
        predictors_x = torch.tensor(subexpression.astype(np.float32))
        if neighbor_celltypes:
            test_simplex = get_celltype_simplex(
                torch.tensor(labelinfo[:, 1]),
                get_neighbors(edges, predictors_x.shape[0]),
            )
            predictors_x = torch.cat((predictors_x, test_simplex), dim=1)
        if log_transform:
            predictors_x = torch.log1p(predictors_x)

        return torch_geometric.data.Data(
            x=predictors_x,
            edge_index=edges,
            pos=torch.tensor(locations_for_this_slice.astype(np.float32)),
            y=torch.tensor(labelinfo),
            bregma=breg,
        )

    def construct_graphs(
        self,
        n_neighbors,
        train,
        log_transform=True,
        neighbor_celltypes=False,
        radius=None,
    ):
        # load hdf5
        with h5py.File(self.merfish_hdf5, "r") as h5f:
            # pylint: disable=no-member
            
            data = types.SimpleNamespace(
                anids=h5f["Animal_ID"][:],
                bregs=h5f["Bregma"][:],
                expression=h5f["expression"][:],
                locations=np.c_[h5f["Centroid_X"][:], h5f["Centroid_Y"][:]],
                behavior=h5f["Behavior"][:].astype("U"),
                celltypes=h5f["Cell_class"][:].astype("U"),
            )
            
            num_graphs = int(np.ceil(n_neighbors/3))
            
            # print(num_graphs)
            
            data_graphs = []
            
            # for each new graph
            for i in range(1, num_graphs+1):
                
                # subset 1/num_graphs of the data based on quantiles
                graph_filter = np.where(
                    (data.locations[:,0]<=np.quantile(data.locations[:,0], i/num_graphs)) & 
                    (data.locations[:,0]>=np.quantile(data.locations[:,0], (i-1)/num_graphs))
                )[0]

                data = types.SimpleNamespace(
                    anids=h5f["Animal_ID"][graph_filter],
                    bregs=h5f["Bregma"][graph_filter],
                    expression=h5f["expression"][graph_filter, :],
                    locations=np.c_[h5f["Centroid_X"][graph_filter], h5f["Centroid_Y"][graph_filter]],
                    behavior=h5f["Behavior"][graph_filter].astype("U"),
                    celltypes=h5f["Cell_class"][graph_filter].astype("U"),
                )
                
                data_graphs.append(data)
        
        # see if you can update data locations AFTER data was created
        # create a deepcopy and then split the locations

        # store all the slices in this list...
        data_list = []
        # print(len(data_graphs))
        for data in data_graphs:
            
            # get the (animal_id,bregma) pairs that define a unique slice
            unique_slices = np.unique(np.c_[data.anids, data.bregs], axis=0)

            # are we looking at train or test sets?
            unique_slices = unique_slices[:150] if train else unique_slices[150:]
        
            # print(len(unique_slices))
            
            for anid, breg in unique_slices:
                data_list.append(
                    self.construct_graph(
                        data,
                        anid,
                        breg,
                        n_neighbors,
                        log_transform,
                        neighbor_celltypes,
                        radius,
                    )
                )

        return data_list

In [162]:
test = MerfishDataset("../data", n_neighbors=13)

In [163]:
len(test)

128

In [164]:
test[12]

Data(bregma=[1], edge_index=[2, 17615], pos=[1355, 2], x=[1355, 155], y=[1355, 2])

In [165]:
import sys

for n_neighbors in [0,1,2,3,5,8,13]:
    test = MerfishDataset("../data", n_neighbors=13)
    print(n_neighbors, sys.getsizeof(test))

0 48
1 48


KeyboardInterrupt: 

In [151]:
from spatial.merfish_dataset import MerfishDataset

In [152]:
test = MerfishDataset("../data", n_neighbors=3)

In [153]:
len(test)

150

In [154]:
test[12]

Data(bregma=[1], edge_index=[2, 18273], pos=[6091, 2], x=[6091, 155], y=[6091, 2])