# DyAb Model: Training and Inference Tutorial
# ====================================
#
# This notebook demonstrates how to use the DataFrameLightningDataModule with the DyAbModel
# for training and performing inference on antibody data.

In [None]:
import pandas as pd
import numpy as np
import torch
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score

# Import the necessary lobster modules
from lobster.data import DyAbDataFrameLightningDataModule
from lobster.tokenization import AminoAcidTokenizerFast
from lobster.transforms import TokenizerTransform

# For reproducibility
SEED = 42
pl.seed_everything(SEED, workers=True)
device = torch.device('cuda')

## 1. Create a Sample Dataset

For this tutorial, we'll create a synthetic dataset to demonstrate the functionality.
In real applications, you would use your own antibody data.

The DyAb model is trained to predict the difference between properties (y1 - y2),
so we need to create our data accordingly. DyAb is also designed to be trained on
sequences that are 1 edit-distance apart, we can ignore this in our toy example, but
we will make the sequences all be the same length to ensure we train a reasonable model. 

In [None]:
def generate_sample_data(num_samples=500, seed=42):
    """Generate synthetic antibody data for demonstration with DyAb model."""
    np.random.seed(seed)
    
    # Create random heavy and light chain sequences
    amino_acids = list("ARNDCEQGHILKMFPSTWYV")
    
    heavy_length = np.random.randint(110, 130, 1)[0]
    light_length = np.random.randint(105, 115, 1)[0]
    
    heavy_chains = []
    light_chains = []
    pkd_values = []
    
    for i in range(num_samples):
        heavy = ''.join(np.random.choice(amino_acids, heavy_length))
        light = ''.join(np.random.choice(amino_acids, light_length))
        heavy_chains.append(heavy)
        light_chains.append(light)

        # Create synthetic pKD values (binding affinity)
        # For this toy example, we'll make more A's in the hc and M's in the lc 
        # mean high affinity
        pkd_values.append((heavy.count('A') + light.count('M'))/2 + np.random.normal(0, 0.5, 1)[0])

    
    # Create a DataFrame
    df = pd.DataFrame({
        'fv_heavy': heavy_chains,
        'fv_light': light_chains,
        'pKD': pkd_values
    })
    
    return df

# Generate sample data
df = generate_sample_data(500)

## 2. Set Up the DyAbDataFrameLightningDataModule

The DyAb model expects data in the form of paired sequences with their corresponding
target values, where during training it learns to predict the differences between pairs.

Let's set up the data module:

In [None]:
# Initialize the tokenizer transform
tokenizer = AminoAcidTokenizerFast()
transform_fn = TokenizerTransform(
    tokenizer=tokenizer,
    padding="max_length",
    max_length=256,
    truncation=True,
    return_attention_mask=True
)

# Initialize the DyAb datamodule
dyab_datamodule = DyAbDataFrameLightningDataModule(
    data=df,
    remove_nulls=True,
    transform_fn=transform_fn,
    lengths=[0.8, 0.1, 0.1],  # Train, val, test split
    batch_size=16,
    seed=SEED,
    num_workers=4,
    max_length=256
)

# Set up the datamodule
dyab_datamodule.prepare_data()
dyab_datamodule.setup(stage="fit")

# Let's examine what a batch from the dataloader looks like
train_dataloader = dyab_datamodule.train_dataloader()
for batch in train_dataloader:
    sequence1, sequence2, target1, target2 = batch
    print("Sequence1 shape:", np.array(sequence1[0]).shape)  # First element of the tuple for heavy chain
    print("Sequence2 shape:", np.array(sequence2[0]).shape)  # First element of the tuple for heavy chain
    print("Target1 shape:", target1.shape)
    print("Target2 shape:", target2.shape)
    print("Target difference (what model will predict):", (target1 - target2).mean().item())
    break

## 3. Using the DyAb Model

Now we'll initialize and train the DyAb model. We'll use the FlexBERT architecture
which is designed to handle protein sequences.

In [None]:
# Import the DyAb model
from lobster.model import DyAbModel

# Initialize the model
dyab_model = DyAbModel(
    model_name="esm2_t6_8M_UR50D",
    embedding_img_size=224,
    diff_channel_0="diff",
    diff_channel_1="diff",
    diff_channel_2="diff"
)

# Define callbacks for training
callbacks = [
    ModelCheckpoint(
        dirpath='checkpoints/',
        filename='dyab-{epoch:02d}-{val_loss:.4f}',
        save_top_k=3,
        monitor='val/loss',
        mode='min'
    ),
    EarlyStopping(
        monitor='val/loss',
        patience=10,
        mode='min'
    )
]

# Initialize the trainer
trainer = pl.Trainer(
    max_epochs=100,
    callbacks=callbacks,
    accelerator='auto',
    devices=1,
    log_every_n_steps=10
)

## 4. Train the Model

Now let's train the DyAb model using our datamodule.

In [None]:
# Train the model
trainer.fit(dyab_model, datamodule=dyab_datamodule)

## 5. Evaluate the Model

Let's evaluate our model on the test set.

In [None]:
# Test the model
test_results = trainer.test(dyab_model, datamodule=dyab_datamodule)
print(f"Test results: {test_results}")

In [None]:
# Visualize embeddings
for emb in dyab_model.embedding_cache:
    print(dyab_model.embedding_cache[emb].shape)
    if dyab_model.embedding_cache[emb].dim() == 3:
        plt.imshow(dyab_model.embedding_cache[emb][0].cpu())
    else:
        plt.imshow(dyab_model.embedding_cache[emb].cpu())
    plt.title('antiberty_emb')
    plt.show()
    break

for batch in dyab_datamodule.train_dataloader():
    sequences1, sequences2, y1, y2 = batch
    actual_diff = (y1 - y2).float()

    if len(sequences1) == 2:  # concat multiple chains
        sequences1 = [s1 + "." + s2 for s1, s2 in zip(*sequences1)]
    if len(sequences2) == 2:  # concat multiple chains
        sequences2 = [s1 + "." + s2 for s1, s2 in zip(*sequences2)]

    for seq1, seq2 in zip(sequences1, sequences2):
        if seq1 not in dyab_model.embedding_cache:
            with torch.inference_mode():
                hidden_states = dyab_model.model.sequences_to_latents([seq1])[-2].to(device).float()
            dyab_model.embedding_cache[seq1] = hidden_states
        if seq2 not in dyab_model.embedding_cache:
            with torch.inference_mode():
                hidden_states = dyab_model.model.sequences_to_latents([seq2])[-2].to(device).float()
            dyab_model.embedding_cache[seq2] = hidden_states

        embeddings1 = torch.concat([dyab_model.embedding_cache[seq] for seq in sequences1], dim=0).to(dyab_model.device)
        embeddings2 = torch.concat([dyab_model.embedding_cache[seq] for seq in sequences2], dim=0).to(dyab_model.device)

    embedding_image = dyab_model._resize_embeddings(embeddings1, embeddings2)

    numpy_img = embedding_image[0].cpu().numpy()
    numpy_img = np.transpose(numpy_img, (1, 2, 0))

    plt.imshow(numpy_img)
    plt.title("embedding_diff")
    plt.show()

    break
    

## Visualize an embedding

DyAb uses embeddings generated by another protein language model to predict the difference between two
sequences. Let's visualize one of these embeddings

## 6. Perform Inference on New Sequences

Now let's demonstrate how to use the trained model for inference on new antibody sequences.

In [None]:
def predict_binding_differences(model, new_data):
    """
    Predict binding affinity differences for pairs of antibody sequences.
    
    Args:
        model: Trained DyAb model
        new_data: DataFrame with fv_heavy, fv_light, pKD columns
        
    Returns:
        DataFrame with predicted differences
    """
    # Create a datamodule for inference
    inference_datamodule = DyAbDataFrameLightningDataModule(
        data=new_data,
        remove_nulls=True,
        transform_fn=transform_fn,
        lengths=[0, 0, 1],  # All data for inference
        batch_size=16,
        seed=SEED,
        num_workers=4,
        max_length=256
    )
    
    # Set up the datamodule
    inference_datamodule.prepare_data()
    inference_datamodule.setup(stage="predict")
    
    # Perform inference
    model.eval()
    predictions = []
    sequence1_data = []
    sequence2_data = []
    actual_diffs = []
    
    with torch.no_grad():
        for idx, batch in enumerate(inference_datamodule.predict_dataloader()):
            sequence1, sequence2, target1, target2 = batch
            
            preds, _ = model.predict_step(batch, idx)

            # y_hat, _ = trainer.predict(model, inference_datamodule)[idx]

            # Process through the model
            # pred_diff = model(sequence1, sequence2)
            actual_diff = (target1 - target2).cpu().numpy()
            predictions.extend(preds.cpu().numpy())
            actual_diffs.extend(actual_diff)
            
            # Store sequence information for reference
            for i in range(len(target1)):
                sequence1_data.append({
                    "heavy": sequence1[0][i],
                    "light": sequence1[1][i],
                    "target": target1[i].item()
                })
                sequence2_data.append({
                    "heavy": sequence2[0][i],
                    "light": sequence2[1][i],
                    "target": target2[i].item()
                })            
    
    # Create results DataFrame
    results = pd.DataFrame({
        "Predicted_Difference": predictions,
        "Actual_Difference": actual_diffs,
    })
    
    return results

# Generate some new sequences for inference
new_data = generate_sample_data(100)

# Predict binding affinity differences
results_df = predict_binding_differences(dyab_model, new_data)

# Display results
print(results_df.head(10))

## 7. Visualize Predictions vs. Actual Differences

Let's visualize how well our model's predictions match the actual differences.

In [None]:
# Plot actual vs. predicted differences
plt.figure(figsize=(8, 6))
plt.scatter(results_df["Actual_Difference"], results_df["Predicted_Difference"], alpha=0.7)
plt.plot(
    [min(results_df["Actual_Difference"]), max(results_df["Actual_Difference"])], 
    [min(results_df["Actual_Difference"]), max(results_df["Actual_Difference"])], 
    'r--'
)
plt.xlabel('Actual Difference')
plt.ylabel('Predicted Difference')
plt.title('Predicted vs. Actual Affinity Differences')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Calculate R² score
r2 = r2_score(results_df["Actual_Difference"], results_df["Predicted_Difference"])
print(f"R² score: {r2:.4f}")

## 8. Using the Model for Ranking Antibodies

One valuable application of the DyAb model is to rank antibodies based on predicted affinities.
Let's demonstrate how to use the model for this purpose.

In [None]:
def rank_antibodies(model, antibody_pool, reference_antibody):
    """
    Rank a pool of antibodies against a reference antibody based on predicted binding affinity.
    
    Args:
        model: Trained DyAb model
        antibody_pool: DataFrame with fv_heavy, fv_light, pKD columns for candidate antibodies
        reference_antibody: Dict with fv_heavy, fv_light for the reference antibody
        
    Returns:
        DataFrame with ranked antibodies
    """
    # Create pairs of reference and candidate antibodies
    pairs_data = []
    
    for _, candidate in antibody_pool.iterrows():
        pairs_data.append({
            "fv_heavy": reference_antibody["fv_heavy"],
            "fv_light": reference_antibody["fv_light"],
            "pKD": 0.0,  # Dummy value for reference
            "candidate_heavy": candidate["fv_heavy"],
            "candidate_light": candidate["fv_light"],
            "candidate_pKD": candidate["pKD"]
        })
    
    pairs_df = pd.DataFrame(pairs_data)
    
    # Create a combined dataset for DyAb format
    # We need to duplicate each row and swap the order to get predictions in both directions
    combined_data = []
    
    for _, row in pairs_df.iterrows():
        # Reference first, candidate second
        combined_data.append({
            "fv_heavy": row["fv_heavy"],
            "fv_light": row["fv_light"],
            "pKD": row["pKD"]
        })
        
        # Candidate first, reference second
        combined_data.append({
            "fv_heavy": row["candidate_heavy"],
            "fv_light": row["candidate_light"],
            "pKD": row["candidate_pKD"]
        })
    
    combined_df = pd.DataFrame(combined_data)
    
    # Create a datamodule
    inference_datamodule = DyAbDataFrameLightningDataModule(
        data=combined_df,
        remove_nulls=True,
        transform_fn=transform_fn,
        lengths=[0, 0, 1],  # All data for inference
        batch_size=16,
        seed=SEED,
        num_workers=4,
        max_length=256
    )
    
    # Set up the datamodule
    inference_datamodule.prepare_data()
    inference_datamodule.setup(stage="predict")
    
    # Perform inference
    model.eval()
    all_preds = []
    
    with torch.no_grad():
        for batch in inference_datamodule.predict_dataloader():
            # pred_diff = model(sequence1, sequence2)
            preds = dyab_model.predict(batch)

            all_preds.extend(preds.cpu().numpy())
    
    # Process predictions (every second prediction is for ref-candidate pair)
    ref_candidate_preds = all_preds[::2]
    
    # Add predictions to the original dataframe
    for i, pred in enumerate(ref_candidate_preds):
        pairs_df.loc[i, "predicted_diff_ref_candidate"] = pred
    
    # Sort by prediction (higher predicted difference means candidate is better than reference)
    ranked_df = pairs_df.sort_values(by="predicted_diff_ref_candidate", ascending=False)
    
    # Add rank column
    ranked_df["rank"] = range(1, len(ranked_df) + 1)
    
    return ranked_df[["rank", "candidate_heavy", "candidate_light", "candidate_pKD", "predicted_diff_ref_candidate"]]

# Generate a pool of candidate antibodies
candidate_pool = generate_sample_data(100)

# Select a reference antibody
reference_antibody = {
    "fv_heavy": df.iloc[0]["fv_heavy"],
    "fv_light": df.iloc[0]["fv_light"]
}

# Rank antibodies
ranked_candidates = rank_antibodies(dyab_model, candidate_pool, reference_antibody)

# Display top 10 ranked antibodies
print("Top 10 Ranked Antibodies:")
print(ranked_candidates.head(10))

## 9. Saving and Loading the Model

Demonstrating how to save and load a trained DyAb model.

In [None]:
# Save the model
model_path = "dyab_model.ckpt"
trainer.save_checkpoint(model_path)

# Load the model
loaded_model = DyAbModel.load_from_checkpoint(model_path)

# Verify the loaded model
loaded_model.eval()
with torch.no_grad():
    # Get a batch from the dataloader
    for batch in dyab_datamodule.val_dataloader():
        sequence1, sequence2, target1, target2 = batch
        
        # Make predictions with both models
        # original_pred = dyab_model(sequence1, sequence2)
        with torch.inference_mode():
            output = dyab_model._compute_loss(batch)
        assert output is not None
        _, original_pred, _ = output
        
        # loaded_pred = loaded_model(sequence1, sequence2)
        with torch.inference_mode():
            output = loaded_model._compute_loss(batch)
        assert output is not None
        _, loaded_pred, _ = output
        
        # Verify they give the same predictions
        print("Original model prediction mean:", original_pred.mean().item())
        print("Loaded model prediction mean:", loaded_pred.mean().item())
        
        # Check if predictions are identical
        is_identical = torch.allclose(original_pred, loaded_pred, atol=1e-5)
        print(f"Models give identical predictions: {is_identical}")
        break

## 10. Conclusion

In this tutorial, we've demonstrated how to:

1. Prepare antibody data for the DyAb model
2. Train a DyAb model using DataFrameLightningDataModule
3. Evaluate the model's performance
4. Use the model for inference and ranking of antibodies
5. Save and load the model

The DyAb model is particularly useful for learning pairwise relationships between sequences,
which makes it valuable for tasks like antibody optimization, where you want to predict if
a modification to a sequence will improve its properties.