# This is an upgraded version of the scBERT

## We are planning to incorporate these things:

### Improvements to the Encoder block
1. Grouped Multi Query Attention
2. RMS Norm in place of LayerNorm for faster training
3. Flash attention 2.0
4. SwiGLU/SiLU in place of ReLU/GLU - done
5. Gene coexpression

### Improvements to improve parameter count while reducing computational cost
1. Mixtral of Experts

### Improvements to training stratergy
2. Improved Token Embeddings
1. Improved masking

### For Faster training
1. Mixed precision training - done
2. Distributed Data Parallel Training - done
3. Faster Data Loading using MultDL
4. Adafactor - https://huggingface.co/docs/transformers/main/en/perf_train_gpu_one#optimizer-choice
5. Torch compile - https://huggingface.co/docs/transformers/main/en/perf_train_gpu_one#using-torchcompile
6. Data preloading - done

# Installing Necessary libraries

## Installing Flash Attention 2.0

In [None]:
!pip uninstall -y ninja && pip install ninja
!pip install packaging
!pip install flash-attn --no-build-isolation
# !MAX_JOBS=4 pip install flash-attn --no-build-isolation if machine has less than 96 GB of RAM

## HF Accelerate

In [None]:
!pip install accelerate

# Importing necessary libraries

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [None]:
from accelerate import Accelerator
from accelerate import notebook_launcher
accelerator = Accelerator(mixed_precision='fp16', gradient_accumulation_steps=60)

# Model architecture

## RMS Norm

In [None]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

In [None]:
rmsnorm = RMSNorm(dim=1)
x = torch.tensor([[10,0,3]])
print(x.shape)
rmsnorm(x)

## FeedForward Layer

In [None]:
class FeedForwardBlock(nn.Module):
    def __init__(self, args_dim, hidden_dim):
        super().__init__()

        self.w1 = nn.Linear(args_dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, args_dim, bias=False)
        self.w3 = nn.Linear(args_dim, hidden_dim, bias=False)

    def forward(self, x) -> torch.Tensor:
        return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))

In [None]:
# test_ff = FeedForwardBlock(2,1)
# test_ff(x)

## Sparse Mixture of Experts block

In [None]:
class MoeLayer(nn.Module):
    def __init__(self, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok):
        super().__init__()
        assert len(experts) > 0
        self.experts = nn.ModuleList(experts)
        self.gate = gate
        self.num_experts_per_tok = num_experts_per_tok

    def forward(self, inputs: torch.Tensor):
        # For each token, generate `num_experts` logits indicating which expert to use.
        gate_logits = self.gate(inputs)

        # For each token, select the top `num_experts_per_tok` experts, and use them to compute
        weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)

        # Apply the softmax to the logits AFTER selecting the top-k, this makes comparison with different hyperparams consitent.
        # Because even if we change the total number of experts or the number of experts per token, the sum of the weights will still be 1 for each token.
        weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype)

        results = torch.zeros_like(inputs)
        for current_expert_index, current_expert in enumerate(self.experts):
            # For each expert, select which token it will be applied to.
            token_index, token_expert_index = torch.where(selected_experts == current_expert_index)
            # Apply the expert to the selected tokens weighting it by the logits (post-softmax) computed above.
            results[token_index] += weights[token_index, token_expert_index, None] * current_expert(
                inputs[token_index]
            )
        return results

## Attention Block

In [None]:
class SelfAttention(nn.Module):
  def __init__(self, embed_size, heads):
    super().__init__()
    self.embed_size = embed_size
    self.head_dim = embed_size // heads

    assert (self.head_dim * heads == embed_size), "Embed size needs to be divided by heads"

    self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.queries = nn.Linear(self.head_dim, self. head_dim, bias=False)
    self.fc_out = nn.Linear(heads*)


## Encoder Block

In [None]:
class Encoder(nn.Module):
  def __init__(self, embed_size, heads):
    super().__init__()
    self.attention = SelfAttention()
    self.attention_norm = RMSNorm()
    self.ff_norm = RMSNorm()
    self.feed_forward = MoeLayer(experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)], gate=nn.Linear(args.dim, args.moe.num_experts=8, bias=False), moe_args=args.moe)


  def forward(self, x):
    r = self.attention(self.attention_norm(x))
    h = x + r
    x = self.feed_forward(self.ffn_norm(h))
    out = h + r
    return out



## SCBERT 2.0

In [None]:
class scBERT2(nn.Module):
  def __init__(self):
    super().__init__():

  def forward(self, x):


# Masking Strategy

In [None]:
# get the random prob matrix and True means smaller than prob threshold
def prob_mask_like(t, prob):
    return torch.zeros_like(t).float().uniform_(0, 1) < prob

# get the mask matrix which cannot be masked
def mask_with_tokens(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    return mask

def get_mask_subset_with_prob(mask, prob):
    batch, seq_len, device = *mask.shape, mask.device
    max_masked = math.ceil(prob * seq_len)      # num of mask of a single sequence in average
    num_tokens = mask.sum(dim=-1, keepdim=True)     # num of pure tokens of each sequence except special tokens
    mask_excess = torch.cat((torch.zeros(0), torch.arange(mask.size(-1)).repeat(mask.size(0)))).reshape(mask.size(0),mask.size(-1)).to(device)
    mask_excess = (mask_excess >= (num_tokens * prob).ceil())        # only 15% of pure tokens can be masked
    mask_excess = mask_excess[:, :max_masked]       # get difference between 15% of pure tokens and 15% of all tokens
    rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)     # rand (0-1) as prob, special token use -1e9
    _, sampled_indices = rand.topk(max_masked, dim=-1)      # get index of topk prob to mask
    sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)        # delete difference of mask not pure
    new_mask = torch.zeros((batch, seq_len + 1), device=device)     # get (batch, seq_len) shape zero matrix
    new_mask.scatter_(-1, sampled_indices, 1)       # set masks in zero matrix as 1
    return new_mask[:, 1:].bool()       # the final mask, True is mask

def data_mask(data,
    mask_prob = MASK_PROB,
    replace_prob = REPLACE_PROB,
    num_tokens = None,
    random_token_prob = RANDOM_TOKEN_PROB,
    mask_token_id = MASK_TOKEN_ID,
    pad_token_id = PAD_TOKEN_ID,
    mask_ignore_token_ids = MASK_IGNORE_TOKEN_IDS
):
    mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id])
    # do not mask [pad] tokens, or any other tokens in the tokens designated to be excluded ([cls], [sep])
    # also do not include these special tokens in the tokens chosen at random
    no_mask = mask_with_tokens(data, mask_ignore_token_ids)   # ignore_token as True, will not be masked later
    mask = get_mask_subset_with_prob(~no_mask, mask_prob)      # get the True/False mask matrix
    # get mask indices
    ## mask_indices = torch.nonzero(mask, as_tuple=True)   # get the index of mask(nonzero value of mask matrix)
    # mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob)
    masked_input = data.clone().detach()
    # if random token probability > 0 for mlm
    if random_token_prob > 0:
        assert num_tokens is not None, 'num_tokens keyword must be supplied when instantiating MLM if using random token replacement'
        random_token_prob = prob_mask_like(data, random_token_prob)       # get the mask matrix of random token replace
        random_tokens = torch.randint(0, num_tokens, data.shape, device=data.device)     # generate random token matrix with the same shape as in
        random_no_mask = mask_with_tokens(random_tokens, mask_ignore_token_ids)        # not masked matrix for the random token matrix
        random_token_prob &= ~random_no_mask        # get the pure mask matrix of random token replace
        random_indices = torch.nonzero(random_token_prob, as_tuple=True)        # index of random token replace
        masked_input[random_indices] = random_tokens[random_indices]        # replace some tokens by random token
    # [mask] input
    replace_prob = prob_mask_like(data, replace_prob)     # get the mask matrix of token being masked
    masked_input = masked_input.masked_fill(mask * replace_prob, mask_token_id)        # get the data has been masked by mask_token
    # mask out any tokens to padding tokens that were not originally going to be masked
    labels = data.masked_fill(~mask, pad_token_id)        # the label of masked tokens
    return masked_input, labels

# Embeddings

## Token Embedding

## "Positional Embedding"

# Dataset and DataLoader

## Low RAM

In [None]:

data_path = '../data/panglao_human.h5ad'

In [None]:
class SCDataset(Dataset):
    def __init__(self, file_path, indices):
        self.file_path = file_path
        self.data = sc.read_h5ad(data_path, backed='r')
        self.length = self.data.X.shape[0]
        self.indices = indices
        self.indices_len = len(self.indices)

    def __getitem__(self, index):
        rand_start = random.randint(0, self.indices_len-1)
        data = self.data.X[self.indices[rand_start]]
        # Convert sparse matrix row to dense if necessary
        if isinstance(data, scipy.sparse.csr_matrix):
            data = data.toarray().squeeze(0)
            # print(data)

        # Apply the same preprocessing as before
        data[data > (CLASS - 2)] = CLASS - 2
        data = torch.from_numpy(data).long()
        data = torch.cat((data, torch.tensor([0]))).to(device)
        return data

    def __len__(self):
        return self.length

In [None]:
total_samples = 1357593  # Replace with the actual total length of your dataset
train_ratio = 0.95

# Calculate the number of samples in each set
num_train_samples = int(total_samples * train_ratio)
num_valid_samples = total_samples - num_train_samples

# Generate indices for training and validation sets
train_indices = list(range(0, num_train_samples))
valid_indices = list(range(num_train_samples, total_samples))

print("Training indices:", len(train_indices))
print("Validation indices:", len(valid_indices))

In [None]:
train_dataset = SCDataset(data_path, train_indices)
val_dataset = SCDataset(data_path, valid_indices)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

## High RAM

In [None]:
class SCDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data

    def __getitem__(self, index):
        rand_start = random.randint(0, self.data.shape[0]-1)
        full_seq = self.data[rand_start].toarray()[0]
        full_seq[full_seq > (CLASS - 2)] = CLASS - 2
        full_seq = torch.from_numpy(full_seq).long()
        full_seq = torch.cat((full_seq, torch.tensor([0]))).to(device)
        return full_seq

    def __len__(self):
        return self.data.shape[0]

data = sc.read_h5ad('/content/drive/MyDrive/scFasterBERT/data/panglao_human.h5ad')
data = data.X

In [None]:
data_train, data_val = train_test_split(data, test_size=0.05,random_state=SEED)

train_dataset = SCDataset(data_train)
val_dataset = SCDataset(data_val)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, pin_memory=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, pin_memory=True, num_workers=4))

# Model Initialization and Training

In [None]:
model = scBERT2().to(device)

In [None]:
optimizer = Adam(student_model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss(ignore_index = PAD_TOKEN_ID, reduction='mean').to(device)
softmax = nn.Softmax(dim=-1)
scaler = torch.cuda.amp.GradScaler()

In [None]:
train_loader, val_loader, model, optimizer = accelerator.prepare(train_loader, val_loader, model, optimizer)

In [None]:
EPOCHS = 1
VALIDATE_EVERY = 2
GRADIENT_ACCUMULATION = 4
PAD_TOKEN_ID = 0  # Assuming a placeholder value

for i in range(1, EPOCHS+1):
    model.train()
    running_loss = 0.0
    cum_acc = 0.0
    for index, data in tqdm(enumerate(train_loader)):
        with accelerator.accumulate(model):
          index += 1
          data = data.to(device)
          data, labels = data_mask(data)
          logits = model(data)
          loss = loss_fn(logits.transpose(1, 2), labels)
          accelerator.backward(loss)
          torch.nn.utils.clip_grad_norm_(model.parameters(), int(1e2))
          optimizer.step()
          scheduler.step()
          optimizer.zero_grad()
        running_loss += loss.item()
        final = softmax(logits)[..., 1:-1]
        final = final.argmax(dim=-1) + 1
        pred_num = (labels != PAD_TOKEN_ID).sum(dim=-1)
        correct_num = ((labels != PAD_TOKEN_ID) * (final == labels)).sum(dim=-1)
        cum_acc += torch.true_divide(correct_num, pred_num).mean().item()

    epoch_loss = running_loss / index
    epoch_acc = 100 * cum_acc / index
    print(f'    ==  Epoch: {i} | Training Loss: {epoch_loss:.6f} | Accuracy: {epoch_acc:6.4f}%  ==')

    if i % VALIDATE_EVERY == 0:
        model.eval()
        running_loss = 0.0
        predictions = []
        truths = []
        with torch.no_grad():
            for index, data in tqdm(enumerate(val_loader)):
                index += 1
                data = data.to(device)
                data, labels = data_mask(data)
                logits = model(data)
                loss = loss_fn(logits.transpose(1, 2), labels)
                running_loss += loss.item()
                softmax = nn.Softmax(dim=-1)
                final = softmax(logits)[..., 1:-1]
                final = final.argmax(dim=-1) + 1
                predictions.append(final)
                truths.append(labels)
        val_loss = running_loss / index
        correct_num = ((torch.cat(truths, dim=0) != PAD_TOKEN_ID) * (torch.cat(predictions, dim=0) == torch.cat(truths, dim=0))).sum().item()
        val_num = (torch.cat(truths, dim=0) != PAD_TOKEN_ID).sum().item()
        val_acc = 100 * correct_num / val_num
        print(f'    ==  Epoch: {i} | Validation Loss: {val_loss:.6f} | Accuracy: {val_acc:6.4f}%  ==')

    # save_ckpt(i, model, optimizerepoch_loss, model_name, ckpt_dir)

In [None]:
notebook_launcher(train_model, args=(), num_processes=2)