In [1]:
from pathlib import Path
import json

from src.student_tahoe_x1.configuration_student_tx import StudentTXConfig
from src.student_tahoe_x1.trainer import train_distillation_model

from tahoe_x1.tokenizer import GeneVocab 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# --- Setup Dummy Files and Directories for Testing ---
output_dir = "/home/oem/vcivale/DistillationScFoundation/output" # Changed to avoid conflicts with previous dummy data
Path(output_dir).mkdir(parents=True, exist_ok=True)

vocab_path = Path("vocab.json")
student_config_save_path = Path(output_dir) / "config.json"

train_h5ad_path = Path("/home/oem/vcivale/scFoundation/dataset/data_yuto/tahoe_x1_embeddings/70m/data_yuto_with_clusters_chunk_001.h5ad")
val_h5ad_path = Path("/home/oem/vcivale/scFoundation/dataset/data_yuto/tahoe_x1_embeddings/70m/data_yuto_with_clusters_chunk_002.h5ad")
output_dir = Path("distillation_output")
output_dir.mkdir(parents=True, exist_ok=True)

gene_dict = json.load(open(vocab_path, "r"))
vocab = GeneVocab(gene_dict)

# Create a dummy Student config and save it
student_test_config = StudentTXConfig(
    vocab_size=len(vocab),
    n_layers=2,
    n_heads=2,
    d_model=128,
    expansion_ratio=4,
    pad_token_id=vocab["<pad>"],
    pad_value=0,
    n_input_bins=51,
    use_flash_attention=True,
    max_position_embeddings=100,
)


In [3]:
student_test_config.save_pretrained("/home/oem/vcivale/DistillationScFoundation/output", filename="config.json")

In [None]:
train_distillation_model(
    student_config_path=str(student_config_save_path),
    teacher_vocab_path=str(vocab_path),
    train_h5ad_path=str(train_h5ad_path),
    validation_h5ad_path=str(val_h5ad_path),
    output_dir=str(output_dir),
    num_train_epochs=1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    max_seq_len=100,
    teacher_embedding_key="Tx1-70m", # Pass the key
)