# Transformer on CITE-seq

In [1]:
import os
import gc

os.environ["NUMBA_CACHE_DIR"] = "/scratch/st-jiaruid-1/yinian/tmp/"  # https://github.com/scverse/scanpy/issues/2113
from os.path import basename, join
from os import makedirs
from pathlib import Path
import yaml

import logging
import anndata as ad
import pickle
import numpy as np
import pandas as pd
import scanpy as sc
import scipy

import h5py
import hdf5plugin
import tables

import math

from sklearn.preprocessing import binarize
from sklearn.decomposition import TruncatedSVD

import torch
import torch.nn as nn
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.nn import Transformer

Matplotlib created a temporary config/cache directory at /tmp/pbs.4275416.pbsha.ib.sockeye/matplotlib-c4rg7olh because the default path (/home/jovyan/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cuda = torch.cuda.is_available()

## Load the data

In [3]:
def load_data_as_anndata(filepaths, metadata_path):
    """
    Loads the files in <filepaths> as AnnData objects

    Source: https://github.com/openproblems-bio/neurips_2022_saturn_notebooks/blob/main/notebooks/loading_and_visualizing_all_data.ipynb
    """
    metadata_df = pd.read_csv(metadata_path)
    metadata_df = metadata_df.set_index("cell_id")

    adatas = {}
    chunk_size = 10000
    for name, filepath in filepaths.items():
        filename = basename(filepath)[:-3]
        logging.info(f"Loading {filename}")

        h5_file = h5py.File(filepath)
        h5_data = h5_file[filename]

        features = h5_data["axis0"][:]
        cell_ids = h5_data["axis1"][:]

        features = features.astype(str)
        cell_ids = cell_ids.astype(str)

        technology = metadata_df.loc[cell_ids, "technology"].unique().item()

        sparse_chunks = []
        n_cells = h5_data["block0_values"].shape[0]

        for chunk_indices in np.array_split(np.arange(n_cells), 100):
            chunk = h5_data["block0_values"][chunk_indices]
            sparse_chunk = scipy.sparse.csr_matrix(chunk)
            sparse_chunks.append(sparse_chunk)

        X = scipy.sparse.vstack(sparse_chunks)

        adata = ad.AnnData(
            X=X,
            obs=metadata_df.loc[cell_ids],
            var=pd.DataFrame(index=features),
        )

        adatas[name] = adata

    return adatas

In [30]:
config = yaml.safe_load(Path('/scratch/st-jiaruid-1/yinian/my_jupyter/scRNA-competition/experiments/basic-nn-cite.yaml').read_text())
adatas = load_data_as_anndata(config['paths'], config['metadata'])

In [5]:
x_train = adatas['x']
x_test = adatas['x_test']
y_train = adatas['y']
combined_data = ad.concat([x_train, x_test])

## Generate PCA embeddings of dimension 140

In [6]:
def pca_data(data, dimension):
    pca = TruncatedSVD(n_components=dimension, random_state=42)
    transformed = pca.fit_transform(data.X)
    new_data = ad.AnnData(transformed, data.obs, data.uns)
    return new_data, pca

In [7]:
# pca_combined_data, pca_combined_data_source = pca_data(combined_data, 140)

In [4]:
pca_combined_data = pickle.load(open('/scratch/st-jiaruid-1/yinian/pca_combined_cite.pkl', 'rb'))
y_train = pickle.load(open('/scratch/st-jiaruid-1/yinian/y_cite.pkl', 'rb'))
x_test = pickle.load(open('/scratch/st-jiaruid-1/yinian/x_test_cite.pkl', 'rb'))

In [5]:
# pca_y, pca_y_source = pca_data(y_train, 256)

In [6]:
cell_type_proportions = {}
hidden = sum(pca_combined_data.obs['cell_type'] == 'hidden')
for cell_type in set(pca_combined_data.obs['cell_type']):
    if cell_type != 'hidden':
        cell_type_proportions[cell_type] = sum(pca_combined_data.obs['cell_type'] == cell_type) / (pca_combined_data.shape[0] - hidden)

In [7]:
cell_type_proportions

{'BP': 0.0025323649614294908,
 'MasP': 0.15118971007346366,
 'HSC': 0.3583254632222046,
 'MkP': 0.09026251347669471,
 'NeuP': 0.1790039364485044,
 'MoP': 0.015227620329123868,
 'EryP': 0.20345839148857928}

## Generate input data

In [8]:
def separate_data(data):
    cell_day_dic = {}
    for cell_type in set(data.obs['cell_type']):
        for day in set(data.obs['day']):
            cell_day_data = data[np.logical_and(data.obs['day'] == day, data.obs['cell_type'] == cell_type)]
            if cell_day_data.shape[0] == 0:
                continue
            cell_day_dic[(cell_type, day)] = cell_day_data.obs_names
    return cell_day_dic

In [9]:
separate_data(y_train)

{('BP',
  2): Index(['003fe3679efa', '141cf9566045', '64f09dfa8830', 'ffeb4d94d62c',
        '0c7c3b21ee58', '589f9705d916', 'b1ab2ffa799c', 'dbc954a3de9f',
        '381722e45901', 'a4b1499fa792',
        ...
        '61fe6a252b63', '851226be85f4', '05f331cc2c73', '342f17241c43',
        '9782d8e3306d', '643e33a230ca', '1f00cb777ec4', '4782c74c1aa1',
        'd874d8f3436f', 'c9e91caf1b6e'],
       dtype='object', name='cell_id', length=117),
 ('BP',
  3): Index(['27b3c7ec6711', 'd3db9f743eda', 'e6a4759dfec2', 'f1a8f93a0000',
        'd9dd09f607ff', '352ad554b236', 'f203ef8e9e80', 'cd3ff3d83fcf',
        '5dd9d0cd81c4', '50dcba77ce10', '619e6861fda0', '1d872b054b55',
        '3a9fa86f9c6d', '7e5e2f4ff66c'],
       dtype='object', name='cell_id'),
 ('BP',
  4): Index(['524116849ba7', '125268612718', 'dc18d771eccd', '39a21ec5a70e',
        '9feb13b3d1d7', 'c8acfbe2df71', 'a3a68f9a5861', '0a48ba82bf55',
        'ed3807e6a550', 'e632e9ba8194', 'a1c9a36bed33', 'eca9a8e0d45e',
        '8334ec

In [10]:
def generate_sequence(pca_combined_data, y_train, cell_type, indices):
    seq = []
    for day in (2, 3, 4):
        day_indices = indices[(cell_type, day)]
        seq.append(np.random.choice(day_indices))
    return pca_combined_data[seq, :].X.toarray(), y_train[seq, :].X.toarray()

In [11]:
def generate_train_data(pca_combined_data, y_train, cell_type_proportions, num_samples=1_000_000):
    cell_types = list(cell_type_proportions.keys())
    cell_type_probs = list(cell_type_proportions.values())
    indices = separate_data(y_train)
    x_data = []
    y_data = []
    for i in range(num_samples):
        cell_type = np.random.choice(cell_types, p=cell_type_probs)
        x, y = generate_sequence(pca_combined_data, y_train, cell_type, indices)
        x_data.append(x)
        y_data.append(y)
    return np.stack(x_data, axis=0), np.stack(y_data, axis=0)

In [15]:
# generated_x_train, generated_y_train = generate_train_data(pca_combined_data, y_train, cell_type_proportions, 400_000)

### Make TensorDataset

In [12]:
class SCDataset(Dataset):
    def __init__(self, pca_combined_data, y_train, y_pca, cell_type_proportions, size, technology):
        self.pca_combined_data = pca_combined_data
        self.y_train = y_train
        self.y_pca = y_pca
        self.cell_type_proportions = cell_type_proportions
        if technology == 'multiome':
            self.size = size * 4
        else:
            self.size = size * 3
        
        cell_types = list(cell_type_proportions.keys())
        cell_type_probs = list(cell_type_proportions.values())
        indices = separate_data(y_train)
        self.technology = technology
        if technology == 'multiome':
            days = (2, 3, 4, 7)
        else:
            days = (2, 3, 4)
        
        self.sequences = []        
        for i in range(size):
            cell_type = np.random.choice(cell_types)
            seq = []
            for day in days:
                day_indices = indices[(cell_type, day)]
                seq.append(np.random.choice(day_indices))
            self.sequences.append(seq)
    
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        if self.technology == 'cite':
            y = self.y_train[self.sequences[idx // 3], :].X.toarray()
            x = self.pca_combined_data[self.sequences[idx // 3], :].X.toarray()
            y_0 = self.y_train[self.sequences[idx // 3], :].X.toarray()
            y_orig = self.y_train[self.sequences[idx // 3], :].X.toarray()
#             if idx % 3 == 0:
#                 y_0[0, :] = -2147483647
#             elif idx % 3 == 1:
#                 y_0[1, :] = -2147483647
#             else:
#                 y_0[2, :] = -2147483647
        else:
            y = self.y_pca[self.sequences[idx // 4], :].X.toarray()
            x = self.pca_combined_data[self.sequences[idx // 4], :].X.toarray()
            y_0 = self.y_pca[self.sequences[idx // 4], :].X.toarray()
            y_orig = self.y_train[self.sequences[idx // 4], :].X.toarray()
            if idx % 4 == 0:
                y_0[0, :] = 0
            elif idx % 4 == 1:
                y_0[1, :] = 0
            elif idx % 4 == 2:
                y_0[2, :] = 0
            else:
                y_0[3, :] = 0
        return x * 10, y, y_0, y_orig

In [13]:
dataset = SCDataset(pca_combined_data, y_train, y_train, cell_type_proportions, 5000, "cite")

# dataset = TensorDataset(torch.Tensor(generated_x_train), torch.tensor(generated_y_train))
train_num = int(len(dataset) * 4/5)
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_num, len(dataset) - train_num]
)

In [14]:
training_loader = DataLoader(train_dataset, batch_size=1000)
validation_loader = DataLoader(val_dataset, batch_size=1000)

## Transformer model

In [15]:
class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout_p, max_len):
        super().__init__()
        # Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
        # max_len determines how far the position can have an effect on a token (window)
        
        # Info
        self.dropout = nn.Dropout(dropout_p)
        
        # Encoding - From formula
        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) # 0, 1, 2, 3, 4, 5
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
        
        # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        
        # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        # Saving buffer (same as parameter without gradients needed)
        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pos_encoding",pos_encoding)
        
    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])

class SCTransformer(nn.Module):
    def __init__(self, output_dim, input_dim, nhead):
        super().__init__()
#         self.positional_encoder = PositionalEncoding(dim_model=input_dim, dropout_p=0.1, max_len=4)
        self.transformer = Transformer(
            d_model=input_dim,
            nhead=nhead,
            batch_first=True,
            num_encoder_layers=6,
            num_decoder_layers=6
        )
        self.out = nn.Linear(input_dim, output_dim)
    
    def forward(self, src, tgt):
#         src = self.positional_encoder(src)
#         tgt = self.positional_encoder(tgt)
        
        transformer_out = self.transformer(src, tgt)
        out = self.out(transformer_out)
        return out

In [16]:
model = SCTransformer(output_dim=140, input_dim=140, nhead=2)
if cuda:
    model.to('cuda')
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

### Train the Transformer

In [17]:
def correlation_score(y_true, y_pred):
    """Scores the predictions according to the competition rules.

    It is assumed that the predictions are not constant.

    Returns the average of each sample's Pearson correlation coefficient

    Source: https://www.kaggle.com/code/xiafire/lb-t15-msci-multiome-catboostregressor#Predicting
    """
    if y_true.shape != y_pred.shape:
        raise ValueError("Shapes are different.")
    corrsum = 0
    for i in range(len(y_true)):
        corrsum += np.corrcoef(y_true[i], y_pred[i])[1, 0]
    return corrsum / len(y_true)

In [18]:
def corrcoeff(y_pred, y_true):
    '''Pearson Correlation Coefficient
    Implementation in Torch, without shifting to cpu, detach, numpy (consumes time)
    '''
    y_true_ = y_true - torch.mean(y_true, 1, keepdim=True)
    y_pred_ = y_pred - torch.mean(y_pred, 1, keepdim=True)

    num = (y_true_ * y_pred_).sum(1, keepdim=True)
    den = torch.sqrt(((y_pred_ ** 2).sum(1, keepdim=True)) * ((y_true_ ** 2).sum(1, keepdim=True)))

    return (num/den).mean()

In [19]:
def train_one_epoch(model, training_loader, epoch_index, loss_fn, optimizer):
    """
    Literally the most basic training epoch
    """
    running_loss = 0.0
    last_loss = 0.0
    model.train()

    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels, masked_labels, labels_orig = data

        if cuda:
            inputs = inputs.to("cuda")
            labels = labels.to("cuda")
            masked_labels = masked_labels.to("cuda")

        optimizer.zero_grad()
#         print(inputs.shape, masked_labels.shape)
        outputs = model(inputs, masked_labels)

        loss = loss_fn(outputs, labels)
        loss.backward()
        
#         res = outputs.detach().cpu().numpy() @ pca_y_source.components_
#         labels_orig = labels_orig.detach().cpu().numpy()

        optimizer.step()

        curr_loss = loss.item()
        running_loss += curr_loss
        print("  batch {} loss: {}".format(i + 1, curr_loss))

        del inputs, labels
        gc.collect()

    return running_loss / (i + 1)

In [53]:
for epoch in range(10):
    avg_loss = train_one_epoch(model, training_loader, epoch, loss_fn, optimizer)
    
    model.eval()
    with torch.no_grad():
        running_vcorr = 0.0
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels, vmaskedlabels, labels_orig = vdata
            if cuda:
                vinputs = vinputs.to("cuda")
                vlabels = vlabels.to("cuda")
                vmaskedlabels = vmaskedlabels.to("cuda")
            voutputs = model(vinputs, vmaskedlabels)
            
#             res = voutputs.detach().cpu().numpy() @ pca_y_source.components_
            
            vcorr = 0
            for j in range(voutputs.shape[1]):
                vcorr += corrcoeff(voutputs[:, j, :], vlabels[:, j, :])
            vcorr /= j + 1

            running_vcorr += vcorr
#             del vinputs, vlabels
            gc.collect()


    print(f"EPOCH: {epoch} MSE loss: {avg_loss}\nCORR: {running_vcorr / (i+1)}")

  batch 1 loss: 7.154778480529785
  batch 2 loss: 7.510797500610352
  batch 3 loss: 6.983055591583252
  batch 4 loss: 7.362465858459473
  batch 5 loss: 6.965112686157227
  batch 6 loss: 7.2250261306762695
  batch 7 loss: 6.813897609710693
  batch 8 loss: 7.0283942222595215
  batch 9 loss: 6.7882609367370605
  batch 10 loss: 7.065382480621338
  batch 11 loss: 7.108630180358887
  batch 12 loss: 7.169112205505371
EPOCH: 0 MSE loss: 7.169112205505371
CORR: 0.7871603965759277
  batch 1 loss: 7.148752212524414
  batch 2 loss: 7.435673713684082
  batch 3 loss: 6.918435573577881
  batch 4 loss: 7.314789295196533
  batch 5 loss: 7.0017499923706055
  batch 6 loss: 7.176348686218262
  batch 7 loss: 6.807710647583008
  batch 8 loss: 6.973664283752441
  batch 9 loss: 6.707352638244629
  batch 10 loss: 6.988345623016357
  batch 11 loss: 7.047585487365723
  batch 12 loss: 7.286639213562012
EPOCH: 1 MSE loss: 7.286639213562012
CORR: 0.7923027873039246
  batch 1 loss: 7.034459114074707
  batch 2 loss: 

KeyboardInterrupt: 

In [54]:
voutputs[:, 0, :]

tensor([[2.7921, 0.4844, 1.1786,  ..., 5.0047, 2.2285, 3.7837],
        [2.7921, 0.4845, 1.1786,  ..., 5.0046, 2.2283, 3.7837],
        [0.3911, 0.4925, 1.0441,  ..., 0.3613, 4.0096, 3.4221],
        ...,
        [0.3911, 0.4925, 1.0441,  ..., 0.3613, 4.0096, 3.4221],
        [0.3911, 0.4925, 1.0441,  ..., 0.3613, 4.0096, 3.4221],
        [0.3911, 0.4925, 1.0441,  ..., 0.3613, 4.0096, 3.4221]],
       device='cuda:0')

In [51]:
vinputs[:,0,:].detach().cpu().numpy()

array([[ 2.5550774e+03, -1.1350066e+02, -7.2097031e+01, ...,
         3.9026688e+01,  9.2112980e+00,  1.6284214e+01],
       [ 2.8380879e+03,  7.7036423e+01,  4.5122791e+01, ...,
         7.0002890e+00, -1.7623823e+01, -2.1650360e+01],
       [ 2.4740940e+03, -7.8775620e+01, -4.1895789e+02, ...,
        -7.1026440e+00, -1.7761303e+01, -8.1611118e+00],
       ...,
       [ 2.6772498e+03, -4.1187103e+01, -1.6395892e+02, ...,
        -3.4901077e+01, -2.9137092e+00,  9.4901829e+00],
       [ 2.6216445e+03, -9.3096552e+00, -1.4173842e+02, ...,
        -1.7808605e+01, -2.4403126e+00, -4.9363903e+01],
       [ 2.4737693e+03, -4.6828751e+01, -3.8871545e+02, ...,
         3.4314991e+01,  8.8688736e+00,  1.6734121e+01]], dtype=float32)

## Generate test data

For each test data point, generate a bunch sequences of that cell type and use the average prediction.

In [21]:
def get_days(day):
    if day == 2:
        return (-1, 3, 4)
    elif day == 3:
        return (2, -1, 4)
    elif day == 4:
        return (2, 3, -1)
    elif day == 7:
        return (2, 3, 4, -1)
    
def gen_one_test_sequence(x_test_idx, cell_type, sequence_days, pca_combined_data, y_train, indices):
    """
    Generate a test sequence of a particular cell tpe incorporating the test data point
    
    Parameters:
    - x_test_idx: the index of the test data point in <pca_combined_data>
    - cell_type: the cell type of the test data point
    - sequence_days: the sequence of days to be taken, -1 means the test data point will be substituted in there
    - pca_combined_data: the train+test dataset
    - indices: indices of each cell grouped by day and cell_type
    """
    seq = []
    y_seq = []
    for day in sequence_days:
        if day == -1:
            seq.append(pca_combined_data[x_test_idx].obs_names[0])
            y_seq.append(np.zeros((1, 140)))
        else:
            day_indices = indices[(cell_type, day)]
            day_choice = np.random.choice(day_indices)
            seq.append(day_choice)
            y_seq.append(y_train[day_choice, :].X.toarray())
    return pca_combined_data[seq, :].X.toarray(), np.concatenate(y_seq)
    

def test_sequence(pca_combined_data, x_test, y_train, idx, num_samples=1000):
    """
    Generate <num_samples> sequence of cells across all days that include the test data point
    at index <idx> of <x_test>.
    """
    cell = x_test[idx]
    cell_type = cell.obs['cell_type'][0]
    sequence_days = get_days(cell.obs['day'][0])
    indices = separate_data(y_train)
    test_set_x = []
    test_set_y = []
    for i in range(num_samples):
        x, y = gen_one_test_sequence(len(y_train) + idx, cell_type, sequence_days, pca_combined_data, y_train, indices)
        test_set_x.append(x)
        test_set_y.append(y)
    return np.stack(test_set_x, axis=0), np.stack(test_set_y, axis=0)

In [38]:
for i in range(0, len(x_test), 1000):
    ts_x, ts_y = test_sequence(pca_combined_data, x_test, y_train, i, num_samples=100)
    ts_x, ts_y = torch.Tensor(ts_x), torch.Tensor(ts_y)
    res = model(ts_x, ts_y).detach().cpu().numpy()
    print(i)

0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000


In [36]:
np.average(res, axis=0).shape

(3, 140)

In [143]:
x, y, y_mask = val_dataset[10000]
x = torch.Tensor(x).to('cuda')
y = torch.Tensor(y).to('cuda')
y_mask = torch.Tensor(y_mask).to('cuda')
res = model(x, y_mask)
corrcoeff(res, y)

tensor(0.8501, device='cuda:0', grad_fn=<MeanBackward0>)

In [145]:
x, y = ts
x = torch.Tensor(x).to('cuda')
y = torch.Tensor(y).to('cuda')
model(x, y)

tensor([[[-2.3380e-01, -2.5301e-01,  5.6153e-01,  4.7070e+00,  4.0749e+00,
           7.0291e+00, -4.1253e-01,  4.2696e-01,  2.7988e-01, -1.6133e-02,
          -4.2794e-01,  4.3336e-01, -2.8055e-01,  7.1426e-02,  5.1469e+00,
           2.4168e+00,  5.5194e+00, -8.6265e-01,  1.4029e+00, -6.6716e-02,
           2.6619e-01,  2.7669e+00, -9.0417e-02, -1.6963e-01,  1.1509e+01,
          -1.3920e-01,  7.9307e-02, -8.7431e-02,  1.3485e-01, -1.3632e-01,
          -1.1982e-01,  6.0440e-01,  4.8803e-01,  8.1412e-02,  1.1880e-02,
          -1.3467e-01,  5.9267e-01,  1.6394e+01,  2.4749e-01,  7.4268e-03,
           1.7989e-01, -7.5410e-02,  1.1606e-01, -8.2716e-01, -5.0067e-01,
          -1.5443e-01,  9.2380e-01,  1.4570e-01,  3.8661e+00, -1.4493e-01,
           3.9532e-02, -6.7813e-03,  2.3593e+00, -1.2438e-01,  1.4728e+00,
           3.6491e-01, -1.6345e-02,  3.3940e+00,  3.3760e-01,  5.9738e-02,
          -5.7101e-02,  5.3698e-02,  1.3603e-01,  2.0610e-01,  4.5693e-01,
           4.8488e-02,  1