# Transformer on CITE-seq

In [1]:
import os
import gc

os.environ["NUMBA_CACHE_DIR"] = "/scratch/st-jiaruid-1/yinian/tmp/"  # https://github.com/scverse/scanpy/issues/2113
from os.path import basename, join
from os import makedirs
from pathlib import Path
import yaml

import logging
import anndata as ad
import pickle
import numpy as np
import pandas as pd
import scanpy as sc
import scipy

import h5py
import hdf5plugin
import tables

from sklearn.preprocessing import binarize
from sklearn.decomposition import TruncatedSVD

Matplotlib created a temporary config/cache directory at /tmp/pbs.4263223.pbsha.ib.sockeye/matplotlib-246i220g because the default path (/home/jovyan/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [121]:
cuda = torch.cuda.is_available()

## Load the data

In [2]:
def load_data_as_anndata(filepaths, metadata_path):
    """
    Loads the files in <filepaths> as AnnData objects

    Source: https://github.com/openproblems-bio/neurips_2022_saturn_notebooks/blob/main/notebooks/loading_and_visualizing_all_data.ipynb
    """
    metadata_df = pd.read_csv(metadata_path)
    metadata_df = metadata_df.set_index("cell_id")

    adatas = {}
    chunk_size = 10000
    for name, filepath in filepaths.items():
        filename = basename(filepath)[:-3]
        logging.info(f"Loading {filename}")

        h5_file = h5py.File(filepath)
        h5_data = h5_file[filename]

        features = h5_data["axis0"][:]
        cell_ids = h5_data["axis1"][:]

        features = features.astype(str)
        cell_ids = cell_ids.astype(str)

        technology = metadata_df.loc[cell_ids, "technology"].unique().item()

        sparse_chunks = []
        n_cells = h5_data["block0_values"].shape[0]

        for chunk_indices in np.array_split(np.arange(n_cells), 100):
            chunk = h5_data["block0_values"][chunk_indices]
            sparse_chunk = scipy.sparse.csr_matrix(chunk)
            sparse_chunks.append(sparse_chunk)

        X = scipy.sparse.vstack(sparse_chunks)

        adata = ad.AnnData(
            X=X,
            obs=metadata_df.loc[cell_ids],
            var=pd.DataFrame(index=features),
        )

        adatas[name] = adata

    return adatas

In [3]:
config = yaml.safe_load(Path('/scratch/st-jiaruid-1/yinian/my_jupyter/scRNA-competition/experiments/basic-nn-cite.yaml').read_text())
adatas = load_data_as_anndata(config['paths'], config['metadata'])

In [4]:
x_train = adatas['x']
x_test = adatas['x_test']
y_train = adatas['y']
combined_data = ad.concat([x_train, x_test])

## Generate PCA embeddings of dimension 140

In [7]:
def pca_data(data, dimension):
    pca = TruncatedSVD(n_components=dimension, random_state=42)
    transformed = pca.fit_transform(data.X)
    new_data = ad.AnnData(transformed, data.obs, data.uns)
    return new_data

In [8]:
pca_combined_data = pca_data(combined_data, 140)

In [23]:
cell_type_proportions = {}
for cell_type in set(pca_combined_data.obs['cell_type']):
    cell_type_proportions[cell_type] = sum(pca_combined_data.obs['cell_type'] == cell_type) / pca_combined_data.shape[0]

## Generate input data

In [41]:
def separate_data(data):
    cell_day_dic = {}
    for cell_type in set(data.obs['cell_type']):
        for day in set(data.obs['day']):
            cell_day_data = data[np.logical_and(data.obs['day'] == day, data.obs['cell_type'] == cell_type)]
            if cell_day_data.shape[0] == 0:
                continue
            cell_day_dic[(cell_type, day)] = cell_day_data.obs_names
    return cell_day_dic

In [107]:
def generate_sequence(pca_combined_data, y_train, cell_type, indices):
    seq = []
    for day in (2, 3, 4):
        day_indices = indices[(cell_type, day)]
        seq.append(np.random.choice(day_indices))
    return pca_combined_data[seq, :].X.toarray(), y_train[seq, :].X.toarray()

In [110]:
def generate_train_data(pca_combined_data, y_train, cell_type_proportions, num_samples=100_000):
    cell_types = list(cell_type_proportions.keys())
    cell_type_probs = list(cell_type_proportions.values())
    indices = separate_data(y_train)
    x_data = []
    y_data = []
    for i in range(num_samples):
        cell_type = np.random.choice(cell_types, p=cell_type_probs)
        x, y = generate_sequence(pca_combined_data, y_train, cell_type, indices)
        x_data.append(x)
        y_data.append(y)
    return np.stack(x_data, axis=0), np.stack(y_data, axis=0)

In [111]:
generated_x_train, generated_y_train = generate_train_data(pca_combined_data, y_train, cell_type_proportions, 25000)

## Transformer model

In [92]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.nn import Transformer

  from .autonotebook import tqdm as notebook_tqdm


In [122]:
model = Transformer(d_model=140, nhead=2, batch_first=True)
if cuda:
    model.to('cuda')
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

### Make TensorDataset

In [116]:
dataset = TensorDataset(torch.Tensor(generated_x_train), torch.tensor(generated_y_train))
train_num = int(len(dataset) * 9/10)
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_num, len(dataset) - train_num]
)

In [119]:
training_loader = DataLoader(train_dataset, batch_size=1000)
validation_loader = DataLoader(val_dataset, batch_size=1000)

### Train the Transformer

In [130]:
def train_one_epoch(model, training_loader, epoch_index, loss_fn, optimizer):
    """
    Literally the most basic training epoch
    """
    running_loss = 0.0
    last_loss = 0.0

    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        if cuda:
            inputs = inputs.to("cuda")
            labels = labels.to("cuda")

        optimizer.zero_grad()

        outputs = model(inputs, labels)

        loss = loss_fn(outputs, labels)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()
        last_loss = running_loss
        print("  batch {} loss: {}".format(i + 1, running_loss))
        running_loss = 0.0
        del inputs, labels
        gc.collect()

    return last_loss

In [131]:
for epoch in range(10):
    avg_loss = train_one_epoch(model, training_loader, epoch, loss_fn, optimizer)
    print(avg_loss)

  batch 1 loss: 16.50493812561035
  batch 2 loss: 13.168618202209473
  batch 3 loss: 12.00985050201416
  batch 4 loss: 11.608171463012695
  batch 5 loss: 11.750839233398438
  batch 6 loss: 11.373592376708984
  batch 7 loss: 11.413345336914062
  batch 8 loss: 11.69194221496582
  batch 9 loss: 11.454431533813477
  batch 10 loss: 11.55285930633545
  batch 11 loss: 11.570755958557129
  batch 12 loss: 11.609855651855469
  batch 13 loss: 11.196542739868164
  batch 14 loss: 11.43812370300293
  batch 15 loss: 11.33919620513916
  batch 16 loss: 11.715832710266113
  batch 17 loss: 11.545110702514648
  batch 18 loss: 11.280266761779785
  batch 19 loss: 11.530559539794922
  batch 20 loss: 11.57193374633789
  batch 21 loss: 11.424890518188477
  batch 22 loss: 11.449894905090332
  batch 23 loss: 11.020182609558105
11.020182609558105
  batch 1 loss: 11.327436447143555
  batch 2 loss: 11.384275436401367
  batch 3 loss: 11.347406387329102
  batch 4 loss: 11.229406356811523
  batch 5 loss: 11.4549140930

8.644827842712402


In [132]:
def corrcoeff(y_pred, y_true):
    '''Pearson Correlation Coefficient
    Implementation in Torch, without shifting to cpu, detach, numpy (consumes time)
    '''
    y_true_ = y_true - torch.mean(y_true, 1, keepdim=True)
    y_pred_ = y_pred - torch.mean(y_pred, 1, keepdim=True)

    num = (y_true_ * y_pred_).sum(1, keepdim=True)
    den = torch.sqrt(((y_pred_ ** 2).sum(1, keepdim=True)) * ((y_true_ ** 2).sum(1, keepdim=True)))

    return (num/den).mean()

In [133]:
running_vcorr = 0.0
for i, vdata in enumerate(validation_loader):
    vinputs, vlabels = vdata
    if cuda:
        vinputs = vinputs.to("cuda")
        vlabels = vlabels.to("cuda")
    voutputs = model(vinputs, vlabels)
    vcorr = corrcoeff(voutputs, vlabels)
    running_vcorr += vcorr
    del vinputs, vlabels
    gc.collect()

running_vcorr / (i+1)

tensor(0.1059, device='cuda:0', grad_fn=<DivBackward0>)

## Generate test data

For each test data point, generate a bunch sequences of that cell type and use the average prediction.