# Transformer on CITE-seq

In [1]:
import os
import gc

os.environ["NUMBA_CACHE_DIR"] = "/scratch/st-jiaruid-1/yinian/tmp/"  # https://github.com/scverse/scanpy/issues/2113
from os.path import basename, join
from os import makedirs
from pathlib import Path
import yaml

import logging
import anndata as ad
import pickle
import numpy as np
import pandas as pd
import scanpy as sc
import scipy

import h5py
import hdf5plugin
import tables

from sklearn.preprocessing import binarize
from sklearn.decomposition import TruncatedSVD

Matplotlib created a temporary config/cache directory at /tmp/pbs.4263223.pbsha.ib.sockeye/matplotlib-246i220g because the default path (/home/jovyan/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [121]:
cuda = torch.cuda.is_available()

## Load the data

In [2]:
def load_data_as_anndata(filepaths, metadata_path):
    """
    Loads the files in <filepaths> as AnnData objects

    Source: https://github.com/openproblems-bio/neurips_2022_saturn_notebooks/blob/main/notebooks/loading_and_visualizing_all_data.ipynb
    """
    metadata_df = pd.read_csv(metadata_path)
    metadata_df = metadata_df.set_index("cell_id")

    adatas = {}
    chunk_size = 10000
    for name, filepath in filepaths.items():
        filename = basename(filepath)[:-3]
        logging.info(f"Loading {filename}")

        h5_file = h5py.File(filepath)
        h5_data = h5_file[filename]

        features = h5_data["axis0"][:]
        cell_ids = h5_data["axis1"][:]

        features = features.astype(str)
        cell_ids = cell_ids.astype(str)

        technology = metadata_df.loc[cell_ids, "technology"].unique().item()

        sparse_chunks = []
        n_cells = h5_data["block0_values"].shape[0]

        for chunk_indices in np.array_split(np.arange(n_cells), 100):
            chunk = h5_data["block0_values"][chunk_indices]
            sparse_chunk = scipy.sparse.csr_matrix(chunk)
            sparse_chunks.append(sparse_chunk)

        X = scipy.sparse.vstack(sparse_chunks)

        adata = ad.AnnData(
            X=X,
            obs=metadata_df.loc[cell_ids],
            var=pd.DataFrame(index=features),
        )

        adatas[name] = adata

    return adatas

In [3]:
config = yaml.safe_load(Path('/scratch/st-jiaruid-1/yinian/my_jupyter/scRNA-competition/experiments/basic-nn-cite.yaml').read_text())
adatas = load_data_as_anndata(config['paths'], config['metadata'])

In [4]:
x_train = adatas['x']
x_test = adatas['x_test']
y_train = adatas['y']
combined_data = ad.concat([x_train, x_test])

## Generate PCA embeddings of dimension 140

In [7]:
def pca_data(data, dimension):
    pca = TruncatedSVD(n_components=dimension, random_state=42)
    transformed = pca.fit_transform(data.X)
    new_data = ad.AnnData(transformed, data.obs, data.uns)
    return new_data

In [8]:
pca_combined_data = pca_data(combined_data, 140)

In [23]:
cell_type_proportions = {}
for cell_type in set(pca_combined_data.obs['cell_type']):
    cell_type_proportions[cell_type] = sum(pca_combined_data.obs['cell_type'] == cell_type) / pca_combined_data.shape[0]

## Generate input data

In [41]:
def separate_data(data):
    cell_day_dic = {}
    for cell_type in set(data.obs['cell_type']):
        for day in set(data.obs['day']):
            cell_day_data = data[np.logical_and(data.obs['day'] == day, data.obs['cell_type'] == cell_type)]
            if cell_day_data.shape[0] == 0:
                continue
            cell_day_dic[(cell_type, day)] = cell_day_data.obs_names
    return cell_day_dic

In [107]:
def generate_sequence(pca_combined_data, y_train, cell_type, indices):
    seq = []
    for day in (2, 3, 4):
        day_indices = indices[(cell_type, day)]
        seq.append(np.random.choice(day_indices))
    return pca_combined_data[seq, :].X.toarray(), y_train[seq, :].X.toarray()

In [110]:
def generate_train_data(pca_combined_data, y_train, cell_type_proportions, num_samples=100_000):
    cell_types = list(cell_type_proportions.keys())
    cell_type_probs = list(cell_type_proportions.values())
    indices = separate_data(y_train)
    x_data = []
    y_data = []
    for i in range(num_samples):
        cell_type = np.random.choice(cell_types, p=cell_type_probs)
        x, y = generate_sequence(pca_combined_data, y_train, cell_type, indices)
        x_data.append(x)
        y_data.append(y)
    return np.stack(x_data, axis=0), np.stack(y_data, axis=0)

In [135]:
generated_x_train, generated_y_train = generate_train_data(pca_combined_data, y_train, cell_type_proportions, 200000)

## Transformer model

In [92]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.nn import Transformer

  from .autonotebook import tqdm as notebook_tqdm


In [136]:
model = Transformer(d_model=140, nhead=2, batch_first=True)
if cuda:
    model.to('cuda')
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

### Make TensorDataset

In [137]:
dataset = TensorDataset(torch.Tensor(generated_x_train), torch.tensor(generated_y_train))
train_num = int(len(dataset) * 9/10)
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_num, len(dataset) - train_num]
)

In [138]:
training_loader = DataLoader(train_dataset, batch_size=1000)
validation_loader = DataLoader(val_dataset, batch_size=1000)

### Train the Transformer

In [139]:
def train_one_epoch(model, training_loader, epoch_index, loss_fn, optimizer):
    """
    Literally the most basic training epoch
    """
    running_loss = 0.0
    last_loss = 0.0

    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        if cuda:
            inputs = inputs.to("cuda")
            labels = labels.to("cuda")

        optimizer.zero_grad()

        outputs = model(inputs, labels)

        loss = loss_fn(outputs, labels)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()
        last_loss = running_loss
        print("  batch {} loss: {}".format(i + 1, running_loss))
        running_loss = 0.0
        del inputs, labels
        gc.collect()

    return last_loss

In [None]:
for epoch in range(10):
    avg_loss = train_one_epoch(model, training_loader, epoch, loss_fn, optimizer)
    
    running_vcorr = 0.0
    for i, vdata in enumerate(validation_loader):
        vinputs, vlabels = vdata
        if cuda:
            vinputs = vinputs.to("cuda")
            vlabels = vlabels.to("cuda")
        voutputs = model(vinputs, vlabels)
        vcorr = corrcoeff(voutputs, vlabels)
        running_vcorr += vcorr
        del vinputs, vlabels
        gc.collect()

    print(f"EPOCH: {epoch} MSE loss: {avg_loss}\nCORR: {running_vcorr / (i+1)}")

  batch 1 loss: 2.1027958393096924
  batch 2 loss: 2.2796595096588135
  batch 3 loss: 2.2264785766601562
  batch 4 loss: 2.4719793796539307
  batch 5 loss: 2.183279514312744
  batch 6 loss: 2.0771780014038086
  batch 7 loss: 2.203709125518799
  batch 8 loss: 2.217686176300049
  batch 9 loss: 2.1932010650634766
  batch 10 loss: 2.1574747562408447
  batch 11 loss: 2.1221885681152344
  batch 12 loss: 2.1732869148254395
  batch 13 loss: 2.0254836082458496
  batch 14 loss: 2.0416934490203857
  batch 15 loss: 2.1138007640838623
  batch 16 loss: 2.154921293258667
  batch 17 loss: 2.109858512878418
  batch 18 loss: 2.2187130451202393
  batch 19 loss: 2.159362316131592
  batch 20 loss: 2.0656349658966064
  batch 21 loss: 2.217397928237915
  batch 22 loss: 2.076099157333374
  batch 23 loss: 2.157496929168701
  batch 24 loss: 2.1282711029052734
  batch 25 loss: 2.110861301422119
  batch 26 loss: 2.0100178718566895
  batch 27 loss: 2.150425910949707
  batch 28 loss: 2.1155900955200195
  batch 29 l

  batch 49 loss: 1.812227487564087
  batch 50 loss: 1.7467308044433594
  batch 51 loss: 1.824634075164795
  batch 52 loss: 1.7866535186767578
  batch 53 loss: 1.7878735065460205
  batch 54 loss: 1.74176824092865
  batch 55 loss: 1.8085062503814697
  batch 56 loss: 1.8415099382400513
  batch 57 loss: 1.9297717809677124
  batch 58 loss: 1.8163669109344482
  batch 59 loss: 1.8373198509216309
  batch 60 loss: 1.8594261407852173
  batch 61 loss: 1.6902827024459839
  batch 62 loss: 1.822744607925415
  batch 63 loss: 1.722785234451294
  batch 64 loss: 1.8817871809005737
  batch 65 loss: 1.8456673622131348
  batch 66 loss: 1.7840392589569092
  batch 67 loss: 1.8215981721878052
  batch 68 loss: 1.915765404701233
  batch 69 loss: 1.7714972496032715
  batch 70 loss: 1.804119348526001
  batch 71 loss: 1.8668862581253052
  batch 72 loss: 1.9884092807769775
  batch 73 loss: 1.7129008769989014
  batch 74 loss: 1.70928955078125
  batch 75 loss: 1.7831708192825317
  batch 76 loss: 1.6826738119125366
  

  batch 96 loss: 1.5150212049484253
  batch 97 loss: 1.53345787525177
  batch 98 loss: 1.5371865034103394
  batch 99 loss: 1.572500228881836
  batch 100 loss: 1.401063084602356
  batch 101 loss: 1.615871548652649
  batch 102 loss: 1.5094974040985107
  batch 103 loss: 1.5203776359558105
  batch 104 loss: 1.5648781061172485
  batch 105 loss: 1.604698657989502
  batch 106 loss: 1.483328104019165
  batch 107 loss: 1.5891718864440918
  batch 108 loss: 1.5937179327011108
  batch 109 loss: 1.4367541074752808
  batch 110 loss: 1.4615397453308105
  batch 111 loss: 1.466422438621521
  batch 112 loss: 1.4908201694488525
  batch 113 loss: 1.5050722360610962
  batch 114 loss: 1.4490188360214233
  batch 115 loss: 1.511762022972107
  batch 116 loss: 1.5243432521820068
  batch 117 loss: 1.5118495225906372
  batch 118 loss: 1.57705819606781
  batch 119 loss: 1.4532119035720825
  batch 120 loss: 1.6339677572250366
  batch 121 loss: 1.5807651281356812
  batch 122 loss: 1.4190714359283447
  batch 123 loss

  batch 142 loss: 1.359300971031189
  batch 143 loss: 1.3387407064437866
  batch 144 loss: 1.2568440437316895
  batch 145 loss: 1.2828633785247803
  batch 146 loss: 1.2296931743621826
  batch 147 loss: 1.2596427202224731
  batch 148 loss: 1.3531230688095093
  batch 149 loss: 1.3998326063156128
  batch 150 loss: 1.2585285902023315
  batch 151 loss: 1.2995892763137817
  batch 152 loss: 1.283079743385315
  batch 153 loss: 1.3505940437316895
  batch 154 loss: 1.3007214069366455
  batch 155 loss: 1.3402231931686401
  batch 156 loss: 1.330495834350586
  batch 157 loss: 1.2161544561386108
  batch 158 loss: 1.240110158920288
  batch 159 loss: 1.3589699268341064
  batch 160 loss: 1.2212783098220825
  batch 161 loss: 1.1599435806274414
  batch 162 loss: 1.2534923553466797
  batch 163 loss: 1.2627662420272827
  batch 164 loss: 1.2725586891174316
  batch 165 loss: 1.3267103433609009
  batch 166 loss: 1.303421139717102
  batch 167 loss: 1.232697606086731
  batch 168 loss: 1.3297253847122192
  batch

  batch 6 loss: 1.0426054000854492
  batch 7 loss: 1.1239105463027954
  batch 8 loss: 1.1419004201889038
  batch 9 loss: 1.1464427709579468
  batch 10 loss: 1.0989429950714111
  batch 11 loss: 1.0872212648391724
  batch 12 loss: 1.11024808883667
  batch 13 loss: 1.0074546337127686
  batch 14 loss: 1.0167511701583862
  batch 15 loss: 1.0721708536148071
  batch 16 loss: 1.1420748233795166
  batch 17 loss: 1.0749483108520508
  batch 18 loss: 1.1571592092514038
  batch 19 loss: 1.120277762413025
  batch 20 loss: 1.048801064491272
  batch 21 loss: 1.1778664588928223
  batch 22 loss: 1.043519139289856
  batch 23 loss: 1.128670334815979
  batch 24 loss: 1.1027950048446655
  batch 25 loss: 1.0957064628601074
  batch 26 loss: 1.0061191320419312
  batch 27 loss: 1.127484917640686
  batch 28 loss: 1.0753583908081055
  batch 29 loss: 1.156449317932129
  batch 30 loss: 1.0942732095718384
  batch 31 loss: 1.135563611984253
  batch 32 loss: 1.155699372291565
  batch 33 loss: 1.003970980644226
  batch

  batch 53 loss: 0.9312506914138794
  batch 54 loss: 0.8866769075393677
  batch 55 loss: 0.9404717683792114
  batch 56 loss: 0.9804242849349976
  batch 57 loss: 1.0609875917434692
  batch 58 loss: 0.9347256422042847
  batch 59 loss: 0.9745929837226868
  batch 60 loss: 0.9793203473091125
  batch 61 loss: 0.8616926670074463
  batch 62 loss: 0.9661965370178223
  batch 63 loss: 0.8889434933662415
  batch 64 loss: 0.9987026453018188
  batch 65 loss: 0.9876698851585388
  batch 66 loss: 0.9392924308776855
  batch 67 loss: 0.9643647074699402
  batch 68 loss: 1.0314946174621582
  batch 69 loss: 0.9211907386779785
  batch 70 loss: 0.9505137205123901
  batch 71 loss: 0.9964882731437683
  batch 72 loss: 1.0865650177001953
  batch 73 loss: 0.8821011781692505
  batch 74 loss: 0.8807632923126221
  batch 75 loss: 0.9258272051811218
  batch 76 loss: 0.8544911742210388
  batch 77 loss: 1.0471782684326172
  batch 78 loss: 0.9611110091209412
  batch 79 loss: 0.9271880388259888
  batch 80 loss: 0.935573756

  batch 98 loss: 0.8201615810394287
  batch 99 loss: 0.8587504029273987
  batch 100 loss: 0.7351649403572083
  batch 101 loss: 0.8922380805015564
  batch 102 loss: 0.8174141645431519
  batch 103 loss: 0.8055301308631897
  batch 104 loss: 0.8602185845375061
  batch 105 loss: 0.8720000982284546
  batch 106 loss: 0.7822149991989136
  batch 107 loss: 0.8687703609466553
  batch 108 loss: 0.8799967169761658
  batch 109 loss: 0.7624281644821167
  batch 110 loss: 0.7893801331520081
  batch 111 loss: 0.7793403267860413
  batch 112 loss: 0.8024055361747742
  batch 113 loss: 0.8172909021377563
  batch 114 loss: 0.7746617197990417
  batch 115 loss: 0.8186706304550171
  batch 116 loss: 0.8267825841903687
  batch 117 loss: 0.8218148350715637
  batch 118 loss: 0.860115647315979
  batch 119 loss: 0.7786560654640198
  batch 120 loss: 0.9150099158287048
  batch 121 loss: 0.8725005984306335
  batch 122 loss: 0.7433277368545532
  batch 123 loss: 0.7369624972343445
  batch 124 loss: 0.823130190372467
  bat

  batch 142 loss: 0.7882201671600342
  batch 143 loss: 0.7692995071411133
  batch 144 loss: 0.6925491690635681
  batch 145 loss: 0.7071466445922852
  batch 146 loss: 0.687815248966217
  batch 147 loss: 0.7036543488502502
  batch 148 loss: 0.7840592861175537
  batch 149 loss: 0.8502041697502136
  batch 150 loss: 0.7086080312728882
  batch 151 loss: 0.7465640902519226
  batch 152 loss: 0.7152255773544312
  batch 153 loss: 0.7800861597061157
  batch 154 loss: 0.7292953729629517
  batch 155 loss: 0.7806651592254639
  batch 156 loss: 0.7676160335540771
  batch 157 loss: 0.6758802533149719
  batch 158 loss: 0.6868462562561035
  batch 159 loss: 0.7668842077255249
  batch 160 loss: 0.6778416633605957
  batch 161 loss: 0.6171512603759766
  batch 162 loss: 0.6865448951721191
  batch 163 loss: 0.7078889012336731
  batch 164 loss: 0.6905918121337891
  batch 165 loss: 0.7299857139587402
  batch 166 loss: 0.696053683757782
  batch 167 loss: 0.6590033173561096
  batch 168 loss: 0.7315802574157715
  b

In [132]:
def corrcoeff(y_pred, y_true):
    '''Pearson Correlation Coefficient
    Implementation in Torch, without shifting to cpu, detach, numpy (consumes time)
    '''
    y_true_ = y_true - torch.mean(y_true, 1, keepdim=True)
    y_pred_ = y_pred - torch.mean(y_pred, 1, keepdim=True)

    num = (y_true_ * y_pred_).sum(1, keepdim=True)
    den = torch.sqrt(((y_pred_ ** 2).sum(1, keepdim=True)) * ((y_true_ ** 2).sum(1, keepdim=True)))

    return (num/den).mean()

In [133]:
running_vcorr = 0.0
for i, vdata in enumerate(validation_loader):
    vinputs, vlabels = vdata
    if cuda:
        vinputs = vinputs.to("cuda")
        vlabels = vlabels.to("cuda")
    voutputs = model(vinputs, vlabels)
    vcorr = corrcoeff(voutputs, vlabels)
    running_vcorr += vcorr
    del vinputs, vlabels
    gc.collect()

running_vcorr / (i+1)

tensor(0.1059, device='cuda:0', grad_fn=<DivBackward0>)

## Generate test data

For each test data point, generate a bunch sequences of that cell type and use the average prediction.