# TP2 - Part 3: Classification with Embeddings

**Day 2 - AI for Sciences Winter School**

**Instructor:** Raphael Cousin

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/racousin/ai_for_sciences/blob/main/day2/tp2_part3.ipynb)

---

## The Task: Predicting Blood-Brain Barrier Permeability

The **blood-brain barrier (BBB)** is a selective barrier that protects the brain from toxins in the bloodstream. For drug development, it's crucial to know whether a molecule can cross this barrier:

- **BBB-permeable**: Drug can enter the brain (needed for CNS drugs)
- **BBB-impermeable**: Drug cannot enter (needed for peripherally-acting drugs)

**Our Goal**: Build a classifier that predicts BBB permeability using molecular embeddings.

## Objectives

1. Understand how to **use embeddings as features** for classification
2. Train a simple **MLP classifier** on molecular embeddings
3. Evaluate model performance and understand what the model learns
4. Experiment with different configurations

---

# Part 1: The Classification Pipeline

The key insight: **Embeddings transform the problem.**

Instead of teaching a model to understand molecular structure from scratch, we:

1. Use a **pre-trained model** (ChemBERTa) to convert SMILES ‚Üí embeddings
2. Train a **simple classifier** on the embeddings

```
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ   SMILES    ‚îÇ ‚Üí  ‚îÇ   ChemBERTa    ‚îÇ ‚Üí  ‚îÇ  768-dim    ‚îÇ ‚Üí  ‚îÇ   Classifier ‚îÇ
‚îÇ   string    ‚îÇ    ‚îÇ  (frozen)      ‚îÇ    ‚îÇ  embedding  ‚îÇ    ‚îÇ   (trained)  ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
                    Pre-trained                              BBB_permeable
                    knowledge                                or BBB_impermeable
```

This is called **transfer learning**: leveraging knowledge learned from one task (molecular language modeling) for another task (BBB prediction).

## Setup

In [None]:
# Install packages
!pip install -q transformers torch pandas matplotlib scikit-learn

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import warnings
warnings.filterwarnings('ignore')

# Check GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
print("Setup complete!")

---

# Part 2: Preparing the Data

## Load the BBBP Dataset

In [None]:
# Load dataset
url = "https://raw.githubusercontent.com/racousin/ai_for_sciences/main/day2/data/molecules_bbbp.csv"
df = pd.read_csv(url)

print(f"Dataset shape: {df.shape}")
print(f"\nClass distribution:")
print(df['label_name'].value_counts())
print(f"\nClass imbalance ratio: {df['label'].mean():.2%} permeable")

### ü§î Question 1

Look at the class distribution:

1. Are the classes balanced?
2. What would be the accuracy of a model that always predicts "permeable"?
3. Why is this baseline important to know?

## Compute Molecular Embeddings

We'll use ChemBERTa to convert SMILES strings to embeddings.

**Note**: Computing embeddings for the full dataset takes time. In practice, you'd cache these.

In [None]:
from transformers import AutoTokenizer, AutoModel

# Load ChemBERTa
print("Loading ChemBERTa...")
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model.eval()
model.to(device)

print(f"Model loaded on {device}")
print(f"Embedding dimension: {model.config.hidden_size}")

In [None]:
def compute_embeddings(smiles_list, batch_size=32):
    """Compute embeddings for a list of SMILES strings."""
    embeddings = []
    
    for i in range(0, len(smiles_list), batch_size):
        batch = smiles_list[i:i+batch_size]
        
        # Tokenize
        inputs = tokenizer(batch, return_tensors="pt", 
                          padding=True, truncation=True, max_length=512)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Get embeddings
        with torch.no_grad():
            outputs = model(**inputs)
            # Use [CLS] token embedding
            batch_emb = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            embeddings.append(batch_emb)
        
        if (i // batch_size) % 10 == 0:
            print(f"  Processed {min(i+batch_size, len(smiles_list))}/{len(smiles_list)} molecules")
    
    return np.vstack(embeddings)

# Compute embeddings for all molecules
print("Computing embeddings...")
X = compute_embeddings(df['SMILES'].tolist())
y = df['label'].values

print(f"\nEmbedding matrix shape: {X.shape}")
print(f"Labels shape: {y.shape}")

## Split Data into Train and Test Sets

In [None]:
# Split data: 80% train, 20% test
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"Training set: {X_train.shape[0]} samples")
print(f"Test set: {X_test.shape[0]} samples")
print(f"\nTrain class distribution: {y_train.mean():.2%} permeable")
print(f"Test class distribution: {y_test.mean():.2%} permeable")

---

# Part 3: Building the Classifier

We'll build a simple **Multi-Layer Perceptron (MLP)**:

```
Input (768) ‚Üí Hidden (256) ‚Üí ReLU ‚Üí Hidden (64) ‚Üí ReLU ‚Üí Output (1) ‚Üí Sigmoid
```

This is a small network that learns to map embeddings to BBB permeability.

In [None]:
class BBBClassifier(nn.Module):
    """Simple MLP classifier for BBB permeability prediction."""
    
    def __init__(self, input_dim, hidden_dim=256, dropout=0.2):
        super().__init__()
        
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 4, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.network(x)

# Create model
input_dim = X_train.shape[1]  # 768
classifier = BBBClassifier(input_dim=input_dim, hidden_dim=256, dropout=0.2)
classifier.to(device)

# Count parameters
n_params = sum(p.numel() for p in classifier.parameters())
print(f"Model architecture:")
print(classifier)
print(f"\nTotal parameters: {n_params:,}")

### ü§î Question 2

Look at the model architecture:

1. Why do we use ReLU activation after hidden layers but Sigmoid at the output?
2. What does Dropout do? Why might it help?
3. The model has ~215K parameters. ChemBERTa has ~85M. What does this tell us about transfer learning?

## Prepare Data for Training

In [None]:
# Convert to PyTorch tensors
X_train_t = torch.FloatTensor(X_train)
y_train_t = torch.FloatTensor(y_train).unsqueeze(1)
X_test_t = torch.FloatTensor(X_test)
y_test_t = torch.FloatTensor(y_test).unsqueeze(1)

# Create data loaders
batch_size = 64  # <-- You can experiment with this!

train_dataset = TensorDataset(X_train_t, y_train_t)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = TensorDataset(X_test_t, y_test_t)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

print(f"Training batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

---

# Part 4: Training the Classifier

The training loop:
1. **Forward pass**: Compute predictions
2. **Compute loss**: Binary Cross-Entropy
3. **Backward pass**: Compute gradients
4. **Update weights**: Optimizer step

In [None]:
# Training configuration
learning_rate = 0.001  # <-- You can experiment with this!
n_epochs = 50          # <-- You can experiment with this!

# Loss and optimizer
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)

print(f"Training configuration:")
print(f"  Learning rate: {learning_rate}")
print(f"  Epochs: {n_epochs}")
print(f"  Batch size: {batch_size}")

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    
    for X_batch, y_batch in loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
        # Forward pass
        y_pred = model(X_batch)
        loss = criterion(y_pred, y_batch)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(loader)


def evaluate(model, loader, criterion, device):
    """Evaluate on a dataset."""
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for X_batch, y_batch in loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            
            y_pred = model(X_batch)
            loss = criterion(y_pred, y_batch)
            total_loss += loss.item()
            
            # Store predictions
            all_preds.extend((y_pred > 0.5).cpu().numpy().flatten())
            all_labels.extend(y_batch.cpu().numpy().flatten())
    
    accuracy = accuracy_score(all_labels, all_preds)
    return total_loss / len(loader), accuracy

In [None]:
# Training loop
train_losses, test_losses = [], []
train_accs, test_accs = [], []

print("Training...\n")
print(f"{'Epoch':>6} {'Train Loss':>12} {'Test Loss':>12} {'Train Acc':>12} {'Test Acc':>12}")
print("-" * 60)

for epoch in range(n_epochs):
    # Train
    train_loss = train_epoch(classifier, train_loader, criterion, optimizer, device)
    
    # Evaluate
    train_loss_eval, train_acc = evaluate(classifier, train_loader, criterion, device)
    test_loss, test_acc = evaluate(classifier, test_loader, criterion, device)
    
    # Store metrics
    train_losses.append(train_loss_eval)
    test_losses.append(test_loss)
    train_accs.append(train_acc)
    test_accs.append(test_acc)
    
    # Print progress
    if (epoch + 1) % 10 == 0 or epoch == 0:
        print(f"{epoch+1:>6} {train_loss_eval:>12.4f} {test_loss:>12.4f} {train_acc:>12.2%} {test_acc:>12.2%}")

print("\nTraining complete!")

## Visualize Training Progress

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss plot
axes[0].plot(train_losses, label='Train', linewidth=2)
axes[0].plot(test_losses, label='Test', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=11)
axes[0].set_ylabel('Loss', fontsize=11)
axes[0].set_title('Training and Test Loss', fontsize=12, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy plot
axes[1].plot(train_accs, label='Train', linewidth=2)
axes[1].plot(test_accs, label='Test', linewidth=2)
axes[1].axhline(y=y_train.mean(), color='gray', linestyle='--', 
                label=f'Baseline ({y_train.mean():.2%})', alpha=0.7)
axes[1].set_xlabel('Epoch', fontsize=11)
axes[1].set_ylabel('Accuracy', fontsize=11)
axes[1].set_title('Training and Test Accuracy', fontsize=12, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### ü§î Question 3

Look at the training curves:

1. Is the model overfitting? (Hint: Compare train vs test curves)
2. Has the model converged?
3. Is the model better than the baseline (always predicting "permeable")?

---

# Part 5: Evaluating the Model

Accuracy alone doesn't tell the full story. Let's look at more detailed metrics.

In [None]:
# Get final predictions
classifier.eval()
with torch.no_grad():
    y_pred_proba = classifier(X_test_t.to(device)).cpu().numpy().flatten()
    y_pred = (y_pred_proba > 0.5).astype(int)

# Classification report
print("Classification Report:")
print("=" * 55)
print(classification_report(y_test, y_pred, 
                            target_names=['BBB_impermeable', 'BBB_permeable']))

In [None]:
# Confusion matrix
cm = confusion_matrix(y_test, y_pred)

fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(cm, cmap='Blues')

# Labels
labels = ['Impermeable', 'Permeable']
ax.set_xticks([0, 1])
ax.set_yticks([0, 1])
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)

# Add values
for i in range(2):
    for j in range(2):
        ax.text(j, i, f'{cm[i,j]}', ha='center', va='center', fontsize=14,
               color='white' if cm[i,j] > cm.max()/2 else 'black')

ax.set_xlabel('Predicted', fontsize=11)
ax.set_ylabel('Actual', fontsize=11)
ax.set_title('Confusion Matrix', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

# Interpret
print(f"\nInterpretation:")
print(f"  True Negatives (correctly predicted impermeable): {cm[0,0]}")
print(f"  False Positives (incorrectly predicted permeable): {cm[0,1]}")
print(f"  False Negatives (incorrectly predicted impermeable): {cm[1,0]}")
print(f"  True Positives (correctly predicted permeable): {cm[1,1]}")

### ü§î Question 4

Look at the confusion matrix:

1. Which type of error is more common: false positives or false negatives?
2. In drug discovery, which error might be more costly?
   - False positive: Predict drug enters brain when it doesn't
   - False negative: Predict drug doesn't enter brain when it does
3. How might you adjust the model to reduce one type of error?

---

# Part 6: Experiments

Now it's your turn to experiment! Try modifying the parameters and see how performance changes.

## Exercise 1: Change the Architecture

What happens if you make the network deeper or shallower?

In [None]:
# TODO: Experiment with different architectures
# Try changing:
# - hidden_dim (64, 128, 256, 512)
# - dropout (0.0, 0.1, 0.3, 0.5)

# Example: Create a smaller network
small_classifier = BBBClassifier(
    input_dim=input_dim, 
    hidden_dim=64,       # <-- Smaller hidden layer
    dropout=0.1          # <-- Less dropout
)

n_params_small = sum(p.numel() for p in small_classifier.parameters())
print(f"Original model parameters: {n_params:,}")
print(f"Small model parameters: {n_params_small:,}")
print(f"\nDoes a smaller model perform worse? Try training it!")

## Exercise 2: Adjust the Decision Threshold

By default, we predict "permeable" if probability > 0.5.

What if we change this threshold?

In [None]:
# Try different thresholds
thresholds = [0.3, 0.4, 0.5, 0.6, 0.7]

print(f"{'Threshold':>10} {'Accuracy':>12} {'Precision':>12} {'Recall':>12}")
print("-" * 50)

for thresh in thresholds:
    y_pred_thresh = (y_pred_proba > thresh).astype(int)
    acc = accuracy_score(y_test, y_pred_thresh)
    
    # Compute precision and recall for permeable class
    tp = ((y_pred_thresh == 1) & (y_test == 1)).sum()
    fp = ((y_pred_thresh == 1) & (y_test == 0)).sum()
    fn = ((y_pred_thresh == 0) & (y_test == 1)).sum()
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    
    print(f"{thresh:>10.1f} {acc:>12.2%} {precision:>12.2%} {recall:>12.2%}")

### ü§î Question 5

Look at how threshold affects metrics:

1. What happens to precision as threshold increases?
2. What happens to recall as threshold increases?
3. Why is there a trade-off between precision and recall?
4. What threshold would you choose for a drug discovery application?

## Exercise 3: Compare to a Baseline Model

Let's see how our MLP compares to a simple logistic regression.

In [None]:
from sklearn.linear_model import LogisticRegression

# Train logistic regression
lr_model = LogisticRegression(max_iter=1000)
lr_model.fit(X_train, y_train)

# Evaluate
lr_pred = lr_model.predict(X_test)
lr_acc = accuracy_score(y_test, lr_pred)

print("Model Comparison:")
print("=" * 40)
print(f"{'Model':<25} {'Accuracy':>12}")
print("-" * 40)
print(f"{'Baseline (always 1)':<25} {y_test.mean():>12.2%}")
print(f"{'Logistic Regression':<25} {lr_acc:>12.2%}")
print(f"{'MLP (ours)':<25} {test_accs[-1]:>12.2%}")

---

# Summary: Key Takeaways

## What We Learned

1. **Embeddings as features**: Pre-trained models (ChemBERTa) convert molecules to useful representations

2. **Transfer learning**: A small classifier on top of frozen embeddings can be very effective

3. **Evaluation matters**: Accuracy alone isn't enough - look at precision, recall, confusion matrix

4. **Trade-offs**: Threshold, architecture, and hyperparameters all affect performance

## Key Insight

> **Pre-trained embeddings dramatically simplify the problem.** Instead of learning chemistry from scratch, we leverage knowledge from models trained on millions of molecules.

## Practical Applications

This same workflow applies to many scientific domains:

| Domain | Embedding Model | Classification Task |
|--------|-----------------|---------------------|
| Molecules | ChemBERTa | Toxicity, solubility, BBB |
| Proteins | ESM-2 | Function prediction, localization |
| DNA | DNABERT | Promoter detection, modification |
| Text | SciBERT | Topic classification, sentiment |

---

## Reflection Questions

1. **For your research**, what classification problems could you solve with embeddings?

2. **What pre-trained model** would you use, and what classifier would you build?

3. **What evaluation metrics** would matter most for your application?

4. **What's the limitation** of this approach? When might you need to fine-tune the embedding model itself?

---

## Bonus: Save Your Model

If you trained a good model, you can save it for later use!

In [None]:
# Save model
torch.save(classifier.state_dict(), 'bbb_classifier.pt')
print("Model saved to 'bbb_classifier.pt'")

# To load later:
# loaded_model = BBBClassifier(input_dim=768)
# loaded_model.load_state_dict(torch.load('bbb_classifier.pt'))