# Tutorial 5: TCR-to-Gene Expression Prediction

Fine-tune TCRfoundation for cross-modal prediction: predicting gene expression from TCR sequences.

## 1. Setup

In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import json
import torch
import numpy as np
import pandas as pd
import scanpy as sc
import tcrfoundation as tcrf

# Set random seeds for reproducibility
np.random.seed(0)
torch.manual_seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)

## 2. Configuration

In [2]:
# Configuration
config = {
    "checkpoint_path": "../TCR_foundation_model/foundation_model_best.pt",
    "num_epochs": 2,
    "batch_size": 512,
    "modalities": None,  # Train all modalities: tcr_only, tcra_only, tcrb_only
    "val_split": 0.2,
    "test_split": 0.2,
    "save_splits": False,
    "save_predictions": False,
    "task": "TCR2gene"
}

results_dir = f"../results/{config['task']}"
os.makedirs(results_dir, exist_ok=True)

# Save configuration
with open(f"{results_dir}/config.json", 'w') as f:
    json.dump(config, f, indent=2)

print(f"Task: {config['task']}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
print(f"Results directory: {results_dir}")

Task: TCR2gene
Device: cuda
Results directory: ../results/TCR2gene


## 3. Load Data

In [3]:
adata = sc.read("../data/adata_finetune.h5ad")
print(f"Dataset: {adata.n_obs} cells, {adata.n_vars} genes")
print(f"\nTCR sequences:")
print(f"  CDR3a: {adata.obs['CDR3a'].notna().sum()} available")
print(f"  CDR3b: {adata.obs['CDR3b'].notna().sum()} available")

Dataset: 444979 cells, 3000 genes

TCR sequences:
  CDR3a: 444979 available
  CDR3b: 444979 available


## 4. Train Cross-Modal Regressor

This trains models to predict gene expression from TCR sequences:
- TCR α+β → Gene expression
- TCR α only → Gene expression
- TCR β only → Gene expression

In [4]:
print(f"\n=== Training {config['task']} Regressor ===")

results, adata_with_predictions = tcrf.finetune.cross_modal.train_regressor(
    adata,
    checkpoint_path=config["checkpoint_path"],
    num_epochs=config["num_epochs"],
    batch_size=config["batch_size"],
    modalities=config["modalities"],
    val_split=config["val_split"],
    test_split=config["test_split"],
    save_splits=config["save_splits"],
    save_predictions=config["save_predictions"]
)

print("\n✓ Training complete")


=== Training TCR2gene Regressor ===
Loaded model with max_length: 30

=== Training tcr_only regressor ===


Epoch 1/2 (Train): 100%|██████████████████████████████████| 522/522 [00:21<00:00, 23.75it/s]
Epoch 1/2 (Val): 100%|████████████████████████████████████| 174/174 [00:03<00:00, 45.97it/s]


Epoch 1/2 - Train Loss: 0.040502, Val Loss: 0.039059, Train R²: -103.7333, Val R²: -0.2391
Saved best model checkpoint (Val Loss: 0.039059)


Epoch 2/2 (Train): 100%|██████████████████████████████████| 522/522 [00:22<00:00, 23.43it/s]
Epoch 2/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 72.20it/s]


Epoch 2/2 - Train Loss: 0.039063, Val Loss: 0.038805, Train R²: -10.9746, Val R²: -0.6951
Saved best model checkpoint (Val Loss: 0.038805)

Evaluating on test set with best model...


Testing: 100%|████████████████████████████████████████████| 174/174 [00:03<00:00, 44.76it/s]


Test Loss: 0.038711, Test MSE: 0.038711, Test R²: -0.7039

=== Training tcra_only regressor ===


Epoch 1/2 (Train): 100%|██████████████████████████████████| 522/522 [00:16<00:00, 31.67it/s]
Epoch 1/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 73.45it/s]


Epoch 1/2 - Train Loss: 0.038970, Val Loss: 0.038550, Train R²: -17.3715, Val R²: -1.0270
Saved best model checkpoint (Val Loss: 0.038550)


Epoch 2/2 (Train): 100%|██████████████████████████████████| 522/522 [00:14<00:00, 37.17it/s]
Epoch 2/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 74.99it/s]


Epoch 2/2 - Train Loss: 0.038423, Val Loss: 0.038114, Train R²: -3.4055, Val R²: -1.1113
Saved best model checkpoint (Val Loss: 0.038114)

Evaluating on test set with best model...


Testing: 100%|████████████████████████████████████████████| 174/174 [00:02<00:00, 74.76it/s]


Test Loss: 0.038035, Test MSE: 0.038034, Test R²: -1.0393

=== Training tcrb_only regressor ===


Epoch 1/2 (Train): 100%|██████████████████████████████████| 522/522 [00:14<00:00, 35.80it/s]
Epoch 1/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 74.57it/s]


Epoch 1/2 - Train Loss: 0.038810, Val Loss: 0.038457, Train R²: -11.7012, Val R²: -0.6569
Saved best model checkpoint (Val Loss: 0.038457)


Epoch 2/2 (Train): 100%|██████████████████████████████████| 522/522 [00:14<00:00, 37.16it/s]
Epoch 2/2 (Val): 100%|████████████████████████████████████| 174/174 [00:02<00:00, 71.41it/s]


Epoch 2/2 - Train Loss: 0.038336, Val Loss: 0.038090, Train R²: -2.8365, Val R²: -1.1950
Saved best model checkpoint (Val Loss: 0.038090)

Evaluating on test set with best model...


Testing: 100%|████████████████████████████████████████████| 174/174 [00:02<00:00, 71.89it/s]


Test Loss: 0.038028, Test MSE: 0.038028, Test R²: -1.1855

✓ Training complete


## 5. Performance Summary

In [5]:
# Create summary dataframe
summary_rows = []
for mode in results:
    for split in ["train", "val", "test"]:
        metrics = results[mode][split]
        row = {
            "modality": mode,
            "split": split,
            "loss": metrics["loss"],
            "mse": metrics["mse"],
            "r2": metrics["r2"]
        }
        summary_rows.append(row)

summary_df = pd.DataFrame(summary_rows)
summary_df.to_csv(f"{results_dir}/{config['task']}_summary.csv", index=False)

print("\n" + "="*60)
print(f"{config['task']} Results Summary")
print("="*60)
print(summary_df.to_string())
print(f"\n✓ Summary saved to: {results_dir}/{config['task']}_summary.csv")


TCR2gene Results Summary
    modality  split      loss       mse         r2
0   tcr_only  train  0.039063  0.039061 -10.974631
1   tcr_only    val  0.038805  0.038805  -0.695051
2   tcr_only   test  0.038711  0.038711  -0.703923
3  tcra_only  train  0.038423  0.038421  -3.405451
4  tcra_only    val  0.038114  0.038114  -1.111292
5  tcra_only   test  0.038035  0.038034  -1.039305
6  tcrb_only  train  0.038336  0.038335  -2.836521
7  tcrb_only    val  0.038090  0.038090  -1.195003
8  tcrb_only   test  0.038028  0.038028  -1.185543

✓ Summary saved to: ../results/TCR2gene/TCR2gene_summary.csv
