# Tutorial 4: Binding Avidity Regression

Fine-tune TCRfoundation for predicting TCR-antigen binding avidity (quantitative binding strength).

## 1. Setup

In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import scanpy as sc
import tcrfoundation as tcrf
from tcrfoundation.finetune.avidity import (
    train_binding_counts_regressor,
    build_regression_results_dataframe,
    plot_regression_metrics_charts
)

## 2. Configuration

In [2]:
checkpoint_path = "../TCR_foundation_model/foundation_model_best.pt"
results_dir = "../results/binding_counts"
num_epochs = 2 # Just for demonstration. When training from scratch, please set the epoachs as 50.
batch_size = 128

os.makedirs(results_dir, exist_ok=True)

## 3. Load Data

We need to combine data from two files:
- `adata_avidity.h5ad`: Contains binding counts
- `speci_adata.h5ad`: Contains TCR sequences and gene expression

In [3]:
# Load binding counts data
adata_avidity = sc.read("../data/adata_avidity.h5ad")
print(f"Avidity data: {adata_avidity.n_obs} cells")

# Load TCR and gene expression data
adata = sc.read("../data/speci_adata.h5ad")
print(f"Specificity data: {adata.n_obs} cells, {adata.n_vars} genes")

# Transfer binding counts and splits to main adata
adata.obsm['binding_counts'] = adata_avidity.obsm['binding_counts']
adata.obs['set'] = adata_avidity.obs['set']

# Check data
n_antigens = adata.obsm['binding_counts'].shape[1]
print(f"\nBinding counts for {n_antigens} antigens")
print(f"Data splits: {adata.obs['set'].value_counts().to_dict()}")

Avidity data: 60114 cells
Specificity data: 60114 cells, 3000 genes

Binding counts for 8 antigens
Data splits: {'train': 38472, 'test': 12023, 'val': 9619}


## 4. Train Avidity Regressor

In [4]:
print("\n=== Training Binding Avidity Regressor ===")

results = train_binding_counts_regressor(
    adata,
    checkpoint_path=checkpoint_path,
    num_epochs=num_epochs,
    batch_size=batch_size,
    return_training_history=False
)

print("\n✓ Training complete")


=== Training Binding Avidity Regressor ===
Output dimension: 8

Loaded model with max_length: 30


                                                                                            

Mode rna_only Epoch 1/2: Train Loss = 2868.7943 | Val Loss = 2214.6270 | Val R² = -0.1805
--> Best model saved with Val R² = -0.1805


                                                                                            

Mode rna_only Epoch 2/2: Train Loss = 2833.1833 | Val Loss = 2176.2131 | Val R² = -0.1682
--> Best model saved with Val R² = -0.1682

Loaded model with max_length: 30


                                                                                            

Mode tcr_only Epoch 1/2: Train Loss = 2810.2117 | Val Loss = 2118.3756 | Val R² = -0.1730
--> Best model saved with Val R² = -0.1730


                                                                                            

Mode tcr_only Epoch 2/2: Train Loss = 2724.8559 | Val Loss = 2033.9279 | Val R² = -0.1436
--> Best model saved with Val R² = -0.1436

Loaded model with max_length: 30


                                                                                            

Mode tcra_only Epoch 1/2: Train Loss = 2841.1611 | Val Loss = 2168.9397 | Val R² = -0.1930
--> Best model saved with Val R² = -0.1930


                                                                                            

Mode tcra_only Epoch 2/2: Train Loss = 2788.4668 | Val Loss = 2117.1266 | Val R² = -0.1762
--> Best model saved with Val R² = -0.1762

Loaded model with max_length: 30


                                                                                            

Mode tcrb_only Epoch 1/2: Train Loss = 2847.3957 | Val Loss = 2179.0847 | Val R² = -0.1914
--> Best model saved with Val R² = -0.1914


                                                                                            

Mode tcrb_only Epoch 2/2: Train Loss = 2801.2477 | Val Loss = 2130.1765 | Val R² = -0.1722
--> Best model saved with Val R² = -0.1722

Loaded model with max_length: 30


                                                                                            

Mode rna_tcr Epoch 1/2: Train Loss = 2800.0972 | Val Loss = 2115.3714 | Val R² = -0.1498
--> Best model saved with Val R² = -0.1498


                                                                                            

Mode rna_tcr Epoch 2/2: Train Loss = 2710.2179 | Val Loss = 2033.8319 | Val R² = -0.1123
--> Best model saved with Val R² = -0.1123

✓ Training complete


## 5. Performance Summary

In [6]:
print("\n" + "="*60)
print("Regression Performance by Modality")
print("="*60)
print(df)

print("\n" + "="*60)
print("Test Set Performance")
print("="*60)
for mode in results.keys():
    test_metrics = results[mode]['test']
    r2 = test_metrics['avg_r2']
    mse = test_metrics['avg_mse']
    mae = test_metrics['avg_mae']
    print(f"{mode:15s}: R²={r2:.3f}, MSE={mse:.4f}, MAE={mae:.4f}")


Regression Performance by Modality
         Mode  Split        R²          MSE        MAE      MSLE
0    rna_only  train -0.164948  2807.111328  12.578028  2.224296
1    rna_only    val -0.168183  2176.213135  12.488777  2.215775
2    rna_only   test -0.165725  2006.111572  12.432177  2.236494
3    tcr_only  train -0.143940  2666.379150  12.329477  2.273770
4    tcr_only    val -0.143627  2033.927979  12.240131  2.261877
5    tcr_only   test -0.141133  1874.746582  12.190807  2.291583
6   tcra_only  train -0.174499  2751.299072  12.522174  2.580707
7   tcra_only    val -0.176214  2117.126465  12.430248  2.561954
8   tcra_only   test -0.174020  1952.712524  12.381844  2.603213
9   tcrb_only  train -0.171604  2762.735596  12.452492  2.475381
10  tcrb_only    val -0.172159  2130.176514  12.357960  2.460135
11  tcrb_only   test -0.169991  1964.404053  12.308035  2.493986
12    rna_tcr  train -0.111450  2665.060791  12.147854  1.886841
13    rna_tcr    val -0.112304  2033.831909  12.066460