# This is an upgraded version of the scBERT

## We are planning to incorporate these things:

### Improvements to the Encoder block
1. Grouped Multi Query Attention
2. RMS Norm in place of LayerNorm for faster training
3. Flash attention 2.0
4. SwiGLU/SiLU in place of ReLU/GLU - done
5. Gene coexpression

### Improvements to improve parameter count while reducing computational cost
1. Mixtral of Experts

### Improvements to training stratergy
2. Improved Token Embeddings
1. Improved masking

### For Faster training
1. Mixed precision training - done
2. Distributed Data Parallel Training - done
3. Faster Data Loading using MultDL
4. Adafactor - https://huggingface.co/docs/transformers/main/en/perf_train_gpu_one#optimizer-choice
5. Torch compile - https://huggingface.co/docs/transformers/main/en/perf_train_gpu_one#using-torchcompile
6. Data preloading - done

In [1]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [2]:
%pip install mixture_of_experts
%pip install scanpy
%pip install accelerate

Collecting mixture_of_experts
  Downloading mixture_of_experts-0.2.3-py3-none-any.whl (6.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->mixture_of_experts)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->mixture_of_experts)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->mixture_of_experts)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->mixture_of_experts)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->mixture_of_experts)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch->mixture_of_experts)
  Using cached nvidi

In [3]:
import scanpy as sc
from sklearn.model_selection import train_test_split
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from mixture_of_experts import MoE
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader
import random
from tqdm import tqdm
from functools import reduce
from torch.optim import AdamW
import math
import matplotlib.pyplot as plt
from accelerate import Accelerator
import pickle as pkl
from sklearn.model_selection import train_test_split, ShuffleSplit, StratifiedShuffleSplit, StratifiedKFold
import pandas as pd

In [4]:
accelerator = Accelerator(mixed_precision='fp16')

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
torch.backends.cuda.enable_flash_sdp(True)

In [7]:
class FlashAttentionBlock(nn.Module):
    def __init__(self, d_model: int, h: int) -> None:
        super().__init__()
        self.d_model = d_model
        self.h = h
        assert d_model % h ==0, "d_model is not divisble by h"
        self.d_k = d_model // h

        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)


    def forward(self, q, k, v):
        query = self.w_q (q) # (batch, seq_len, d_model) -> (batch, seq_len, d_model)
        key = self.w_k(k) # (batch, seq_len, d_model) -> (batch, seq_len, d_model)
        value = self.w_v(v) # (batch, seq_len, d_model) -> (batch, seq_len, d_model)

        # Test code
        # query = q
        # key = k
        # value = v

        # (batch, seq_len, d_model) -> (Batch, seq_len, h, d_k) -> (Batch, h, seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1,2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1,2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1,2)

        x = F.scaled_dot_product_attention(query,key,value, dropout_p=0.1)

        # (Batch, h, seq_len, d_k) -> (Batch, seq_len, h, d_k) -> (Batch, seq_len, d_model)
        x =  x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)

        # (Batch, seq_len, d_model)  -> (Batch, seq_len, d_model)
        return self.w_o(x)

In [8]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

In [9]:
moe = MoE(
    dim = 200,
    num_experts =  4,               # increase the experts (# parameters) of your model without increasing computation
    hidden_dim = 200 * 4,           # size of hidden dimension in each expert, defaults to 4 * dimension
    activation = nn.SiLU,      # use your preferred activation, will default to GELU
    second_policy_train = 'random', # in top_2 gating, policy for whether to use a second-place expert
    second_policy_eval = 'random',  # all (always) | none (never) | threshold (if gate value > the given threshold) | random (if gate value > threshold * random_uniform(0, 1))
    second_threshold_train = 0.2,
    second_threshold_eval = 0.2,
    capacity_factor_train = 1.25,   # experts have fixed capacity per batch. we need some extra capacity in case gating is not perfectly balanced.
    capacity_factor_eval = 2.,      # capacity_factor_* should be set to a value >=1
    loss_coef = 1e-2                # multiplier on the auxiliary expert balancing auxiliary loss
)

In [10]:
class Encoder(nn.Module):
  def __init__(self, local_heads, d_model, hidden_ff_model):
    super().__init__()
    # Embedding dimension = 200, Local_Attention heads = 10
    self.attention = FlashAttentionBlock(d_model= d_model, h=local_heads)
    self.attention_norm = RMSNorm(dim =d_model)
    self.ff_norm = RMSNorm(dim=d_model)
    self.feed_forward = MoE(dim =d_model, num_experts=8, hidden_dim= hidden_ff_model,activation = nn.SiLU, second_policy_train = 'random', second_policy_eval = 'random', second_threshold_train = 0.2, second_threshold_eval = 0.2, capacity_factor_train = 1.25,capacity_factor_eval = 2., loss_coef = 1e-2)


  def forward(self, x):
    x_normed = self.attention_norm(x)
    # print(x_normed.shape)
    r = self.attention(x_normed, x_normed, x_normed)
    h = x + r
    r, _ = self.feed_forward(self.ff_norm(h))
    out = h + r
    return out

In [11]:
class Gene2VecPositionalEmbedding(nn.Module):
    def __init__(self):
        super().__init__()

        gene2vec_weight = np.load('/content/drive/MyDrive/scFasterBERT/data/gene2vec_16906.npy')
        gene2vec_weight = np.concatenate((gene2vec_weight, np.zeros((1, gene2vec_weight.shape[1]))), axis=0)
        gene2vec_weight = torch.from_numpy(gene2vec_weight)
        self.emb = nn.Embedding.from_pretrained(gene2vec_weight)

    def forward(self, x):
        t = torch.arange(x.shape[1], device = device)
        return self.emb(t)

In [12]:
#max_seq_len =16907
class scBERT2(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_emb = nn.Embedding(7, 200)
        self.pos_emb = Gene2VecPositionalEmbedding()
        self.layer1 = Encoder(local_heads=10,d_model=200, hidden_ff_model=400)
        self.layer2 = Encoder(local_heads=10,d_model=200, hidden_ff_model=400)
        self.layer3 = Encoder(local_heads=10,d_model=200, hidden_ff_model=400)
        self.layer4 = Encoder(local_heads=10,d_model=200, hidden_ff_model=400)
        self.layer5 = Encoder(local_heads=10,d_model=200, hidden_ff_model=400)
        self.layer6 = Encoder(local_heads=10,d_model=200, hidden_ff_model=400)
        self.norm = RMSNorm(200)
        self.classifier = nn.Linear(in_features=200, out_features=7)

    def forward(self, x):
    # x = x.type(torch.int32)
        pos_emb = self.pos_emb(x)
        x = self.token_emb(x.int())
        # print(x.shape)
        x += pos_emb
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = self.classifier(self.norm(x))
        return x

In [13]:
scbert = scBERT2()

In [14]:
sum(p.numel() for p in scbert.parameters() if p.requires_grad)

8659807

In [15]:
scbert.load_state_dict(torch.load('/content/drive/MyDrive/scFasterBERT/Final_models/scbert2_epoch1.pth'))

<All keys matched successfully>

In [16]:
class Identity(torch.nn.Module):
    def __init__(self, dropout = 0., h_dim = 100, out_dim = 10):
        super(Identity, self).__init__()
        self.conv1 = nn.Conv2d(1, 1, (1, 200))
        self.act = nn.ReLU()
        self.fc1 = nn.Linear(in_features=SEQ_LEN, out_features=512, bias=True)
        self.act1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(in_features=512, out_features=h_dim, bias=True)
        self.act2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)
        self.fc3 = nn.Linear(in_features=h_dim, out_features=out_dim, bias=True)

    def forward(self, x):
        x = x[:,None,:,:]
        x = self.conv1(x)
        x = self.act(x)
        x = x.view(x.shape[0],-1)
        x = self.fc1(x)
        x = self.act1(x)
        x = self.dropout1(x)
        x = self.fc2(x)
        x = self.act2(x)
        x = self.dropout2(x)
        x = self.fc3(x)
        return x

# Training

In [17]:
SEED = 2021
EPOCHS = 10
BATCH_SIZE = 1
LEARNING_RATE = 1e-4
GRADIENT_ACCUMULATION = 60
SEQ_LEN = 16907
VALIDATE_EVERY = 1
CLASS = 7
MASK_PROB = 0.15
REPLACE_PROB = 0.9
RANDOM_TOKEN_PROB = 0.
MASK_TOKEN_ID = CLASS - 1
PAD_TOKEN_ID = CLASS - 1
MASK_IGNORE_TOKEN_IDS = [0]
POS_EMBED_USING = True

In [18]:
class SCDataset(Dataset):
    def __init__(self, data, label):
        super().__init__()
        self.data = data
        self.label = label

    def __getitem__(self, index):
        rand_start = random.randint(0, self.data.shape[0]-1)
        full_seq = self.data[rand_start].toarray()[0]
        full_seq[full_seq > (CLASS - 2)] = CLASS - 2
        full_seq = torch.from_numpy(full_seq).long()
        full_seq = torch.cat((full_seq, torch.tensor([0]))).to(device)
        seq_label = self.label[rand_start]
        return full_seq, seq_label

    def __len__(self):
        return self.data.shape[0]


In [19]:
data = sc.read_h5ad('/content/drive/MyDrive/scFasterBERT/Final_models/preprocessed_data.h5ad')
label_dict, label = np.unique(np.array(data.obs['celltype']), return_inverse=True)
with open('label_dict', 'wb') as fp:
    pkl.dump(label_dict, fp)
with open('label', 'wb') as fp:
    pkl.dump(label, fp)
class_num = np.unique(label, return_counts=True)[1].tolist()
class_weight = torch.tensor([(1 - (x / sum(class_num))) ** 2 for x in class_num])
data = data.X

acc = []
f1 = []
f1w = []
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
pred_list = pd.Series(['un'] * data.shape[0])

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=SEED)

for index_train, index_val in sss.split(data, label):
   data_train, label_train = data[index_train], label[index_train]
   data_val, label_val = data[index_val], label[index_val]
   train_dataset = SCDataset(data_train, label_train)
   val_dataset = SCDataset(data_val, label_val)



In [20]:
for params in scbert.parameters():
  params.requires_grad=False

In [21]:
scbert.classifier = Identity(dropout=0., h_dim=128, out_dim=label_dict.shape[0])
scbert = scbert.to(device)

In [22]:
optimizer = AdamW(scbert.parameters(), lr=LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss(ignore_index = PAD_TOKEN_ID, reduction='mean').to(device)
softmax = nn.Softmax(dim=-1)
scaler = torch.cuda.amp.GradScaler()

In [23]:
for i in range(1, EPOCHS+1):
    train_loader.set_epoch(i)
    model.train()
    running_loss = 0.0
    cum_acc = 0.0
    for index, (data, labels) in enumerate(train_loader):
        index += 1
        data, labels = data.to(device), labels.to(device)
        if index % GRADIENT_ACCUMULATION != 0:
            with model.no_sync():
                logits = model(data)
                loss = loss_fn(logits, labels)
                loss.backward()
        if index % GRADIENT_ACCUMULATION == 0:
            logits = model(data)
            loss = loss_fn(logits, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), int(1e6))
            optimizer.step()
            optimizer.zero_grad()
        running_loss += loss.item()
        softmax = nn.Softmax(dim=-1)
        final = softmax(logits)
        final = final.argmax(dim=-1)
        pred_num = labels.size(0)
        correct_num = torch.eq(final, labels).sum(dim=-1)
        cum_acc += torch.true_divide(correct_num, pred_num).mean().item()
    epoch_loss = running_loss / index
    epoch_acc = 100 * cum_acc / index
    epoch_loss = get_reduced(epoch_loss, local_rank, 0, world_size)
    epoch_acc = get_reduced(epoch_acc, local_rank, 0, world_size)
    if is_master:
        print(f'    ==  Epoch: {i} | Training Loss: {epoch_loss:.6f} | Accuracy: {epoch_acc:6.4f}%  ==')
    dist.barrier()
    scheduler.step()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        dist.barrier()
        running_loss = 0.0
        predictions = []
        truths = []
        with torch.no_grad():
            for index, (data_v, labels_v) in enumerate(val_loader):
                index += 1
                data_v, labels_v = data_v.to(device), labels_v.to(device)
                logits = model(data_v)
                loss = loss_fn(logits, labels_v)
                running_loss += loss.item()
                softmax = nn.Softmax(dim=-1)
                final_prob = softmax(logits)
                final = final_prob.argmax(dim=-1)
                final[np.amax(np.array(final_prob.cpu()), axis=-1) < UNASSIGN_THRES] = -1
                predictions.append(final)
                truths.append(labels_v)
            del data_v, labels_v, logits, final_prob, final
            # gather
            predictions = torch.cat(predictions, dim=0)
            truths = torch.cat(truths, dim=0)
            no_drop = predictions != -1
            predictions = np.array((predictions[no_drop]).cpu())
            truths = np.array((truths[no_drop]).cpu())
            cur_acc = accuracy_score(truths, predictions)
            f1 = f1_score(truths, predictions, average='macro')
            val_loss = running_loss / index
            val_loss = get_reduced(val_loss, local_rank, 0, world_size)
            if is_master:
                print(f'    ==  Epoch: {i} | Validation Loss: {val_loss:.6f} | F1 Score: {f1:.6f}  ==')
                print(confusion_matrix(truths, predictions))
                print(classification_report(truths, predictions, target_names=label_dict.tolist(), digits=4))
            if cur_acc > max_acc:
                max_acc = cur_acc
                trigger_times = 0
                save_best_ckpt(i, model, optimizer, scheduler, val_loss, model_name, ckpt_dir)
            else:
                trigger_times += 1
                if trigger_times > PATIENCE:
                    break
    del predictions, truths

NameError: name 'train_loader' is not defined

In [None]:
def train2(model):
  softmax = nn.Softmax(dim=-1)
  train_losses =[]
  train_accuracies = []
  valid_losses =[]
  valid_accuracies = []
  for i in range(1, EPOCHS+1):
      model.train()
      running_loss = 0.0
      cum_acc = 0.0
      for index, data in enumerate(tqdm(train_loader)):
          index += 1
          data = data.to(device)
          data, labels = data_mask(data)
          logits = model(data)
          loss = loss_fn(logits.transpose(1, 2), labels)/ GRADIENT_ACCUMULATION
          accelerator.backward(loss)
          if index % GRADIENT_ACCUMULATION == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), int(1e2))
            optimizer.step()
            optimizer.zero_grad()
          running_loss += loss.item()
          final = softmax(logits)[..., 1:-1]
          final = final.argmax(dim=-1) + 1
          pred_num = (labels != PAD_TOKEN_ID).sum(dim=-1)
          correct_num = ((labels != PAD_TOKEN_ID) * (final == labels)).sum(dim=-1)
          cum_acc += torch.true_divide(correct_num, pred_num).mean().item()
      epoch_loss = running_loss / index
      epoch_acc = 100 * cum_acc / index
      train_losses.append(epoch_loss)
      train_accuracies.append(epoch_acc)
      print(f'    ==  Epoch: {i} | Training Loss: {epoch_loss:.6f} | Accuracy: {epoch_acc:6.4f}%  ==')

      if i % VALIDATE_EVERY == 0:
          model.eval()
          running_loss = 0.0
          predictions = []
          truths = []
          with torch.no_grad():
              for index, data in enumerate(tqdm(val_loader)):
                  index += 1
                  data = data.to(device)
                  data, labels = data_mask(data)
                  logits = model(data)
                  loss = loss_fn(logits.transpose(1, 2), labels)
                  running_loss += loss.item()
                  softmax = nn.Softmax(dim=-1)
                  final = softmax(logits)[..., 1:-1]
                  final = final.argmax(dim=-1) + 1
                  predictions.append(final)
                  truths.append(labels)
          val_loss = running_loss / index
          correct_num = ((torch.cat(truths, dim=0) != PAD_TOKEN_ID) * (torch.cat(predictions, dim=0) == torch.cat(truths, dim=0))).sum().item()
          val_num = (torch.cat(truths, dim=0) != PAD_TOKEN_ID).sum().item()
          val_acc = 100 * correct_num / val_num
          valid_losses.append(val_loss)
          valid_accuracies.append(val_acc)
          print(f'    ==  Epoch: {i} | Validation Loss: {val_loss:.6f} | Accuracy: {val_acc:6.4f}%  ==')
      torch.save(model.state_dict(), f'scbert2_epoch{i}.pth')

  return train_losses, train_accuracies, valid_losses, valid_accuracies

In [None]:
train_losses, train_accuracies, valid_losses, valid_accuracies = train2(scbert)

In [None]:
def plot_graphs(train_losses, train_accuracies, valid_losses, valid_accuracies):
  plt.plot(train_losses)
  plt.plot(train_accuracies)
  plt.plot(valid_losses)
  plt.plot(valid_accuracies)
  plt.legend(['train_loss', 'train_accuracy', 'valid_loss', 'valid_accuracy'])
  plt.title('Panglao_human')
  plt.show()

In [None]:
plot_graphs(train2.train_losses, train2.train_accuracies, train2.valid_losses, train2.valid_accuracies)