# Fewshot Tutorial

This tutorial demonstrates how to finetune the scLinguist model on a fewshot dataset and then use it to predict protein expression from RNA data.

Import necessary packages and define paths for checkpoints and save directory.

In [1]:
import torch
from pathlib import Path
from torch.utils.data import DataLoader
import sys
sys.path.append('../../')
from scLinguist.data_loaders.data_loader import scMultiDataset
from scLinguist.model.configuration_hyena import HyenaConfig
from scLinguist.model.model import scTrans
import importlib, sys
sys.modules['model'] = importlib.import_module('scLinguist.model')

ENCODER_CKPT = Path("../../pretrained_model/encoder.ckpt")
DECODER_CKPT = Path("../../pretrained_model/decoder.ckpt")
FINETUNE_CKPT = Path("../../pretrained_model/finetune.ckpt")
SAVE_DIR = Path("../../docs/tutorials/fewshot_output")
SAVE_DIR.mkdir(exist_ok=True)

Configure dataloaders for fewshot and test datasets.

First, inspect the data structure of the datasets to ensure they are compatible with the model.

In [25]:
import scanpy as sc
rna_train = sc.read_h5ad('../../data/fewshot_sample_rna.h5ad')
rna_train

AnnData object with n_obs × n_vars = 20 × 19202
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_ADT', 'nFeature_ADT', 'lane', 'donor', 'celltype.l1', 'celltype.l2', 'RNA.weight', 'cell_type'

In [26]:
rna_train.X.todense()

matrix([[ 0.,  0.,  0., ..., 15.,  1., 27.],
        [ 0.,  0.,  0., ...,  0.,  0.,  2.],
        [ 0.,  0.,  0., ...,  2.,  0.,  7.],
        ...,
        [ 0.,  0.,  0., ...,  2.,  0.,  1.],
        [ 0.,  0.,  0., ..., 31.,  0., 69.],
        [ 0.,  0.,  0., ...,  2.,  0.,  8.]], dtype=float32)

In [27]:
rna_train.var_names

Index(['ENSG00000186092', 'ENSG00000284733', 'ENSG00000284662',
       'ENSG00000187634', 'ENSG00000188976', 'ENSG00000187961',
       'ENSG00000187583', 'ENSG00000187642', 'ENSG00000188290',
       'ENSG00000187608',
       ...
       'ENSG00000198712', 'ENSG00000228253', 'ENSG00000198899',
       'ENSG00000198938', 'ENSG00000198840', 'ENSG00000212907',
       'ENSG00000198886', 'ENSG00000198786', 'ENSG00000198695',
       'ENSG00000198727'],
      dtype='object', length=19202)

In [28]:
import scanpy as sc
rna_test = sc.read_h5ad('../../data/test_sample_rna.h5ad')
rna_test

AnnData object with n_obs × n_vars = 10546 × 19202
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_ADT', 'nFeature_ADT', 'lane', 'donor', 'celltype.l1', 'celltype.l2', 'RNA.weight', 'cell_type'

In [29]:
rna_test.X.todense()

matrix([[ 0.,  0.,  0., ...,  5.,  0.,  3.],
        [ 0.,  0.,  0., ...,  0.,  0.,  7.],
        [ 0.,  0.,  0., ...,  1.,  0.,  6.],
        ...,
        [ 0.,  0.,  0., ...,  1.,  0.,  2.],
        [ 0.,  0.,  0., ...,  2.,  0.,  2.],
        [ 0.,  0.,  0., ...,  4.,  0., 12.]], dtype=float32)

In [30]:
rna_test.var_names

Index(['ENSG00000186092', 'ENSG00000284733', 'ENSG00000284662',
       'ENSG00000187634', 'ENSG00000188976', 'ENSG00000187961',
       'ENSG00000187583', 'ENSG00000187642', 'ENSG00000188290',
       'ENSG00000187608',
       ...
       'ENSG00000198712', 'ENSG00000228253', 'ENSG00000198899',
       'ENSG00000198938', 'ENSG00000198840', 'ENSG00000212907',
       'ENSG00000198886', 'ENSG00000198786', 'ENSG00000198695',
       'ENSG00000198727'],
      dtype='object', length=19202)

In [31]:
import scanpy as sc
adt_train = sc.read_h5ad('../../data/train_sample_adt.h5ad')
adt_train

AnnData object with n_obs × n_vars = 16994 × 6427
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_ADT', 'nFeature_ADT', 'lane', 'donor', 'celltype.l1', 'celltype.l2', 'RNA.weight', 'cell_type'

In [32]:
import numpy as np
mask = ~np.isnan(adt_train.X[0].toarray())
adt_train[:, mask].X.todense()

matrix([[1.100e+02, 1.400e+01, 4.900e+01, ..., 1.100e+01, 2.120e+02,
         2.800e+01],
        [1.200e+02, 1.000e+00, 5.890e+02, ..., 2.000e+00, 3.600e+01,
         2.400e+01],
        [6.450e+02, 5.000e+00, 1.256e+03, ..., 1.000e+00, 7.200e+01,
         1.320e+02],
        ...,
        [2.330e+02, 2.700e+01, 8.420e+02, ..., 3.000e+00, 7.700e+01,
         4.600e+01],
        [3.120e+02, 1.500e+01, 1.079e+03, ..., 4.000e+00, 4.800e+01,
         8.000e+01],
        [1.960e+02, 2.000e+00, 1.910e+02, ..., 0.000e+00, 3.400e+01,
         1.900e+01]])

In [38]:
adt_train.var_names

Index(['SP110', 'GTPBA', 'SNX2', 'FRG1', 'TT21A', 'RHG18', 'AR', 'DOCK1',
       'RAB1A', 'MUC1.HMFG2',
       ...
       'CYTSA', 'LFNG', 'PFKFB4', 'LIPB1', 'ZN225', 'TRI69', 'CCL14', 'ZN541',
       'TAP1', 'SCG3'],
      dtype='object', length=6427)

In [33]:
import scanpy as sc
adt_test = sc.read_h5ad('../../data/test_sample_adt.h5ad')
adt_test

AnnData object with n_obs × n_vars = 10546 × 6427
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_ADT', 'nFeature_ADT', 'lane', 'donor', 'celltype.l1', 'celltype.l2', 'RNA.weight', 'cell_type'

Then, create the dataloaders for both fewshot and test datasets. The `scMultiDataset` class is used to load the RNA and protein data from the specified paths.

In [34]:
BATCH_SIZE = 4
fewshot_data = scMultiDataset(
    data_dir_1="../../data/fewshot_sample_rna.h5ad",
    data_dir_2="../../data/fewshot_sample_adt.h5ad",
)
test_data = scMultiDataset(
    data_dir_1="../../data/test_sample_rna.h5ad",
    data_dir_2="../../data/test_sample_adt.h5ad",
)
fewshot_dataloader = DataLoader(
    fewshot_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
    pin_memory=True
)
test_dataloader = DataLoader(
    test_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=False,
    num_workers=0,
    pin_memory=True,
)

Last, configure the model with the appropriate encoder and decoder checkpoints, and set the mode to "RNA-protein". The `HyenaConfig` class is used to define the model configuration parameters such as `d_model`, `emb_dim`, `max_seq_len`, `vocab_len`, and `n_layer`.

In [35]:
enc_cfg = HyenaConfig(
    d_model        = 128,
    emb_dim        = 5,
    max_seq_len    = 19202,
    vocab_len      = 19202,
    n_layer        = 1,
    output_hidden_states=False,
)
dec_cfg = HyenaConfig(
    d_model        = 128,
    emb_dim        = 5,
    max_seq_len    = 6427,
    vocab_len      = 6427,
    n_layer        = 1,
    output_hidden_states=False,
)
model = scTrans.load_from_checkpoint(checkpoint_path=FINETUNE_CKPT)
model.encoder_ckpt_path = ENCODER_CKPT
model.decoder_ckpt_path = DECODER_CKPT
model.mode = "RNA-protein"

Start training the model using PyTorch Lightning. The `ModelCheckpoint` callback is used to save the best model based on validation loss, and the `EarlyStopping` callback is used to stop training if the validation loss does not improve for a specified number of epochs.

In [37]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

ckpt_cb = ModelCheckpoint(
    dirpath      = SAVE_DIR/"ckpt",
    monitor      = "valid_loss",
    mode         = "min",
    save_top_k   = 1,
    filename     = "best-{epoch}-{valid_loss:.4f}",
)
early_cb = EarlyStopping(monitor="valid_loss", mode="min", patience=3)

trainer = pl.Trainer(
    accelerator       = "gpu",
    devices           = [1],
    max_epochs        = 6,
    log_every_n_steps = 50,
    callbacks         = [ckpt_cb, early_cb],
)

trainer.fit(model, fewshot_dataloader, fewshot_dataloader)
best_ckpt = ckpt_cb.best_model_path

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name       | Type             | Params
------------------------------------------------
0 | encoder    | scHeyna_enc      | 313 K 
1 | decoder    | scHeyna_dec      | 249 K 
2 | translator | MLPTranslator    | 284 M 
3 | cos_gene   | CosineSimilarity | 0     
4 | cos_cell   | CosineSimilarity | 0     
------------------------------------------------
285 M     Trainable params
0         Non-trainable params
285 M     Total params
1,141.275 Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Inference with the trained model on the test dataset. Use RNA data to predict proteins in ../../docs/tutorials/protein_names.txt

In [None]:
import scanpy as sc
import torch

# only use 10 cells for example
fewshot_adata = sc.read_h5ad("../../data/test_sample_rna.h5ad")[:10]
sc.pp.normalize_total(fewshot_adata, target_sum=10000)
sc.pp.log1p(fewshot_adata)
fewshot_rna_tensor = torch.tensor(fewshot_adata.X.todense(), dtype=torch.float32).cuda()

model.eval().cuda()

with torch.no_grad():
    _, _, protein_pred = model(fewshot_rna_tensor)

# predict given proteins
target_proteins = [line.strip() for line in open("../../docs/tutorials/protein_names.txt")]

import pandas as pd
prot_map = pd.read_csv("../../docs/tutorials/protein_index_map.csv")
name_to_idx = dict(zip(prot_map["name"], prot_map["index"]))

idx = [name_to_idx[p] for p in target_proteins if p in name_to_idx]

pred_df = pd.DataFrame(
    protein_pred[:, idx].cpu().numpy(),
    columns = target_proteins,
    index   = fewshot_adata.obs_names,
)
pred_df.to_csv(SAVE_DIR/"predicted_protein_expression.csv")