# Finetuning on your custom datasets
In this tutorial, we will finetune the "finetuned" scTrans model on your custom datasets.

First, import the necessary packages and define some paths.

In [5]:
from pathlib import Path
from torch.utils.data import DataLoader
import sys
sys.path.append('../')
from scLinguist.data_loaders.data_loader import scMultiDataset
from scLinguist.model.configuration_hyena import HyenaConfig
from scLinguist.model.model import scTrans
import importlib, sys
sys.modules['model'] = importlib.import_module('scLinguist.model')

# 预训练权重（encoder / decoder）
ENCODER_CKPT = Path("../pretrained_model/encoder.ckpt")
DECODER_CKPT = Path("../pretrained_model/decoder.ckpt")
FINETUNE_CKPT = Path("../pretrained_model/finetune.ckpt")
# 结果保存
SAVE_DIR = Path("../docs/tutorials/finetune_output")
SAVE_DIR.mkdir(exist_ok=True)

Load your custom dataset into our dataloader.

In [6]:
BATCH_SIZE = 4
train_ds = scMultiDataset(
    data_dir_1="../data/train_sample_rna.h5ad",
    data_dir_2="../data/train_sample_adt.h5ad",
)
valid_ds = scMultiDataset(
    data_dir_1="../data/valid_sample_rna.h5ad",
    data_dir_2="../data/valid_sample_adt.h5ad",
)
test_ds = scMultiDataset(
    data_dir_1="../data/test_sample_rna.h5ad",
    data_dir_2="../data/test_sample_adt.h5ad",
)

train_dataloader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    drop_last=True,
)
valid_dataloader = DataLoader(
    valid_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=False,
    num_workers=8,
    pin_memory=True,
)
test_dataloader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=False,
    num_workers=0,
    pin_memory=True,
)


Load our 3M finetuned model and set the encoder/decoder checkpoints.

In [7]:
enc_cfg = HyenaConfig(
    d_model        = 128,
    emb_dim        = 5,
    max_seq_len    = 19202,   # RNA 位点数
    vocab_len      = 19202,
    n_layer        = 1,
    output_hidden_states=False,
)
dec_cfg = HyenaConfig(
    d_model        = 128,
    emb_dim        = 5,
    max_seq_len    = 6427,    # Protein 位点数
    vocab_len      = 6427,
    n_layer        = 1,
    output_hidden_states=False,
)
model = scTrans.load_from_checkpoint(checkpoint_path=FINETUNE_CKPT)
model.encoder_ckpt_path = ENCODER_CKPT
model.decoder_ckpt_path = DECODER_CKPT
model.mode = "RNA-protein"  # 设置为 RNA-protein 模式

Then, start training.

In [9]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

ckpt_cb = ModelCheckpoint(
    dirpath      = SAVE_DIR/"ckpt",
    monitor      = "valid_loss",
    mode         = "min",
    save_top_k   = 1,
    filename     = "best-{epoch}-{valid_loss:.4f}",
)
early_cb = EarlyStopping(monitor="valid_loss", mode="min", patience=3)

trainer = pl.Trainer(
    accelerator       = "gpu",
    devices           = [0],           # 多卡可设为 list(range(N))
    max_epochs        = 1,
    log_every_n_steps = 50,
    callbacks         = [ckpt_cb, early_cb],
)

trainer.fit(model, train_dataloader, valid_dataloader)
best_ckpt = ckpt_cb.best_model_path


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name       | Type             | Params
------------------------------------------------
0 | encoder    | scHeyna_enc      | 313 K 
1 | decoder    | scHeyna_dec      | 249 K 
2 | translator | MLPTranslator    | 284 M 
3 | cos_gene   | CosineSimilarity | 0     
4 | cos_cell   | CosineSimilarity | 0     
------------------------------------------------
285 M     Trainable params
0         Non-trainable params
285 M     Total params
1,141.275 Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Use test data to predict protein(you can upload a list of protein names) expression.

In [33]:
import scanpy as sc
import torch

test_adata = sc.read_h5ad("../data/test_sample_rna.h5ad")[:10]
rna_tensor = torch.tensor(test_adata.X.todense(), dtype=torch.float32).cuda()


In [34]:
model = scTrans.load_from_checkpoint(best_ckpt)
model.encoder_ckpt_path = ENCODER_CKPT
model.decoder_ckpt_path = DECODER_CKPT
model.mode = "RNA-protein"  # 设置为 RNA-protein 模式
model.eval().cuda()

with torch.no_grad():
    _, _, protein_pred = model(rna_tensor)   # protein_pred shape: cells × 6427


In [35]:
# 读取目标蛋白名
target_proteins = [line.strip() for line in open("../docs/tutorials/protein_names.txt")]

# 建立 name→index 字典；示例：假设你有 index↔name 对照表 csv
import pandas as pd
prot_map = pd.read_csv("../docs/tutorials/protein_index_map.csv")     # columns: name, index
name_to_idx = dict(zip(prot_map["name"], prot_map["index"]))

idx = [name_to_idx[p] for p in target_proteins if p in name_to_idx]

pred_df = pd.DataFrame(
    protein_pred[:, idx].cpu().numpy(),
    columns = target_proteins,
    index   = test_adata.obs_names,
)
pred_df.to_csv(SAVE_DIR/"predicted_protein_expression.csv")
