In [1]:
import numpy as np
import scanpy as sc
from pathlib import Path
import torch
from tqdm.auto import tqdm

from distilled_tx1.preprocessing.pipeline import TahoePreprocessor, PreprocessingConfig
from distilled_tx1.models.modeling_distilled_tahoe import DistilledTahoeModel, DistilledTahoeConfig
from distilled_tx1.training.distillation import train_distilled_model
from distilled_tx1.data.load_h5ad_folder import load_h5ad_folder_lazy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ref_adata = sc.read_h5ad("data_yuto_with_clusters_chunk_001.h5ad")

In [7]:
ref_adata

AnnData object with n_obs × n_vars = 100000 × 36391
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_HTO', 'nFeature_HTO', 'percent.mt', 'percent.ribo', 'log2_nCount', 'log2_nFeature', 'log2_mt', 'Donor_id', 'Age_group', 'Sex', 'Age', 'Tube_id', 'Batch', 'File_name', 'Cluster_names', 'Cluster_numbers', 'HLA-DR Antibody', 'Pan-Kir Antibody', 'CD20 Antibody', 'IgM Antibody', 'IgD Antibody', 'CD3 Antibody', 'CD4 Antibody', 'CD8 Antibody', 'CD45RA Antibody', 'CD69 Antibody', 'CD62L Antibody', 'CD38 Antibody', 'CD194 Antibody', 'CD25 Antibody', 'CD45RO Antibody', 'CD195 Antibody', 'CD103 Antibody', 'CD27 Antibody', 'CD57 Antibody', 'CD56 Antibody', 'Detailed_Cluster_names'
    var: 'gene_id', 'id_in_vocab'
    obsm: 'Tx1-70m'

In [3]:
teacher_embeddings = np.array([])

In [4]:
config = PreprocessingConfig(
        seq_len=512,
        n_bins=51,
        normalize=False,
        target_sum=1e4,
        gene_sampling_strategy="topk",
        add_cls_token=True,
        gene_id_key="gene_id"  # or None to use var_names
    )
    
preprocessor = TahoePreprocessor(
    config=config,
    tahoe_model_size="70m",
    vocab_path="vocab.json"
)

In [5]:
gene_ids = np.array([])
expression_bins = np.array([])
attention_masks = np.array([])

In [9]:
for h5ad_file in tqdm(Path("70m/sub").glob("*.h5ad")):
    adata = sc.read_h5ad(h5ad_file)
    adata.var['gene_id'] = ref_adata.var['gene_id']
    
    processed = preprocessor.process_adata(adata, return_dict=True)

    if gene_ids.size == 0:
        gene_ids = processed["gene_ids"].numpy()
    else:
        gene_ids = np.concatenate([gene_ids, processed["gene_ids"].numpy()])
    
    if expression_bins.size == 0:
        expression_bins = processed["expression_bins"].numpy()
    else:
        expression_bins = np.concatenate([expression_bins, processed["expression_bins"].numpy()])
    
    if attention_masks.size == 0:
        attention_masks = processed["attention_mask"].numpy()
    else:
        attention_masks = np.concatenate([attention_masks, processed["attention_mask"].numpy()])

    if teacher_embeddings.size == 0:
        teacher_embeddings = adata.obsm['Tx1-70m']
    else:
        teacher_embeddings = np.concatenate([teacher_embeddings, adata.obsm['Tx1-70m']])

    del adata

0it [00:00, ?it/s]

Gene vocabulary matching:
  Total genes in data: 36391
  Genes in vocabulary: 36391
  Coverage: 100.0%
Creating sequences for 100000 cells...


Tokenizing cells: 100%|██████████| 100/100 [00:10<00:00,  9.57batch/s, cells=100,000/100,000, batch_size=1000]
1it [01:32, 92.19s/it]

Gene vocabulary matching:
  Total genes in data: 36391
  Genes in vocabulary: 36391
  Coverage: 100.0%
Creating sequences for 100000 cells...


Tokenizing cells: 100%|██████████| 100/100 [00:08<00:00, 11.29batch/s, cells=100,000/100,000, batch_size=1000]
2it [03:05, 92.97s/it]

Gene vocabulary matching:
  Total genes in data: 36391
  Genes in vocabulary: 36391
  Coverage: 100.0%
Creating sequences for 100000 cells...


Tokenizing cells: 100%|██████████| 100/100 [00:10<00:00,  9.74batch/s, cells=100,000/100,000, batch_size=1000]
3it [04:46, 96.63s/it]

Gene vocabulary matching:
  Total genes in data: 36391
  Genes in vocabulary: 36391
  Coverage: 100.0%
Creating sequences for 100000 cells...


Tokenizing cells: 100%|██████████| 100/100 [00:11<00:00,  9.04batch/s, cells=100,000/100,000, batch_size=1000]
4it [06:33, 100.51s/it]

Gene vocabulary matching:
  Total genes in data: 36391
  Genes in vocabulary: 36391
  Coverage: 100.0%
Creating sequences for 100000 cells...


Tokenizing cells: 100%|██████████| 100/100 [00:10<00:00,  9.87batch/s, cells=100,000/100,000, batch_size=1000]
5it [08:19, 102.49s/it]

Gene vocabulary matching:
  Total genes in data: 36391
  Genes in vocabulary: 36391
  Coverage: 100.0%
Creating sequences for 100000 cells...


Tokenizing cells: 100%|██████████| 100/100 [00:09<00:00, 10.23batch/s, cells=100,000/100,000, batch_size=1000]
6it [10:05, 103.88s/it]

Gene vocabulary matching:
  Total genes in data: 36391
  Genes in vocabulary: 36391
  Coverage: 100.0%
Creating sequences for 100000 cells...


Tokenizing cells: 100%|██████████| 100/100 [00:10<00:00,  9.63batch/s, cells=100,000/100,000, batch_size=1000]
7it [11:55, 105.68s/it]

Gene vocabulary matching:
  Total genes in data: 36391
  Genes in vocabulary: 36391
  Coverage: 100.0%
Creating sequences for 100000 cells...


Tokenizing cells: 100%|██████████| 100/100 [00:10<00:00,  9.53batch/s, cells=100,000/100,000, batch_size=1000]
8it [13:43, 106.69s/it]

Gene vocabulary matching:
  Total genes in data: 36391
  Genes in vocabulary: 36391
  Coverage: 100.0%
Creating sequences for 100000 cells...


Tokenizing cells: 100%|██████████| 100/100 [00:10<00:00,  9.14batch/s, cells=100,000/100,000, batch_size=1000]
9it [15:34, 107.76s/it]

Gene vocabulary matching:
  Total genes in data: 36391
  Genes in vocabulary: 36391
  Coverage: 100.0%
Creating sequences for 100000 cells...


Tokenizing cells: 100%|██████████| 100/100 [00:10<00:00,  9.25batch/s, cells=100,000/100,000, batch_size=1000]
10it [17:34, 105.46s/it]


In [10]:
student_config = DistilledTahoeConfig(
        vocab_size=preprocessor.vocab.vocab_size,
        n_bins=config.n_bins,
        hidden_size=512,  # Match teacher embedding dimension
        num_hidden_layers=6,  # Smaller than Tahoe X1 (12-24 layers)
        num_attention_heads=8,
        intermediate_size=2048,
        max_position_embeddings=config.seq_len,
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        pooling_strategy="cls"  # or "mean"
    )

In [None]:
model = train_distilled_model(
        gene_ids=gene_ids,
        expression_bins=expression_bins,
        attention_masks=attention_masks,
        teacher_embeddings=teacher_embeddings,
        labels=None,  # Optional: add classification labels
        config=student_config,
        output_dir="./model_outputs/distilled_tahoe",
        num_epochs=10,
        batch_size=64,  # Adjust based on GPU memory
        learning_rate=5e-3,
        warmup_steps=1000,
        weight_decay=0.01,
        max_grad_norm=1.0,
        logging_steps=100,
        save_steps=5000,
        eval_split=0.1,
        use_wandb=True,  # Optional: log to W&B
        wandb_project="distilled-tahoe-x1",
        cosine_loss_weight=1.0,  
    )
    

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.



Epoch 1/10


Training:  36%|███▌      | 5000/14063 [36:22<1:29:01,  1.70it/s, loss=0.2025]


Saved checkpoint to model_outputs/distilled_tahoe/checkpoint-5000


Training:  45%|████▍     | 6299/14063 [46:49<1:02:47,  2.06it/s, loss=0.1782]