# Transformer on CITE-seq

In [1]:
import os
import gc

os.environ["NUMBA_CACHE_DIR"] = "/scratch/st-jiaruid-1/yinian/tmp/"  # https://github.com/scverse/scanpy/issues/2113
from os.path import basename, join
from os import makedirs
from pathlib import Path
import yaml

import logging
import anndata as ad
import pickle
import numpy as np
import pandas as pd
import scanpy as sc
import scipy

import h5py
import hdf5plugin
import tables

import math

from sklearn.preprocessing import binarize
from sklearn.decomposition import TruncatedSVD

import torch
import torch.nn as nn
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.nn import Transformer

Matplotlib created a temporary config/cache directory at /tmp/pbs.4287913.pbsha.ib.sockeye/matplotlib-8odo4t_q because the default path (/home/jovyan/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cuda = torch.cuda.is_available()

## Load the data

In [3]:
def load_data_as_anndata(filepaths, metadata_path):
    """
    Loads the files in <filepaths> as AnnData objects

    Source: https://github.com/openproblems-bio/neurips_2022_saturn_notebooks/blob/main/notebooks/loading_and_visualizing_all_data.ipynb
    """
    metadata_df = pd.read_csv(metadata_path)
    metadata_df = metadata_df.set_index("cell_id")

    adatas = {}
    chunk_size = 10000
    for name, filepath in filepaths.items():
        filename = basename(filepath)[:-3]
        logging.info(f"Loading {filename}")

        h5_file = h5py.File(filepath)
        h5_data = h5_file[filename]

        features = h5_data["axis0"][:]
        cell_ids = h5_data["axis1"][:]

        features = features.astype(str)
        cell_ids = cell_ids.astype(str)

        technology = metadata_df.loc[cell_ids, "technology"].unique().item()

        sparse_chunks = []
        n_cells = h5_data["block0_values"].shape[0]

        for chunk_indices in np.array_split(np.arange(n_cells), 100):
            chunk = h5_data["block0_values"][chunk_indices]
            sparse_chunk = scipy.sparse.csr_matrix(chunk)
            sparse_chunks.append(sparse_chunk)

        X = scipy.sparse.vstack(sparse_chunks)

        adata = ad.AnnData(
            X=X,
            obs=metadata_df.loc[cell_ids],
            var=pd.DataFrame(index=features),
        )

        adatas[name] = adata

    return adatas

In [4]:
config = yaml.safe_load(Path('/scratch/st-jiaruid-1/yinian/my_jupyter/scRNA-competition/experiments/basic-nn-multiome.yaml').read_text())
adatas = load_data_as_anndata(config['paths'], config['metadata'])

In [5]:
x_train = adatas['x']
x_test = adatas['x_test']
y_train = adatas['y']
combined_data = ad.concat([x_train, x_test])

In [6]:
cell_type_proportions = {}
hidden = sum(combined_data.obs['cell_type'] == 'hidden')
for cell_type in set(combined_data.obs['cell_type']):
    if cell_type != 'hidden':
        cell_type_proportions[cell_type] = sum(combined_data.obs['cell_type'] == cell_type) / (combined_data.shape[0] - hidden)

In [7]:
cell_type_proportions

{'MoP': 0.019454040890298466,
 'EryP': 0.16173944233637272,
 'NeuP': 0.20351701874610637,
 'MkP': 0.12479469898623775,
 'MasP': 0.15798266976270034,
 'BP': 0.005342545921353193,
 'HSC': 0.32716958335693114}

## Generate input data

In [8]:
def separate_data(data):
    cell_day_dic = {}
    for cell_type in set(data.obs['cell_type']):
        for day in set(data.obs['day']):
            cell_day_data = data[np.logical_and(data.obs['day'] == day, data.obs['cell_type'] == cell_type)]
            if cell_day_data.shape[0] == 0:
                continue
            cell_day_dic[(cell_type, day)] = cell_day_data.obs_names
    return cell_day_dic

In [9]:
def generate_sequence(pca_combined_data, y_train, cell_type, indices):
    seq = []
    for day in (2, 3, 4):
        day_indices = indices[(cell_type, day)]
        seq.append(np.random.choice(day_indices))
    return pca_combined_data[seq, :].X.toarray(), y_train[seq, :].X.toarray()

In [10]:
def generate_train_data(pca_combined_data, y_train, cell_type_proportions, num_samples=1_000_000):
    cell_types = list(cell_type_proportions.keys())
    cell_type_probs = list(cell_type_proportions.values())
    indices = separate_data(y_train)
    x_data = []
    y_data = []
    for i in range(num_samples):
        cell_type = np.random.choice(cell_types, p=cell_type_probs)
        x, y = generate_sequence(pca_combined_data, y_train, cell_type, indices)
        x_data.append(x)
        y_data.append(y)
    return np.stack(x_data, axis=0), np.stack(y_data, axis=0)

In [11]:
# generated_x_train, generated_y_train = generate_train_data(pca_combined_data, y_train, cell_type_proportions, 400_000)

### Make TensorDataset

In [12]:
class SCDataset(Dataset):
    def __init__(self, pca_combined_data, y_train, y_pca, cell_type_proportions, size, technology):
        self.pca_combined_data = pca_combined_data
        self.y_train = y_train
        self.y_pca = y_pca
        self.cell_type_proportions = cell_type_proportions
#         if technology == 'multiome':
#             self.size = size * 4
#         else:
#             self.size = size * 3
        self.size = size
        
        self.cell_types = list(cell_type_proportions.keys())
        self.cell_type_probs = list(cell_type_proportions.values())
        self.indices = separate_data(y_train)
        self.technology = technology
        if technology == 'multiome':
            self.days = (2, 3, 4, 7)
        else:
            self.days = (2, 3, 4)
        
#         self.sequences = []        
#         for i in range(size):
#             cell_type = np.random.choice(cell_types, p=cell_type_probs)
#             seq = []
#             for day in days:
#                 day_indices = indices[(cell_type, day)]
#                 seq.append(np.random.choice(day_indices))
#             self.sequences.append(seq)
    
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx):

        cell_type = np.random.choice(self.cell_types, p=self.cell_type_probs)
        seq = []
        for day in self.days:
            day_indices = self.indices[(cell_type, day)]
            seq.append(np.random.choice(day_indices))
        if self.technology == 'cite':
            y = self.y_train[sequences, :].X.toarray()
            x = self.pca_combined_data[sequences, :].X.toarray()
            y_0 = self.y_train[sequences, :].X.toarray()
            if idx % 3 == 0:
                y_0[0, :] = -2147483647
            elif idx % 3 == 1:
                y_0[1, :] = -2147483647
            else:
                y_0[2, :] = -2147483647
        else:
            y = self.y_train[seq, :].X.toarray()
            x = self.pca_combined_data[seq, :].X.toarray()
            y_0 = self.y_train[seq, :].X.toarray()
            if idx % 4 == 0:
                y_0[0, :] = -2147483647
            elif idx % 4 == 1:
                y_0[1, :] = -2147483647
            elif idx % 4 == 2:
                y_0[2, :] = -2147483647
            else:
                y_0[3, :] = -2147483647
        return x, y, y_0

In [13]:
dataset = SCDataset(combined_data, y_train, y_train, cell_type_proportions, 130000, "multiome")

# dataset = TensorDataset(torch.Tensor(generated_x_train), torch.tensor(generated_y_train))
train_num = int(len(dataset) * 4/5)
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_num, len(dataset) - train_num]
)

In [14]:
training_loader = DataLoader(train_dataset, batch_size=1000)
validation_loader = DataLoader(val_dataset, batch_size=1000)

## Transformer model

In [15]:
# class PositionalEncoding(nn.Module):
#     def __init__(self, dim_model, dropout_p, max_len):
#         super().__init__()
#         # Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
#         # max_len determines how far the position can have an effect on a token (window)
        
#         # Info
#         self.dropout = nn.Dropout(dropout_p)
        
#         # Encoding - From formula
#         pos_encoding = torch.zeros(max_len, dim_model)
#         positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) # 0, 1, 2, 3, 4, 5
#         division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
        
#         # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
#         pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        
#         # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
#         pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
#         # Saving buffer (same as parameter without gradients needed)
#         pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
#         self.register_buffer("pos_encoding",pos_encoding)
        
#     def forward(self, token_embedding: torch.tensor) -> torch.tensor:
#         # Residual connection + pos encoding
#         return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])

class SCTransformer(nn.Module):
    def __init__(self, src_dim, tgt_dim, model_dim, nhead=2, dropout=0.1):
        super().__init__()
#         self.positional_encoder = PositionalEncoding(dim_model=input_dim, dropout_p=0.1, max_len=4)
        self.transformer = Transformer(
            d_model=model_dim,
            nhead=nhead,
            batch_first=True,
            num_encoder_layers=6,
            num_decoder_layers=6,
            dropout=dropout,
        )
        self.src_processing = nn.Linear(src_dim, model_dim)
        self.tgt_processing = nn.Linear(tgt_dim, model_dim)
        self.means = nn.Linear(model_dim, tgt_dim)
        self.dispersions = nn.Linear(model_dim, tgt_dim)
    
    def forward(self, src, tgt):
#         src = self.positional_encoder(src)
#         tgt = self.positional_encoder(tgt)
        src_vec = self.src_processing(src)
        tgt_vec = self.tgt_processing(tgt)
        transformer_out = self.transformer(src_vec, tgt_vec)
        mean = self.means(transformer_out)
        dispersion = self.dispersions(transformer_out)
        return torch.exp(torch.stack([mean, dispersion], dim=-1))

In [16]:
class NegativeBinomialLoss(nn.Module):
    def forward(self, y_pred, y_true, eps=1e-10):
        y_pred, theta = torch.unbind(y_pred, dim=-1)
        theta = torch.clamp(theta, max=1e6)

        t1 = (
            torch.lgamma(theta + eps)
            + torch.lgamma(y_true + 1.0)
            - torch.lgamma(y_true + theta + eps)
        )
        t2 = (theta + y_true) * torch.log1p(y_pred / (theta + eps)) + (
            y_true * (torch.log(theta + eps) - torch.log(y_pred + eps))
        )
        return torch.mean(t1 + t2)

In [17]:
model = SCTransformer(228942, 23418, model_dim=1024, nhead=2)
if cuda:
    model.to('cuda')
loss_fn = NegativeBinomialLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00005)

In [18]:
# model.load_state_dict(torch.load('/scratch/st-jiaruid-1/yinian/my_jupyter/output/transformer/transformer_state4'))

### Train the Transformer

In [19]:
def correlation_score(y_true, y_pred):
    """Scores the predictions according to the competition rules.

    It is assumed that the predictions are not constant.

    Returns the average of each sample's Pearson correlation coefficient

    Source: https://www.kaggle.com/code/xiafire/lb-t15-msci-multiome-catboostregressor#Predicting
    """
    if y_true.shape != y_pred.shape:
        raise ValueError("Shapes are different.")
    corrsum = 0
    for i in range(len(y_true)):
        corrsum += np.corrcoef(y_true[i], y_pred[i])[1, 0]
    return corrsum / len(y_true)

In [20]:
def corrcoeff(y_pred, y_true):
    '''Pearson Correlation Coefficient
    Implementation in Torch, without shifting to cpu, detach, numpy (consumes time)
    '''
    y_true_ = y_true - torch.mean(y_true, 1, keepdim=True)
    y_pred_ = y_pred - torch.mean(y_pred, 1, keepdim=True)

    num = (y_true_ * y_pred_).sum(1, keepdim=True)
    den = torch.sqrt(((y_pred_ ** 2).sum(1, keepdim=True)) * ((y_true_ ** 2).sum(1, keepdim=True)))

    return (num/den).mean()

In [21]:
def train_one_epoch(model, training_loader, epoch_index, loss_fn, optimizer):
    """
    Literally the most basic training epoch
    """
    running_loss = 0.0
    last_loss = 0.0
    model.train()

    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels, masked_labels = data

        if cuda:
            inputs = inputs.to("cuda")
            labels = labels.to("cuda")
            masked_labels = masked_labels.to("cuda")

        optimizer.zero_grad()
#         print(inputs.shape, masked_labels.shape)
        outputs = model(inputs, masked_labels)

        loss = loss_fn(outputs, labels)
        loss.backward()
        
#         res = outputs.detach().cpu().numpy() @ pca_y_source.components_
#         labels_orig = labels_orig.detach().cpu().numpy()

        optimizer.step()

        curr_loss = loss.item()
        running_loss += curr_loss
        print("  batch {} loss: {}".format(i + 1, curr_loss))

        del inputs, labels, masked_labels
        gc.collect()

    return running_loss / (i + 1)

In [36]:
for epoch in range(3, 10):
    avg_loss = train_one_epoch(model, training_loader, epoch, loss_fn, optimizer)
    
    model.eval()
    with torch.no_grad():
        running_vcorr = 0.0
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels, vmaskedlabels = vdata
            if cuda:
                vinputs = vinputs.to("cuda")
                vlabels = vlabels.to("cuda")
                vmaskedlabels = vmaskedlabels.to("cuda")
            voutputs = model(vinputs, vmaskedlabels)
            
#             res = voutputs.detach().cpu().numpy() @ pca_y_source.components_
            
            vcorr = 0
            for j in range(voutputs.shape[1]):
                vcorr += corrcoeff(voutputs[:, j, :, 0], vlabels[:, j, :])
            vcorr /= j + 1
            
            torch.save(model.state_dict(), f'/scratch/st-jiaruid-1/yinian/my_jupyter/output/transformer/multi_transformer_state{epoch+1}')

            running_vcorr += vcorr
#             del vinputs, vlabels
            gc.collect()


    print(f"EPOCH: {epoch} MSE loss: {avg_loss}\nCORR: {running_vcorr / (i+1)}")

  batch 1 loss: 0.7209523320198059
  batch 2 loss: 0.7175948023796082
  batch 3 loss: 0.7194600701332092
  batch 4 loss: 0.7199166417121887
  batch 5 loss: 0.7172943949699402
  batch 6 loss: 0.7174321413040161
  batch 7 loss: 0.7152698040008545
  batch 8 loss: 0.7176268696784973
  batch 9 loss: 0.7152054309844971
  batch 10 loss: 0.7156948447227478
  batch 11 loss: 0.7141604423522949
  batch 12 loss: 0.7167066335678101
  batch 13 loss: 0.714340090751648
  batch 14 loss: 0.7179343700408936
  batch 15 loss: 0.7167294025421143
  batch 16 loss: 0.7178642153739929
  batch 17 loss: 0.7180213332176208
  batch 18 loss: 0.7170811891555786
  batch 19 loss: 0.7130610942840576
  batch 20 loss: 0.717948317527771
  batch 21 loss: 0.7188989520072937
  batch 22 loss: 0.7183571457862854
  batch 23 loss: 0.7177048325538635
  batch 24 loss: 0.7153449058532715
  batch 25 loss: 0.7152818441390991
  batch 26 loss: 0.7166620492935181
  batch 27 loss: 0.717425525188446
  batch 28 loss: 0.7180020213127136
  ba

  batch 19 loss: 0.7161287069320679
  batch 20 loss: 0.7173123955726624
  batch 21 loss: 0.7148120403289795
  batch 22 loss: 0.7126074433326721
  batch 23 loss: 0.713154673576355
  batch 24 loss: 0.7147067189216614
  batch 25 loss: 0.7127820253372192
  batch 26 loss: 0.7137017250061035
  batch 27 loss: 0.7124190330505371
  batch 28 loss: 0.7162707448005676
  batch 29 loss: 0.7124528884887695
  batch 30 loss: 0.7153272032737732
  batch 31 loss: 0.7124327421188354
  batch 32 loss: 0.7151191234588623
  batch 33 loss: 0.7146115899085999
  batch 34 loss: 0.7114636898040771
  batch 35 loss: 0.7129424214363098
  batch 36 loss: 0.7120429277420044
  batch 37 loss: 0.7127799987792969
  batch 38 loss: 0.7134936451911926
  batch 39 loss: 0.7158252000808716
  batch 40 loss: 0.7157915234565735
  batch 41 loss: 0.7160847187042236
  batch 42 loss: 0.7134150862693787
  batch 43 loss: 0.7136036157608032
  batch 44 loss: 0.7113150954246521
  batch 45 loss: 0.7131789326667786
  batch 46 loss: 0.7115775346

  batch 36 loss: 0.7161281108856201
  batch 37 loss: 0.7108278274536133
  batch 38 loss: 0.7091253399848938
  batch 39 loss: 0.7133564949035645
  batch 40 loss: 0.7140417098999023
  batch 41 loss: 0.7182787656784058
  batch 42 loss: 0.7107496857643127
  batch 43 loss: 0.7087056040763855
  batch 44 loss: 0.7125942707061768
  batch 45 loss: 0.7165704965591431
  batch 46 loss: 0.7112354636192322
  batch 47 loss: 0.7083609104156494
  batch 48 loss: 0.7129514813423157
  batch 49 loss: 0.7136248350143433
  batch 50 loss: 0.7103448510169983
  batch 51 loss: 0.7121352553367615
  batch 52 loss: 0.7144088745117188
  batch 53 loss: 0.7117387056350708
  batch 54 loss: 0.7115164995193481
  batch 55 loss: 0.7135868668556213
  batch 56 loss: 0.71024489402771
  batch 57 loss: 0.7149139642715454
  batch 58 loss: 0.712788462638855
  batch 59 loss: 0.711177408695221
  batch 60 loss: 0.7121233940124512
  batch 61 loss: 0.7124159932136536
  batch 62 loss: 0.7139756679534912
  batch 63 loss: 0.7116121053695

  batch 53 loss: 0.7100746035575867
  batch 54 loss: 0.7080588340759277
  batch 55 loss: 0.7111784815788269
  batch 56 loss: 0.7080315947532654
  batch 57 loss: 0.7112199664115906
  batch 58 loss: 0.7078735828399658
  batch 59 loss: 0.7092387080192566
  batch 60 loss: 0.7092772722244263
  batch 61 loss: 0.711731493473053
  batch 62 loss: 0.709252119064331
  batch 63 loss: 0.710884153842926
  batch 64 loss: 0.7084475755691528
  batch 65 loss: 0.7120559215545654
  batch 66 loss: 0.7102577090263367
  batch 67 loss: 0.7098103165626526
  batch 68 loss: 0.7108051180839539
  batch 69 loss: 0.7062185406684875
  batch 70 loss: 0.7094378471374512
  batch 71 loss: 0.7110808491706848
  batch 72 loss: 0.7103355526924133
  batch 73 loss: 0.7110134363174438
  batch 74 loss: 0.7083894610404968
  batch 75 loss: 0.7102844715118408
  batch 76 loss: 0.7091796398162842
  batch 77 loss: 0.7091083526611328
  batch 78 loss: 0.7122970819473267
  batch 79 loss: 0.708784282207489
  batch 80 loss: 0.7072103619575

In [26]:
out = voutputs.cpu()

In [27]:
out = out[:, :, :, 0]

In [34]:
vcorr = 0
for j in range(voutputs.shape[1]):
    vcorr += corrcoeff(out[:, j, :], vlabels[:, j, :].cpu())
vcorr /= j + 1

In [35]:
vcorr

tensor(0.6307)

In [83]:
torch.save(model.state_dict(), '/scratch/st-jiaruid-1/yinian/my_jupyter/output/transformer/transformer_state')

## Generate test data

For each test data point, generate a bunch sequences of that cell type and use the average prediction.

In [21]:
def get_days(day):
    if day == 2:
        return (-1, 3, 4)
    elif day == 3:
        return (2, -1, 4)
    elif day == 4:
        return (2, 3, -1)
    elif day == 7:
        return (2, 3, 4, -1)
    
def gen_one_test_sequence(x_test_idx, cell_type, sequence_days, pca_combined_data, y_train, indices):
    """
    Generate a test sequence of a particular cell tpe incorporating the test data point
    
    Parameters:
    - x_test_idx: the index of the test data point in <pca_combined_data>
    - cell_type: the cell type of the test data point
    - sequence_days: the sequence of days to be taken, -1 means the test data point will be substituted in there
    - pca_combined_data: the train+test dataset
    - indices: indices of each cell grouped by day and cell_type
    """
    seq = []
    y_seq = []
    for day in sequence_days:
        if day == -1:
            seq.append(pca_combined_data[x_test_idx].obs_names[0])
            y_seq.append(np.zeros((1, 140)))
        else:
            day_indices = indices[(cell_type, day)]
            day_choice = np.random.choice(day_indices)
            seq.append(day_choice)
            y_seq.append(y_train[day_choice, :].X.toarray())
    return pca_combined_data[seq, :].X.toarray(), np.concatenate(y_seq)
    

def test_sequence(pca_combined_data, x_test, y_train, idx, num_samples=1000):
    """
    Generate <num_samples> sequence of cells across all days that include the test data point
    at index <idx> of <x_test>.
    """
    cell = x_test[idx]
    cell_type = cell.obs['cell_type'][0]
    sequence_days = get_days(cell.obs['day'][0])
    indices = separate_data(y_train)
    test_set_x = []
    test_set_y = []
    for i in range(num_samples):
        x, y = gen_one_test_sequence(len(y_train) + idx, cell_type, sequence_days, pca_combined_data, y_train, indices)
        test_set_x.append(x)
        test_set_y.append(y)
    return np.stack(test_set_x, axis=0), np.stack(test_set_y, axis=0)

In [38]:
for i in range(0, len(x_test), 1000):
    ts_x, ts_y = test_sequence(pca_combined_data, x_test, y_train, i, num_samples=100)
    ts_x, ts_y = torch.Tensor(ts_x), torch.Tensor(ts_y)
    res = model(ts_x, ts_y).detach().cpu().numpy()
    print(i)

0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000


In [23]:
torch.exp(dis.log_prob(x))

tensor([[5.9049e-01, 4.2220e-07, 1.9486e-05,  ..., 2.9525e-01, 2.9229e-06,
         4.1334e-03],
        [7.4402e-04, 7.4402e-04, 7.4402e-04,  ..., 1.9486e-05, 2.9525e-01,
         2.9525e-01],
        [4.2220e-07, 7.4402e-04, 2.9525e-01,  ..., 1.2400e-04, 8.8574e-02,
         4.2220e-07],
        ...,
        [7.4402e-04, 4.1334e-03, 2.9229e-06,  ..., 2.9229e-06, 1.9486e-05,
         5.9049e-01],
        [2.9525e-01, 2.9229e-06, 1.2400e-04,  ..., 2.9525e-01, 2.9525e-01,
         2.0667e-02],
        [2.9525e-01, 7.4402e-04, 4.1334e-03,  ..., 5.9049e-01, 4.1334e-03,
         4.2220e-07]])