# FART Evaluation Pipeline

This notebook evaluates a FART (Flavor Aroma Recognition Task) model, trains a RoBERTa classifier, and displays evaluation metrics including ROC curves, confusion matrix, and accuracy metrics.

In [None]:
# Google Colab setup
'''
from google.colab import drive
import os

# mount drive
drive.mount('/content/drive')

# install libraries
!pip install -r "/content/drive/MyDrive/6.7910/code/requirements.txt"

# TODO: make sure to edit data-dir in configuration cell as well
'''

In [None]:
# Imports
import itertools
import json
import math
import os
from collections import Counter
from copy import deepcopy
from typing import Dict, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from datasets import Dataset
from rdkit import Chem
from scipy.special import softmax
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    precision_recall_fscore_support,
    roc_auc_score,
    roc_curve,
)
from sklearn.preprocessing import LabelEncoder, label_binarize
from sklearn.utils.class_weight import compute_class_weight
from torch import nn
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)

# Set matplotlib to display plots inline
%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

## ⚙️ Configuration

**Configure all parameters here before running the notebook.**

In [None]:
# Configuration
model_checkpoint = "terrytwk/chemberta-700k-augmented-smiles"
tokenizer_checkpoint = "seyonec/SMILES_tokenized_PubChem_shard00_160k"
data_dir = "./dataset/splits"
augmentation = False  # Set to True to enable SMILES augmentation
augmentation_numbers = [10, 10, 10, 10, 10]  # Only used if augmentation=True
num_train_epochs = 2
per_device_train_batch_size = 16
per_device_eval_batch_size = 16
max_length = 512
run_name = "fart_evaluation"

tastes = ["bitter", "sour", "sweet", "umami", "undefined"]
label_column = "Canonicalized Taste"
smiles_column = "Canonicalized SMILES"

print("=" * 80)
print("FART EVALUATION PIPELINE")
print("=" * 80)
print(f"Model: {model_checkpoint}")
print(f"Tokenizer: {tokenizer_checkpoint}")
print(f"Data directory: {data_dir}")
print(f"Augmentation: {augmentation}")
print("=" * 80)

## Helper Functions

In [None]:
def control_smiles_duplication(random_smiles, duplicate_control=lambda x: 1):
    counted_smiles = Counter(random_smiles)
    smiles_duplication = {
        smiles: math.ceil(duplicate_control(counted_smiles[smiles]))
        for smiles in counted_smiles
    }
    return list(
        itertools.chain.from_iterable(
            [[smiles] * smiles_duplication[smiles] for smiles in smiles_duplication]
        )
    )


def smiles_to_random(smiles, int_aug=50):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    if int_aug > 0:
        return [
            Chem.MolToSmiles(mol, canonical=False, doRandom=True)
            for _ in range(int_aug)
        ]
    if int_aug == 0:
        return [smiles]
    raise ValueError("int_aug must be greater or equal to zero.")


def augmentation_without_duplication(smiles, augmentation_number):
    smiles_list = smiles_to_random(smiles, augmentation_number)
    return control_smiles_duplication(smiles_list, lambda x: 1)


def augment_dataset(dataset: Dataset, augmentation_numbers, tastes, label_column, smiles_column):
    augmented_data = []
    for i, taste in enumerate(tastes):
        for entry in dataset:
            if entry[label_column] == taste:
                original_smiles = entry[smiles_column]
                new_smiles_list = augmentation_without_duplication(original_smiles, augmentation_numbers[i])
                for new_smiles in new_smiles_list:
                    new_entry = deepcopy(entry)
                    new_entry[smiles_column] = new_smiles
                    augmented_data.append(new_entry)
            else:
                augmented_data.append(entry)
    return Dataset.from_dict({key: [entry[key] for entry in augmented_data] for key in augmented_data[0]})


class CustomTrainer(Trainer):
    """Trainer with optional class-weighted loss."""

    def __init__(self, *args, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = nn.CrossEntropyLoss(
            weight=self.class_weights.to(logits.device) if self.class_weights is not None else None
        )
        loss = loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss


def load_csvs(data_dir: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    train_df = pd.read_csv(os.path.join(data_dir, "fart_train.csv"))
    val_df = pd.read_csv(os.path.join(data_dir, "fart_val.csv"))
    test_df = pd.read_csv(os.path.join(data_dir, "fart_test.csv"))
    train_df.reset_index(drop=True, inplace=True)
    val_df.reset_index(drop=True, inplace=True)
    test_df.reset_index(drop=True, inplace=True)
    return train_df, val_df, test_df

## 1. Load Data

In [None]:
print("\n[1/7] Loading data...")
train_df, val_df, test_df = load_csvs(data_dir)
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)
print(f"✓ Train samples: {len(train_dataset)}")
print(f"✓ Validation samples: {len(val_dataset)}")
print(f"✓ Test samples: {len(test_dataset)}")

## 4. Tokenization

In [None]:
if augmentation:
    print("\n[2/7] Performing SMILES augmentation...")
    train_dataset = augment_dataset(train_dataset, augmentation_numbers, tastes, label_column, smiles_column)
    val_dataset = augment_dataset(val_dataset, augmentation_numbers, tastes, label_column, smiles_column)
    test_dataset = augment_dataset(test_dataset, augmentation_numbers, tastes, label_column, smiles_column)
    print(f"✓ Augmented train samples: {len(train_dataset)}")
    print(f"✓ Augmented validation samples: {len(val_dataset)}")
    print(f"✓ Augmented test samples: {len(test_dataset)}")
else:
    print("\n[2/7] Skipping augmentation...")

## 4. Tokenization

In [None]:
# print(f"\n[3/7] Loading model and tokenizer from: {model_checkpoint}")
# tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
# print(f"✓ Tokenizer loaded (vocab size: {tokenizer.vocab_size})")
print(f"\n[3/7] Loading model and tokenizer...")
tokenizer_path = tokenizer_checkpoint if tokenizer_checkpoint else model_checkpoint
print(f"  Model checkpoint: {model_checkpoint}")
print(f"  Tokenizer: {tokenizer_path}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
print(f"✓ Tokenizer loaded (vocab size: {tokenizer.vocab_size})")

## 6. Training Setup

In [None]:
print("\n[4/7] Tokenizing datasets...")

def tokenize_function(examples):
    return tokenizer(
        examples[smiles_column],
        padding="max_length",
        truncation=True,
        max_length=max_length,
    )

train_dataset = train_dataset.map(tokenize_function, batched=True)
val_dataset = val_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)
print("✓ Tokenization complete")

## 5. Label Encoding

In [None]:
print("\n[5/7] Encoding labels...")
label_encoder = LabelEncoder()
train_labels = label_encoder.fit_transform(train_dataset[label_column])
val_labels = label_encoder.transform(val_dataset[label_column])
test_labels = label_encoder.transform(test_dataset[label_column])

train_dataset = train_dataset.add_column("labels", train_labels)
val_dataset = val_dataset.add_column("labels", val_labels)
test_dataset = test_dataset.add_column("labels", test_labels)
print(f"✓ Classes: {label_encoder.classes_}")

# Class weights (disabled to mirror legacy behavior)
class_weight_values = compute_class_weight(
    class_weight="balanced",
    classes=np.unique(train_labels),
    y=train_labels,
)
class_weights = None
print("\nClass distribution in training set:")
unique, counts = np.unique(train_labels, return_counts=True)
for label, count in zip(unique, counts):
    class_name = label_encoder.inverse_transform([label])[0]
    print(f"  {class_name}: {count} samples (weight: {class_weight_values[label]:.4f})")

## 6. Training Setup

In [None]:
print("\n[6/7] Setting up training...")
num_labels = len(label_encoder.classes_)
model = AutoModelForSequenceClassification.from_pretrained(
    model_checkpoint,
    num_labels=num_labels,
)
print(f"✓ Classification head initialized with {num_labels} labels")

# Create a temporary output directory for training checkpoints
output_dir = "./temp_training_output"
os.makedirs(output_dir, exist_ok=True)

training_args = TrainingArguments(
    run_name=run_name,
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    weight_decay=0.01,
    eval_strategy="steps",
    logging_dir=os.path.join(output_dir, "logs"),
    save_strategy="steps",
    load_best_model_at_end=True,
    save_total_limit=5,
    dataloader_num_workers=8,
    dataloader_pin_memory=True,
    dataloader_prefetch_factor=2,
)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {"accuracy": accuracy_score(labels, predictions)}

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    class_weights=class_weights,
)

## 7. Training

In [None]:
print("\nStarting training...")
trainer.train()
print("✓ Training complete!")

## 8. Evaluation and Metrics

In [None]:
print("\n[7/7] Running evaluation...")
print("\nValidation Results:")
val_results = trainer.evaluate(eval_dataset=val_dataset)
for key, value in val_results.items():
    print(f"  {key}: {value:.4f}" if isinstance(value, float) else f"  {key}: {value}")

print("\nGenerating test predictions...")
predictions = trainer.predict(test_dataset)
probs = softmax(predictions.predictions, axis=1)
pred_labels = np.argmax(probs, axis=1)
true_labels = predictions.label_ids

## Test Set Metrics

In [None]:
# Calculate all metrics
accuracy = accuracy_score(true_labels, pred_labels)
precision, recall, f1, support = precision_recall_fscore_support(
    true_labels, pred_labels, labels=np.arange(num_labels)
)
precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
    true_labels, pred_labels, average="macro", labels=np.arange(num_labels)
)
precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
    true_labels, pred_labels, average="weighted", labels=np.arange(num_labels)
)

label_names = label_encoder.inverse_transform(range(num_labels))

# Display metrics
print("\n" + "=" * 80)
print("TEST SET RESULTS")
print("=" * 80)
print(f"\nOverall Metrics:")
print(f"  Accuracy: {accuracy:.4f}")
print(f"  Macro Precision: {precision_macro:.4f}")
print(f"  Macro Recall: {recall_macro:.4f}")
print(f"  Macro F1 Score: {f1_macro:.4f}")
print(f"  Weighted Precision: {precision_weighted:.4f}")
print(f"  Weighted Recall: {recall_weighted:.4f}")
print(f"  Weighted F1 Score: {f1_weighted:.4f}")

print("\nPer-Class Metrics:")
for i, (p, r, f, s) in enumerate(zip(precision, recall, f1, support)):
    print(f"  Class {label_names[i]}:")
    print(f"    Precision: {p:.4f}")
    print(f"    Recall: {r:.4f}")
    print(f"    F1 Score: {f:.4f}")
    print(f"    Support: {s}")

# Create a metrics DataFrame for better visualization
metrics_df = pd.DataFrame({
    'Class': label_names,
    'Precision': precision,
    'Recall': recall,
    'F1 Score': f1,
    'Support': support
})
print("\n" + "=" * 80)
print("Metrics Summary Table")
print("=" * 80)
print(metrics_df.to_string(index=False))

## Confusion Matrix

In [None]:
conf_matrix = confusion_matrix(true_labels, pred_labels)

plt.figure(figsize=(10, 7))
sns.heatmap(
    conf_matrix,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=label_names,
    yticklabels=label_names,
)
plt.title("Confusion Matrix", fontsize=16, fontweight='bold')
plt.ylabel("Actual Label", fontsize=12)
plt.xlabel("Predicted Label", fontsize=12)
plt.tight_layout()
plt.show()

## ROC Curves

In [None]:
true_labels_bin = label_binarize(true_labels, classes=np.arange(num_labels))
plt.figure(figsize=(10, 8))

auc_scores = {}
for i in range(num_labels):
    if np.sum(true_labels_bin[:, i]) > 0:
        auc = roc_auc_score(true_labels_bin[:, i], probs[:, i])
        auc_scores[label_names[i]] = auc
        fpr, tpr, _ = roc_curve(true_labels_bin[:, i], probs[:, i])
        plt.plot(fpr, tpr, label=f"{label_names[i]} (AUC = {auc:.4f})", linewidth=2)

plt.plot([0, 1], [0, 1], "k--", label="Random Classifier (AUC = 0.5)", linewidth=1.5)
plt.xlabel("False Positive Rate", fontsize=12)
plt.ylabel("True Positive Rate", fontsize=12)
plt.title("ROC Curves for Each Class", fontsize=16, fontweight='bold')
plt.legend(loc="lower right", fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\nAUC Scores:")
for class_name, auc in auc_scores.items():
    print(f"  {class_name}: {auc:.4f}")

## Ensemble Voting (if augmentation enabled)

In [None]:
if augmentation:
    print("\nPerforming ensemble voting...")
    df = pd.DataFrame(
        {
            "Standardized SMILES": test_dataset["Standardized SMILES"],
            "label": test_dataset["labels"],
            "pred_probs": list(probs),
        }
    )
    df["pred_labels"] = np.argmax(df["pred_probs"].tolist(), axis=1)

    grouped = (
        df.groupby("Standardized SMILES")
        .agg(
            {
                "pred_labels": lambda x: x.value_counts().idxmax()
                if x.value_counts().iloc[0] >= 10
                else np.nan,
                "label": "first",
            }
        )
        .dropna()
        .reset_index()
    )

    pred_probs_voted = np.array(
        [np.mean(df.loc[df["Standardized SMILES"] == smile, "pred_probs"], axis=0) for smile in grouped["Standardized SMILES"]]
    )
    true_labels_voted = grouped["label"].values
    pred_labels_voted = grouped["pred_labels"].values

    accuracy_voted = accuracy_score(true_labels_voted, pred_labels_voted)
    precision_v, recall_v, f1_v, support_v = precision_recall_fscore_support(
        true_labels_voted, pred_labels_voted, labels=np.arange(num_labels)
    )
    precision_macro_v, recall_macro_v, f1_macro_v, _ = precision_recall_fscore_support(
        true_labels_voted, pred_labels_voted, average="macro", labels=np.arange(num_labels)
    )

    print(f"\nEnsemble Voting Results:")
    print(f"  Voted Accuracy: {accuracy_voted:.4f}")
    print(f"  Voted Macro Precision: {precision_macro_v:.4f}")
    print(f"  Voted Macro Recall: {recall_macro_v:.4f}")
    print(f"  Voted Macro F1 Score: {f1_macro_v:.4f}")

    # ROC curves for ensemble voting
    true_labels_bin_voted = label_binarize(true_labels_voted, classes=np.arange(num_labels))
    plt.figure(figsize=(10, 8))
    
    auc_scores_voted = {}
    for i in range(num_labels):
        if np.sum(true_labels_bin_voted[:, i]) > 0:
            auc = roc_auc_score(true_labels_bin_voted[:, i], pred_probs_voted[:, i])
            auc_scores_voted[label_names[i]] = auc
            fpr, tpr, _ = roc_curve(true_labels_bin_voted[:, i], pred_probs_voted[:, i])
            plt.plot(fpr, tpr, label=f"{label_names[i]} (AUC = {auc:.4f})", linewidth=2)
    
    plt.plot([0, 1], [0, 1], "k--", label="Random Classifier (AUC = 0.5)", linewidth=1.5)
    plt.xlabel("False Positive Rate", fontsize=12)
    plt.ylabel("True Positive Rate", fontsize=12)
    plt.title("ROC Curves for Each Class (Ensemble Voting)", fontsize=16, fontweight='bold')
    plt.legend(loc="lower right", fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print("\nEnsemble Voting AUC Scores:")
    for class_name, auc in auc_scores_voted.items():
        print(f"  {class_name}: {auc:.4f}")
else:
    print("\nEnsemble voting skipped (augmentation disabled)")