## Tahoe-x1 Training Tutorial

This notebook demonstrates how to train a Tahoe-x1 model from scratch or fine-tune a pre-trained model.

### 0. Prerequisites
- Access to GPU resources (NVIDIA H100/H200 recommended)
- Tahoe-x1 package installed (Refer to README for the installation guide)
- Access to training data (see README for dataset information) 
    - You either need to have the training data locally in your machine or provide the aws s3 credentials so that the data can be strimmed from our public s3 bucket (recommended)
- Weights & Biases account (optional, for logging)


### 1. Load and Customize Config

You can start with the `test_run.yaml` which is a sample config on how training the 70M model  and customize it for your training.

In [1]:
import os
import sys
from omegaconf import OmegaConf as om

sys.path.insert(0, os.path.abspath('..'))

# Load the base configuration
cfg = om.load("../configs/test_run.yaml")
print(om.to_yaml(cfg))

seed: 777
device_train_batch_size: 100
global_train_batch_size: 100
device_eval_batch_size: 100
device_train_microbatch_size: auto
vocabulary:
  remote: s3://tahoe-hackathon-data/MFM/vevo_v2_vocab.json
  local: vocab.json
model:
  name: tahoex
  d_model: 512
  n_layers: 12
  init_device: cpu
  expansion_ratio: 4
  standard_scale_outputs: false
  transformer_activation: relu
  n_heads: 8
  norm_scheme: pre
  use_generative_training: true
  use_cell_conditioned_generation: false
  use_glu: false
  cell_emb_style: cls
  attn_config:
    attn_impl: flash
    attn_type: grouped_query_attention
    kv_nheads: 8
    attn_pdrop: 0.0
    use_attn_mask: false
  norm_config:
    norm_type: layernorm
    eps: 1.0e-05
  expression_encoder:
    input_emb_style: continuous
    dropout: 0.1
    max_value: 512
    activation: relu
    use_norm: true
  gene_encoder:
    use_norm: true
  mvc:
    arch_style: inner product
    query_activation: sigmoid
    scaled_dot_product: true
  expression_decoder:
  

In [2]:
# Customize the config based on your system, design choice, etc

# Training settings
cfg.global_train_batch_size = 256  # Total batch size across all devices
cfg.max_duration = "20ba" #"2ep"  # Train for 2 epochs (adjust as needed)

# Model configuration
cfg.model.d_model = 512
cfg.model.n_layers = 12
cfg.model.n_heads = 8

# IMPORTANT: Current codebase only supports flash attention without attention mask
cfg.model.attn_config.attn_impl = "flash"
cfg.model.attn_config.use_attn_mask = False

# Data loader settings
cfg.train_loader.num_workers = 4  # Adjust based on your system
cfg.train_loader.prefetch_factor = 2 # Adjust based on your system

cfg.collator.use_chem_token=False # You can set it to True if your training data includes drug info(such as Tahoe100M) and you want to inject that to the model
# Optimizer settings
cfg.optimizer.lr = 3.0e-4
cfg.optimizer.weight_decay = 1.0e-05

# Logging
cfg.run_name = "custom_test_run"
cfg.loggers.wandb.project = "tahoex-tutorial"
save_folder = cfg.save_folder = f"./checkpoints/{cfg.run_name}"
cfg.save_interval = "500ba"  # Save every 500 batches

# Save the config
custom_config_path = "./my_training_config.yaml"
om.save(cfg, custom_config_path)
print(f"Configuration saved to: {custom_config_path}")

Configuration saved to: ./my_training_config.yaml


### 2. Training from Scratch

#### Option A: Train using the Python API

In [3]:
from train import main
cfg = om.load(custom_config_path)

# Train the model
trainer = main(cfg)

print(f"Training completed and checkpoints saved at {save_folder}!")

  from .autonotebook import tqdm as notebook_tqdm
  import pkg_resources
2025-10-21 20:51:36,435: rank0[325359][MainThread]: INFO: train: Downloading vocab...


Error downloading the file from s3://tahoe-hackathon-data/MFM/vevo_v2_vocab.json: Unable to locate credentials


[rank0]:[W1021 20:51:36.856222259 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
2025-10-21 20:51:36,984: rank0[325359][MainThread]: INFO: train: Setting vocab size to: 62720
2025-10-21 20:51:37,175: rank0[325359][MainThread]: INFO: train: Building DataLoaders...
2025-10-21 20:51:39,636: rank0[325359][MainThread]: INFO: train: train set number of samples: 60746795
2025-10-21 20:51:39,679: rank0[325359][MainThread]: INFO: train: Validation set number of samples: 613604
2025-10-21 20:51:40,025: rank0[325359][MainThread]: INFO: tahoex.model.model: MosaicML recommends using config.init_device="meta" with Composer + FSDP for faster initialization.
2025-10-21 20:51:41,142: rank0[325359][MainThread]: INFO: train: 

  return torch.load(_ensure_valid_checkpoint(checkpoint_filepath), map_location=map_location)
2025-10-21 20:51:44,866: rank0[325359][MainThread]: INFO: train: Logging config


seed: 777
device_train_batch_size: 256
global_train_batch_size: 256
device_eval_batch_size: 100
device_train_microbatch_size: auto
vocabulary:
  remote: s3://tahoe-hackathon-data/MFM/vevo_v2_vocab.json
  local: vocab.json
model:
  name: tahoex
  d_model: 512
  n_layers: 12
  init_device: cpu
  expansion_ratio: 4
  standard_scale_outputs: false
  transformer_activation: relu
  n_heads: 8
  norm_scheme: pre
  use_generative_training: true
  use_cell_conditioned_generation: false
  use_glu: false
  cell_emb_style: cls
  attn_config:
    attn_impl: flash
    attn_type: grouped_query_attention
    kv_nheads: 8
    attn_pdrop: 0.0
    use_attn_mask: false
  norm_config:
    norm_type: layernorm
    eps: 1.0e-05
  expression_encoder:
    input_emb_style: continuous
    dropout: 0.1
    max_value: 512
    activation: relu
    use_norm: true
  gene_encoder:
    use_norm: true
  mvc:
    arch_style: inner product
    query_activation: sigmoid
    scaled_dot_product: true
  expression_decoder:
  

2025-10-21 20:51:45,240: rank0[325359][MainThread]: INFO: train: Starting training...
******************************
Config:
composer_commit_hash: None
composer_version: 0.28.0
enabled_algorithms/GradientClipping: true
enabled_algorithms/LowPrecisionLayerNorm: true
node_name: unknown because NODENAME environment variable not set
num_gpus_per_node: 1
num_nodes: 1
rank_zero_seed: 777
time/remaining_estimate_unit: hours

******************************
[batch=11/20]:
	 Train time/batch: 10
	 Train time/sample: 2560
	 Train time/batch_in_epoch: 10
	 Train time/sample_in_epoch: 2560
	 Train memory/current_allocated_mem: 1.4400
	 Train memory/current_active_mem: 1.4400
	 Train memory/current_inactive_mem: 5.4282
	 Train memory/current_reserved_mem: 26.4090
	 Train memory/peak_allocated_mem: 40.4660
	 Train memory/peak_active_mem: 40.4660
	 Train memory/peak_inactive_mem: 6.0906
	 Train memory/peak_reserved_mem: 41.2530
	 Train memory/alloc_retries: 1
	 Train trainer/device_train_microbatch_si

Training completed and checkpoints saved at ./checkpoints/custom_test_run!


#### Option B: Train using Composer CLI

Alternatively, you can train using the command line with composer:

In [None]:
# Run training via shell command
!composer ../train.py -f {custom_config_path}

/tahoe/tahoe-x1/.venv/bin/python3: can't open file '/tahoe/tahoe-x1/scripts/../train.py': [Errno 2] No such file or directory
ERROR:composer.cli.launcher:Rank 0 crashed with exit code 2.
Waiting up to 30 seconds for all training processes to terminate. Press Ctrl-C to exit immediately.
Global rank 0 (PID 303486) exited with code 2
ERROR:composer.cli.launcher:Global rank 0 (PID 303486) exited with code 2


####  Resume Training

Note that if your run stopped unexpectedly and you want to resume the training from where it stopped, simply use the **same `run_name` and `save_folder`** in the configuration. The trainer will automatically pick up from the last saved checkpoint.
The trainer will automatically detect existing checkpoints and resume with full state (model weights, optimizer, scheduler, etc.).

```python
resume_cfg = om.load(custom_config_path)

# Keep the same run_name and save_folder - training will auto-resume
trainer = main(resume_cfg)
```



### 3. Fine-tuning a Pre-trained Model

When loading from a checkpoint, you have two options:

**Option 1:** Full Recovery
- Set `load_path` to your checkpoint directory or file
- Loads both model weights AND optimizer/scheduler states

```python
cfg.load_path = "s3://bucket/path/to/checkpoint/"
# This recovers everything: weights + optimizer + scheduler
```

**Option 2:** Weights Only 
- Set `load_path` AND `load_weights_only=True`
- Loads **only model weights**, optimizer/scheduler are initialized fresh

```python
cfg.load_path = "s3://bucket/path/to/checkpoint/"
cfg.load_weights_only = True
# This loads only weights, optimizer/scheduler start fresh
```

In [17]:
# Load configuration for fine-tuning
finetune_cfg = om.load("../configs/test_run.yaml")

# Set checkpoint path
checkpoint_path = "s3://tahoe-hackathon-data/MFM/ckpts/70m/best-model.pt"  # Or local path
finetune_cfg.load_path = checkpoint_path

# Adjust learning rate for fine-tuning and schedular for finetuning
finetune_cfg.optimizer.lr = 1.0e-5
finetune_cfg.optimizer.weight_decay = 1.0e-6
finetune_cfg.scheduler = {}
finetune_cfg.scheduler.name = 'constant_with_warmup'
finetune_cfg.scheduler.t_warmup = '0ba'

# Shorter training duration for fine-tuning
finetune_cfg.max_duration = "30ba"

# Update save folder
finetune_cfg.save_folder = "./checkpoints/finetuned_{run_name}"

print("Fine-tuning configuration:")
print(om.to_yaml(finetune_cfg))

Fine-tuning configuration:
seed: 777
device_train_batch_size: 100
global_train_batch_size: 100
device_eval_batch_size: 100
device_train_microbatch_size: auto
vocabulary:
  remote: s3://tahoe-hackathon-data/MFM/vevo_v2_vocab.json
  local: vocab.json
model:
  name: tahoex
  d_model: 512
  n_layers: 12
  init_device: cpu
  expansion_ratio: 4
  standard_scale_outputs: false
  transformer_activation: relu
  n_heads: 8
  norm_scheme: pre
  use_generative_training: true
  use_cell_conditioned_generation: false
  use_glu: false
  cell_emb_style: cls
  attn_config:
    attn_impl: flash
    attn_type: grouped_query_attention
    kv_nheads: 8
    attn_pdrop: 0.0
    use_attn_mask: false
  norm_config:
    norm_type: layernorm
    eps: 1.0e-05
  expression_encoder:
    input_emb_style: continuous
    dropout: 0.1
    max_value: 512
    activation: relu
    use_norm: true
  gene_encoder:
    use_norm: true
  mvc:
    arch_style: inner product
    query_activation: sigmoid
    scaled_dot_product: tr

In [18]:
# Start fine-tuning
finetune_trainer = main(finetune_cfg)
print("Fine-tuning completed!")

2025-10-21 21:37:14,410: rank0[325359][MainThread]: INFO: train: Downloading vocab...


2025-10-21 21:37:17,234: rank0[325359][MainThread]: INFO: train: Setting vocab size to: 62720
2025-10-21 21:37:17,244: rank0[325359][MainThread]: INFO: train: Building DataLoaders...


File downloaded successfully to vocab.json


2025-10-21 21:37:25,160: rank0[325359][MainThread]: INFO: train: train set number of samples: 60746795
2025-10-21 21:37:26,840: rank0[325359][MainThread]: INFO: train: Validation set number of samples: 613604
2025-10-21 21:37:27,223: rank0[325359][MainThread]: INFO: tahoex.model.model: MosaicML recommends using config.init_device="meta" with Composer + FSDP for faster initialization.
2025-10-21 21:37:28,361: rank0[325359][MainThread]: INFO: train: Total parameters: 70.996993 M
2025-10-21 21:37:28,362: rank0[325359][MainThread]: INFO: train: Total trainable parameters: 70.996993 M 
2025-10-21 21:37:28,363: rank0[325359][MainThread]: INFO: train: gene_encoder: 32.113664 M parameters
2025-10-21 21:37:28,363: rank0[325359][MainThread]: INFO: train: flag_encoder: 0.001024 M parameters
2025-10-21 21:37:28,364: rank0[325359][MainThread]: INFO: train: expression_encoder: 0.264704 M parameters
2025-10-21 21:37:28,365: rank0[325359][MainThread]: INFO: train: transformer_encoder: 37.829632 M para

training duration, take note that the warmup duration is calculated in the
same unit as the trainer's max_duration parameter.
Downloading MFM/ckpts/70m/best-model.pt:   0%|          | 262k/284M [00:07<2:07:30, 37.1kiB/s]
  return torch.load(_ensure_valid_checkpoint(checkpoint_filepath), map_location=map_location)
2025-10-21 21:37:40,472: rank0[325359][MainThread]: INFO: train: Logging config


seed: 777
device_train_batch_size: 100
global_train_batch_size: 100
device_eval_batch_size: 100
device_train_microbatch_size: auto
vocabulary:
  remote: s3://tahoe-hackathon-data/MFM/vevo_v2_vocab.json
  local: vocab.json
model:
  name: tahoex
  d_model: 512
  n_layers: 12
  init_device: cpu
  expansion_ratio: 4
  standard_scale_outputs: false
  transformer_activation: relu
  n_heads: 8
  norm_scheme: pre
  use_generative_training: true
  use_cell_conditioned_generation: false
  use_glu: false
  cell_emb_style: cls
  attn_config:
    attn_impl: flash
    attn_type: grouped_query_attention
    kv_nheads: 8
    attn_pdrop: 0.0
    use_attn_mask: false
  norm_config:
    norm_type: layernorm
    eps: 1.0e-05
  expression_encoder:
    input_emb_style: continuous
    dropout: 0.1
    max_value: 512
    activation: relu
    use_norm: true
  gene_encoder:
    use_norm: true
  mvc:
    arch_style: inner product
    query_activation: sigmoid
    scaled_dot_product: true
  expression_decoder:
  

2025-10-21 21:37:41,201: rank0[325359][MainThread]: INFO: train: Starting training...
******************************
Config:
composer_commit_hash: None
composer_version: 0.28.0
enabled_algorithms/GradientClipping: true
enabled_algorithms/LowPrecisionLayerNorm: true
node_name: unknown because NODENAME environment variable not set
num_gpus_per_node: 1
num_nodes: 1
rank_zero_seed: 777
time/remaining_estimate_unit: hours

******************************
training duration, take note that the warmup duration is calculated in the
same unit as the trainer's max_duration parameter.
[batch=1/30]:
	 Train time/epoch: 0
	 Train time/batch: 0
	 Train time/sample: 0
	 Train time/batch_in_epoch: 0
	 Train time/sample_in_epoch: 0
	 Train memory/current_allocated_mem: 1.8710
	 Train memory/current_active_mem: 1.8710
	 Train memory/current_inactive_mem: 4.2276
	 Train memory/current_reserved_mem: 20.5190
	 Train memory/peak_allocated_mem: 40.4660
	 Train memory/peak_active_mem: 40.4660
	 Train memory/pea

Fine-tuning completed!


### Tips

1. You can monitor your training in **Weights & Biases**
2. Model checkpoints are saved according to `save_interval`
3. IF you encounter OOM issues try reducing the `device_train_batch_size`
4. If you are using a single GPU you can remove the `fsdp_config` from your custom configuration.
5. Ensure `attn_impl: flash` and `use_attn_mask: False` as Triton backend is no longer supported by our codebase (Email us if you have questions on how to use triton backend with custom attn masking)
6. You can add `cell_classification` and `marginal_essentiality` callbacks to the configuration files so that the model will authomathically be evaluated on these benchmarks. (adding some samples is TODO)

7. After training you can:
    1. **Prepare model for inference**: Use `scripts/prepare_for_inference.py`
    2. **Extract cell and gene embeddings**: See `scripts/clustering_tutorial.ipynb` and `inference.predict_embeddings`
    3. **Run benchmarks**: See `scripts/depmap/` and `scripts/msigdb/`
    4. **Upload to HuggingFace**: For sharing your trained model

For more details, refer to the [README.md](../README.md)