In [None]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 54.8 gigabytes of available RAM

You are using a high-RAM runtime!


In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

/bin/bash: line 1: nvidia-smi: command not found


In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
ls

[0m[01;34mdrive[0m/  [01;34msample_data[0m/


In [None]:
import os
os.environ['PYTHONPATH'] += "/content/drive/MyDrive/scFasterBERT/performer_pytorch"

In [None]:
!pip install einops
!pip install local_attention
!pip install scanpy

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0
Collecting local_attention
  Downloading local_attention-1.9.1-py3-none-any.whl (8.2 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->local_attention)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->local_attention)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->local_attention)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26

In [None]:
import sys
sys.path.insert(1, '/content/drive/MyDrive/scFasterBERT/performer_pytorch')
sys.path.insert(2, '/content/drive/MyDrive/scFasterBERT/')
import os
import gc
import argparse
import json
import random
import math
import random
from functools import reduce
import numpy as np
import pandas as pd
from scipy import sparse
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from performer_pytorch import PerformerLM
import scanpy as sc
import anndata as ad
from utils import *
import scipy.sparse
import h5py
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

In [None]:
Writer = SummaryWriter('./runs/scBERT_ours_pretrained')

In [None]:
data_path = '../data/panglao_human.h5ad'

In [None]:
SEED = 2021
EPOCHS = 100
BATCH_SIZE = 3
GRADIENT_ACCUMULATION = 60
LEARNING_RATE = 1e-4
SEQ_LEN = 16907
VALIDATE_EVERY = 1
CLASS = 7
MASK_PROB = 0.15
REPLACE_PROB = 0.9
RANDOM_TOKEN_PROB = 0.
MASK_TOKEN_ID = CLASS - 1
PAD_TOKEN_ID = CLASS - 1
MASK_IGNORE_TOKEN_IDS = [0]
POS_EMBED_USING = True

model_name = 'panglao_pretrain_ours_1'
ckpt_dir = './checkpoints/'

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(SEED)

<torch._C.Generator at 0x79787910c810>

# Masking

In [None]:
# get the random prob matrix and True means smaller than prob threshold
def prob_mask_like(t, prob):
    return torch.zeros_like(t).float().uniform_(0, 1) < prob

# get the mask matrix which cannot be masked
def mask_with_tokens(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    return mask

def get_mask_subset_with_prob(mask, prob):
    batch, seq_len, device = *mask.shape, mask.device
    max_masked = math.ceil(prob * seq_len)      # num of mask of a single sequence in average
    num_tokens = mask.sum(dim=-1, keepdim=True)     # num of pure tokens of each sequence except special tokens
    mask_excess = torch.cat((torch.zeros(0), torch.arange(mask.size(-1)).repeat(mask.size(0)))).reshape(mask.size(0),mask.size(-1)).to(device)
    mask_excess = (mask_excess >= (num_tokens * prob).ceil())        # only 15% of pure tokens can be masked
    mask_excess = mask_excess[:, :max_masked]       # get difference between 15% of pure tokens and 15% of all tokens
    rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)     # rand (0-1) as prob, special token use -1e9
    _, sampled_indices = rand.topk(max_masked, dim=-1)      # get index of topk prob to mask
    sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)        # delete difference of mask not pure
    new_mask = torch.zeros((batch, seq_len + 1), device=device)     # get (batch, seq_len) shape zero matrix
    new_mask.scatter_(-1, sampled_indices, 1)       # set masks in zero matrix as 1
    return new_mask[:, 1:].bool()       # the final mask, True is mask

def data_mask(data,
    mask_prob = MASK_PROB,
    replace_prob = REPLACE_PROB,
    num_tokens = None,
    random_token_prob = RANDOM_TOKEN_PROB,
    mask_token_id = MASK_TOKEN_ID,
    pad_token_id = PAD_TOKEN_ID,
    mask_ignore_token_ids = MASK_IGNORE_TOKEN_IDS
):
    mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id])
    # do not mask [pad] tokens, or any other tokens in the tokens designated to be excluded ([cls], [sep])
    # also do not include these special tokens in the tokens chosen at random
    no_mask = mask_with_tokens(data, mask_ignore_token_ids)   # ignore_token as True, will not be masked later
    mask = get_mask_subset_with_prob(~no_mask, mask_prob)      # get the True/False mask matrix
    # get mask indices
    ## mask_indices = torch.nonzero(mask, as_tuple=True)   # get the index of mask(nonzero value of mask matrix)
    # mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob)
    masked_input = data.clone().detach()
    # if random token probability > 0 for mlm
    if random_token_prob > 0:
        assert num_tokens is not None, 'num_tokens keyword must be supplied when instantiating MLM if using random token replacement'
        random_token_prob = prob_mask_like(data, random_token_prob)       # get the mask matrix of random token replace
        random_tokens = torch.randint(0, num_tokens, data.shape, device=data.device)     # generate random token matrix with the same shape as in
        random_no_mask = mask_with_tokens(random_tokens, mask_ignore_token_ids)        # not masked matrix for the random token matrix
        random_token_prob &= ~random_no_mask        # get the pure mask matrix of random token replace
        random_indices = torch.nonzero(random_token_prob, as_tuple=True)        # index of random token replace
        masked_input[random_indices] = random_tokens[random_indices]        # replace some tokens by random token
    # [mask] input
    replace_prob = prob_mask_like(data, replace_prob)     # get the mask matrix of token being masked
    masked_input = masked_input.masked_fill(mask * replace_prob, mask_token_id)        # get the data has been masked by mask_token
    # mask out any tokens to padding tokens that were not originally going to be masked
    labels = data.masked_fill(~mask, pad_token_id)        # the label of masked tokens
    return masked_input, labels

# Dataset and Dataloader

In [None]:
# total_samples = 1357593  # Replace with the actual total length of your dataset
# train_ratio = 0.95

# # Calculate the number of samples in each set
# num_train_samples = int(total_samples * train_ratio)
# num_valid_samples = total_samples - num_train_samples

# # Generate indices for training and validation sets
# train_indices = list(range(0, num_train_samples))
# valid_indices = list(range(num_train_samples, total_samples))

# print("Training indices:", len(train_indices))
# print("Validation indices:", len(valid_indices))

In [None]:
# class SCDataset(Dataset):
#     def __init__(self, file_path, indices):
#         self.file_path = file_path
#         self.data = sc.read_h5ad(data_path, backed='r')
#         self.length = self.data.X.shape[0]
#         self.indices = indices
#         self.indices_len = len(self.indices)

#     def __getitem__(self, index):
#         rand_start = random.randint(0, self.indices_len-1)
#         data = self.data.X[self.indices[rand_start]]
#         # Convert sparse matrix row to dense if necessary
#         if isinstance(data, scipy.sparse.csr_matrix):
#             data = data.toarray().squeeze(0)
#             # print(data)

#         # Apply the same preprocessing as before
#         data[data > (CLASS - 2)] = CLASS - 2
#         data = torch.from_numpy(data).long()
#         data = torch.cat((data, torch.tensor([0]))).to(device)
#         return data

#     def __len__(self):
#         return self.length

In [None]:

# train_dataset = SCDataset(data_path, train_indices)
# val_dataset = SCDataset(data_path, valid_indices)

# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
class SCDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data

    def __getitem__(self, index):
        rand_start = random.randint(0, self.data.shape[0]-1)
        full_seq = self.data[rand_start].toarray()[0]
        full_seq[full_seq > (CLASS - 2)] = CLASS - 2
        full_seq = torch.from_numpy(full_seq).long()
        full_seq = torch.cat((full_seq, torch.tensor([0]))).to(device)
        return full_seq

    def __len__(self):
        return self.data.shape[0]

data = sc.read_h5ad('/content/drive/MyDrive/scFasterBERT/data/panglao_human.h5ad')
data = data.X
data_train, data_val = train_test_split(data, test_size=0.05,random_state=SEED)

train_dataset = SCDataset(data_train)
val_dataset = SCDataset(data_val)



  utils.warn_names_duplicates("obs")


In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# Model

In [None]:
model = PerformerLM(
    num_tokens = CLASS,
    dim = 200,
    depth = 6,
    max_seq_len = SEQ_LEN,
    heads = 10,
    local_attn_heads = 0,
    g2v_position_emb = POS_EMBED_USING
    ).to(device)



# optimizer
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
ckpt = torch.load('/content/drive/MyDrive/scFasterBERT/panglao_pretrain.pth',map_location=torch.device('cpu'))
model.load_state_dict(ckpt['model_state_dict'])

<All keys matched successfully>

In [None]:
quantized_model = nn.Sequential(
    torch.quantization.QuantStub(),
    model,
    torch.quantization.DeQuantStub(),
)

In [None]:
loss_fn = nn.CrossEntropyLoss(ignore_index = PAD_TOKEN_ID, reduction='mean').to(device)
softmax = nn.Softmax(dim=-1)

In [None]:
for epoch in range(1, EPOCHS+1):
    model.train()
    running_loss = 0.0
    cum_acc = 0.0
    for index, data in tqdm(enumerate(train_loader)):
        index += 1
        data = data.to(device)
        data, labels = data_mask(data)
        if index % GRADIENT_ACCUMULATION != 0:
            logits = model(data)
            loss = loss_fn(logits.transpose(1, 2), labels) / GRADIENT_ACCUMULATION
            loss.backward()
        else:
            logits = model(data)
            loss = loss_fn(logits.transpose(1, 2), labels) / GRADIENT_ACCUMULATION
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), int(1e2))
            optimizer.step()
            optimizer.zero_grad()
        running_loss += loss.item()
        final = softmax(logits)[..., 1:-1]
        final = final.argmax(dim=-1) + 1
        pred_num = (labels != PAD_TOKEN_ID).sum(dim=-1)
        correct_num = ((labels != PAD_TOKEN_ID) * (final == labels)).sum(dim=-1)
        cum_acc += torch.true_divide(correct_num, pred_num).mean().item()
    epoch_loss = running_loss / index
    epoch_acc = 100 * cum_acc / index
    print(f'    ==  Epoch: {epoch} | Training Loss: {epoch_loss:.6f} | Accuracy: {epoch_acc:6.4f}%  ==')
    Writer.add_scalar('Training loss', epoch_loss, epoch)
    Writer.add_scalar('Training accuracy',epoch_acc, epoch)

    if epoch % VALIDATE_EVERY == 0:
        model.eval()
        running_loss = 0.0
        predictions = []
        truths = []
        with torch.no_grad():
            for index, data in tqdm(enumerate(val_loader)):
                index += 1
                data = data.to(device)
                data, labels = data_mask(data)
                logits = model(data)
                loss = loss_fn(logits.transpose(1, 2), labels)
                running_loss += loss.item()
                softmax = nn.Softmax(dim=-1)
                final = softmax(logits)[..., 1:-1]
                final = final.argmax(dim=-1) + 1
                predictions.append(final)
                truths.append(labels)
        val_loss = running_loss / index
        correct_num = ((torch.cat(truths, dim=0) != PAD_TOKEN_ID) * (torch.cat(predictions, dim=0) == torch.cat(truths, dim=0))).sum().item()
        val_num = (torch.cat(truths, dim=0) != PAD_TOKEN_ID).sum().item()
        val_acc = 100 * correct_num / val_num
        print(f'    ==  Epoch: {epoch} | Validation Loss: {val_loss:.6f} | Accuracy: {val_acc:6.4f}%  ==')
        Writer.add_scalar('Valid loss', val_loss, epoch)
        Writer.add_scalar('Valid accuracy',val_acc, epoch)

    # save_ckpt(i, model, optimizerepoch_loss, model_name, ckpt_dir)