In [1]:
"""A graph convolutional autoencoder for MERFISH data."""

import pytorch_lightning as pl
import torch

from spatial.models import base_networks


def calc_pseudo(edge_index, pos):
    """
    Calculate pseudo

    Input:
      - edge_index, an (N_edges x 2) long tensor indicating edges of a graph
      - pos, an (N_vertices x 2) float tensor indicating coordinates of nodes

    Output:
      - pseudo, an (N_edges x 2) float tensor indicating edge-values
        (to be used in graph-convnet)
    """
    coord1 = pos[edge_index[0]]
    coord2 = pos[edge_index[1]]
    edge_dir = coord2 - coord1
    rho = torch.sqrt(edge_dir[:, 0] ** 2 + edge_dir[:, 1] ** 2).unsqueeze(-1)
    theta = torch.atan2(edge_dir[:, 1], edge_dir[:, 0]).unsqueeze(-1)
    return torch.cat((rho, theta), dim=1)


class BasicAEMixin(pl.LightningModule):
    """
    Mixin implementing

    - loss calculations
    - training_step, validation_step,test_step,configure_optimizers for pytorchlightning
    """

    def calc_loss(self, pred, val):
        if self.loss_type == "mse_against_log1pdata":
            return torch.sum((pred - torch.log(1 + val)) ** 2)
        elif self.loss_type == "mse":
            return torch.sum((pred - val) ** 2)
        else:
            raise NotImplementedError(self.loss_type)

    def training_step(self, batch, batch_idx):
        _, reconstruction = self(batch)
        loss = calc_loss(reconstruction, batch.x)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        _, reconstruction = self(batch)
        loss = self.calc_loss(reconstruction, batch.x)
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())


class TrivialAutoencoder(BasicAEMixin):
    """Autoencoder for graph data, ignoring the graph structurea"""

    def __init__(
        self, observables_dimension, hidden_dimensions, latent_dimension, loss_type
    ):
        """
        observables_dimension -- number of values associated with each graph node
        hidden_dimensions -- list of hidden values to associate with each graph node
        latent_dimension -- number of latent values to associate with each graph node
        """
        super().__init__()

        self.loss_type = loss_type

        self.encoder_network = base_networks.construct_dense_relu_network(
            [observables_dimension] + list(hidden_dimensions) + [latent_dimension],
        )

        self.decoder_network = base_networks.construct_dense_relu_network(
            [latent_dimension]
            + list(reversed(hidden_dimensions))
            + [observables_dimension],
        )

    def forward(self, batch):

        latent_loadings = self.encoder_network(batch.x)
        expr_reconstruction = self.decoder_network(latent_loadings)
        return latent_loadings, expr_reconstruction


class MonetAutoencoder2D(BasicAEMixin):
    """Autoencoder for graph data whose nodes are embedded in 2d"""

    def __init__(
        self,
        observables_dimension,
        hidden_dimensions,
        latent_dimension,
        loss_type,
        dim,
        kernel_size,
    ):
        """
        observables_dimension -- number of values associated with each graph node
        latent_dimension -- number of latent values to associate with each graph node
        """
        super().__init__()

        self.loss_type = loss_type

        self.encoder_network = base_networks.DenseReluGMMConvNetwork(
            [observables_dimension] + list(hidden_dimensions) + [latent_dimension],
            dim=2,
            kernel_size=25,
        )
        self.decoder_network = base_networks.DenseReluGMMConvNetwork(
            [latent_dimension]
            + list(reversed(hidden_dimensions))
            + [observables_dimension],
            dim=2,
            kernel_size=25,
        )

    def forward(self, batch):
        pseudo = calc_pseudo(batch.edge_index, batch.pos)
        latent_loadings = self.encoder_network(batch.x, batch.edge_index, pseudo)
        expr_reconstruction = self.decoder_network(
            latent_loadings, batch.edge_index, pseudo
        )
        return latent_loadings, expr_reconstruction


In [1]:
import os
import types

import h5py
import numpy as np
import pandas as pd
import requests
import torch
import torch_geometric
from sklearn import neighbors


class MerfishDataset(torch_geometric.data.InMemoryDataset):
    def __init__(self, root, n_neighbors=3, train=True):
        super().__init__(root)

        data_list = self.construct_graphs(n_neighbors, train)

        with h5py.File(self.merfish_hdf5, "r") as h5f:
            self.gene_names = h5f["gene_names"][:][~self.bad_genes].astype("U")

        self.data, self.slices = self.collate(data_list)

    url = "https://datadryad.org/stash/downloads/file_stream/67671"

    behavior_types = [
        "Naive",
        "Parenting",
        "Virgin Parenting",
        "Aggression to pup",
        "Aggression to adult",
        "Mating",
    ]
    behavior_lookup = {x: i for (i, x) in enumerate(behavior_types)}
    cell_types = [
        "Ambiguous",
        "Astrocyte",
        "Endothelial 1",
        "Endothelial 2",
        "Endothelial 3",
        "Ependymal",
        "Excitatory",
        "Inhibitory",
        "Microglia",
        "OD Immature 1",
        "OD Immature 2",
        "OD Mature 1",
        "OD Mature 2",
        "OD Mature 3",
        "OD Mature 4",
        "Pericytes",
    ]
    celltype_lookup = {x: i for (i, x) in enumerate(cell_types)}

    bad_genes = np.zeros(161, dtype=bool)
    bad_genes[144] = True

    @property
    def raw_file_names(self):
        return ["merfish.csv", "merfish.hdf5"]

    @property
    def merfish_csv(self):
        return os.path.join(self.raw_dir, "merfish.csv")

    @property
    def merfish_hdf5(self):
        return os.path.join(self.raw_dir, "merfish.hdf5")

    def download(self):
        # download csv if necessary
        if not os.path.exists(self.merfish_csv):
            with open(self.merfish_csv, "wb") as csvf:
                csvf.write(requests.get(self.url).content)

        # process csv if necessary
        dataframe = pd.read_csv(self.merfish_csv)

        with h5py.File(self.merfish_hdf5, "w") as h5f:
            for colnm, dtype in zip(dataframe.keys()[:9], dataframe.dtypes[:9]):
                if dtype.kind == "O":
                    data = np.require(dataframe[colnm], dtype="S36")
                    h5f.create_dataset(colnm, data=data)
                else:
                    h5f.create_dataset(colnm, data=np.require(dataframe[colnm]))

            expression = np.array(dataframe[dataframe.keys()[9:]]).astype(np.float16)
            h5f.create_dataset("expression", data=expression)

            gene_names = np.array(dataframe.keys()[9:], dtype="S80")
            h5f.create_dataset("gene_names", data=gene_names)

    def construct_graph(self, data, anid, breg, n_neighbors):
        # get subset of cells in this slice
        good = (data.anids == anid) & (data.bregs == breg)

        # figure out neighborhood structure
        locations_for_this_slice = data.locations[good]
        nbrs = neighbors.NearestNeighbors(
            n_neighbors=n_neighbors + 1, algorithm="ball_tree"
        )
        nbrs.fit(locations_for_this_slice)
        _, kneighbors = nbrs.kneighbors(locations_for_this_slice)
        edges = np.concatenate(
            [np.c_[kneighbors[:, 0], kneighbors[:, i + 1]] for i in range(n_neighbors)],
            axis=0,
        )
        edges = torch.tensor(edges, dtype=torch.long).T

        # remove gene 144.  which is bad.  for some reason.
        subexpression = data.expression[good]
        subexpression = subexpression[:, ~self.bad_genes]

        # get behavior ids
        behavior_ids = np.array([self.behavior_lookup[x] for x in data.behavior[good]])
        celltype_ids = np.array([self.celltype_lookup[x] for x in data.celltypes[good]])
        labelinfo = np.c_[behavior_ids, celltype_ids]

        # make it into a torch geometric data object, add it to the list!
        return torch_geometric.data.Data(
            x=torch.tensor(subexpression.astype(np.float32)),
            edge_index=edges,
            pos=torch.tensor(locations_for_this_slice.astype(np.float32)),
            y=torch.tensor(labelinfo),
        )

    def construct_graphs(self, n_neighbors, train):
        # load hdf5
        with h5py.File(self.merfish_hdf5, "r") as h5f:
            # pylint: disable=no-member
            data = types.SimpleNamespace(
                anids=h5f["Animal_ID"][:],
                bregs=h5f["Bregma"][:],
                expression=h5f["expression"][:],
                locations=np.c_[h5f["Centroid_X"][:], h5f["Centroid_Y"][:]],
                behavior=h5f["Behavior"][:].astype("U"),
                celltypes=h5f["Cell_class"][:].astype("U"),
            )

        # get the (animal_id,bregma) pairs that define a unique slice
        unique_slices = np.unique(np.c_[data.anids, data.bregs], axis=0)

        # are we looking at train or test sets?
        unique_slices = unique_slices[:150] if train else unique_slices[150:]

        # store all the slices in this list...
        data_list = []
        for anid, breg in unique_slices:
            data_list.append(self.construct_graph(data, anid, breg, n_neighbors))

        return data_list

In [3]:
import pandas as pd

merfish = pd.read_csv("../data/raw/merfish.csv")