In [3]:
import torch
from pathlib import Path
from torch.utils.data import DataLoader
import sys
sys.path.append('../')
from scLinguist.data_loaders.data_loader import scMultiDataset
from scLinguist.model.configuration_hyena import HyenaConfig
from scLinguist.model.model import scTrans
import importlib, sys
sys.modules['model'] = importlib.import_module('scLinguist.model')

ENCODER_CKPT = Path("../pretrained_model/encoder.ckpt")
DECODER_CKPT = Path("../pretrained_model/decoder.ckpt")
FINETUNE_CKPT = Path("../pretrained_model/finetune.ckpt")
SAVE_DIR = Path("../docs/tutorials/fewshot_output")
SAVE_DIR.mkdir(exist_ok=True)

In [4]:
BATCH_SIZE = 4
fewshot_data = scMultiDataset(
    data_dir_1="../data/fewshot_sample_rna.h5ad",
    data_dir_2="../data/fewshot_sample_adt.h5ad",
)
test_data = scMultiDataset(
    data_dir_1="../data/test_sample_rna.h5ad",
    data_dir_2="../data/test_sample_adt.h5ad",
)
fewshot_dataloader = DataLoader(
    fewshot_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
    pin_memory=True
)
test_dataloader = DataLoader(
    test_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=False,
    num_workers=0,
    pin_memory=True,
)

In [5]:
enc_cfg = HyenaConfig(
    d_model        = 128,
    emb_dim        = 5,
    max_seq_len    = 19202,
    vocab_len      = 19202,
    n_layer        = 1,
    output_hidden_states=False,
)
dec_cfg = HyenaConfig(
    d_model        = 128,
    emb_dim        = 5,
    max_seq_len    = 6427,
    vocab_len      = 6427,
    n_layer        = 1,
    output_hidden_states=False,
)
model = scTrans.load_from_checkpoint(checkpoint_path=FINETUNE_CKPT)
model.encoder_ckpt_path = ENCODER_CKPT
model.decoder_ckpt_path = DECODER_CKPT
model.mode = "RNA-protein"

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

ckpt_cb = ModelCheckpoint(
    dirpath      = SAVE_DIR/"ckpt",
    monitor      = "valid_loss",
    mode         = "min",
    save_top_k   = 1,
    filename     = "best-{epoch}-{valid_loss:.4f}",
)
early_cb = EarlyStopping(monitor="valid_loss", mode="min", patience=3)

trainer = pl.Trainer(
    accelerator       = "gpu",
    devices           = [1],
    max_epochs        = 1,
    log_every_n_steps = 50,
    callbacks         = [ckpt_cb, early_cb],
)

trainer.fit(model, fewshot_dataloader, fewshot_dataloader)
best_ckpt = ckpt_cb.best_model_path

In [None]:
import scanpy as sc
import torch

test_adata = sc.read_h5ad("../data/test_sample_rna.h5ad")[:10]
rna_tensor = torch.tensor(test_adata.X.todense(), dtype=torch.float32).cuda()