# Airline Sentiment Analysis: Model Benchmark

This notebook acts as the reference implementation for the `trl` library. It benchmarks four different models using the `SentimentModel` wrapper to determine which architecture best handles airline customer sentiment.

In [7]:
import trl_utils
import pandas as pd
from sklearn.metrics import classification_report, accuracy_score
from tqdm import tqdm

# Set display options for cleaner tables
pd.set_option('display.max_colwidth', 150)

## 1. Data Loading
We load the clean dataset and isolate the held-out test set for evaluation.

In [8]:
try:
    # Load full data
    df, label2id = trl_utils.load_airline_data("./data/Tweets.csv")
    
    # Get the held-out test split
    _, _, test_df = trl_utils.get_data_splits(df)
    
    # Use a random sample of 50 tweets for this demo run to keep execution fast
    sample_df = test_df.sample(50, random_state=42).reset_index(drop=True)
    print(f"Test Set Sample: {len(sample_df)} tweets")
    
except FileNotFoundError:
    print("Error: 'Tweets.csv' not found. Please ensure the dataset is in the directory.")
    # Create a dummy dataframe so the rest of the notebook cells don't crash
    sample_df = pd.DataFrame({
        "clean_text": ["flight was great", "service was bad"],
        "airline_sentiment": ["positive", "negative"]
    })

Test Set Sample: 50 tweets


## 2. Model Initialization
We define the configuration for the 4 models to be compared and load them using the API.

In [9]:
# DEFINITION OF ALL 4 MODELS
MODELS_CONFIG = {
    "Baseline BERT":  {"path": "blank4hd/airline-sentiment-bert-baseline", "type": "bert"},
    "Baseline GPT-2": {"path": "blank4hd/airline-sentiment-baseline-gpt2-sft", "type": "gpt2"},
    "Improved SFT":   {"path": "blank4hd/airline-sentiment-gpt2-improved-sft",  "type": "gpt2"},
    "DPO (Active)":   {"path": "blank4hd/airline-sentiment-gpt2-dpo-active-learning", "type": "gpt2"},
}

initialized_models = {}

print("Loading models...")
for name, conf in MODELS_CONFIG.items():
    try:
        print(f"-> Loading {name}...")
        initialized_models[name] = trl_utils.SentimentModel(
            repo_id=conf["path"],
            model_type=conf["type"]
        )
    except Exception as e:
        print(f"❌ Failed to load {name}: {e}")

Loading models...
-> Loading Baseline BERT...
Loading blank4hd/airline-sentiment-bert-baseline (bert)...
-> Loading Baseline GPT-2...
Loading blank4hd/airline-sentiment-baseline-gpt2-sft (gpt2)...
-> Loading Improved SFT...
Loading blank4hd/airline-sentiment-gpt2-improved-sft (gpt2)...
-> Loading DPO (Active)...
Loading blank4hd/airline-sentiment-gpt2-dpo-active-learning (gpt2)...


## 3. Inference Loop
We iterate through the sample tweets and gather predictions from all active models.

In [10]:
results = []

print("Running inference...")
for i, row in tqdm(sample_df.iterrows(), total=len(sample_df)):
    text = row["clean_text"]
    true_label = row["airline_sentiment"]
    
    row_result = {
        "Tweet": text,
        "True Label": true_label
    }
    
    # Ask each model
    for model_name, model_obj in initialized_models.items():
        pred_label, _ = model_obj.predict(text)
        row_result[model_name] = pred_label
        
    results.append(row_result)

results_df = pd.DataFrame(results)
results_df.head(5)

Running inference...


100%|██████████| 50/50 [00:06<00:00,  7.69it/s]


Unnamed: 0,Tweet,True Label,Baseline BERT,Baseline GPT-2,Improved SFT,DPO (Active)
0,i paid using paypal online and after i was charged there was a system error which is what i ended up calling about,negative,negative,negative,negative,negative
1,amazing hospitality and helpfulness from anthony lastella great staff def flying united again,positive,positive,negative,positive,positive
2,yes ive boarded this way many times amp have never had to show my pass on the tarmac multiple times path was railed off only way,neutral,negative,negative,negative,negative
3,go go,neutral,neutral,neutral,neutral,neutral
4,your customer service in philly is deplorable rude amp unprofessional gate agents after delays amp cancelled flightations takingthistothetop,negative,negative,negative,negative,negative


## 4. Evaluation & Analysis
We compare the models based on accuracy and a detailed classification report.

In [11]:
truth = results_df["True Label"]
labels = ["negative", "neutral", "positive"]

for model_name in initialized_models.keys():
    if model_name in results_df.columns:
        preds = results_df[model_name]
        acc = accuracy_score(truth, preds)
        
        print(f"\n{'='*40}")
        print(f"MODEL: {model_name} (Accuracy: {acc:.2%})")
        print(f"{'='*40}")
        print(classification_report(truth, preds, labels=labels, zero_division=0))


MODEL: Baseline BERT (Accuracy: 88.00%)
              precision    recall  f1-score   support

    negative       0.97      0.93      0.95        30
     neutral       0.64      0.88      0.74         8
    positive       0.90      0.75      0.82        12

    accuracy                           0.88        50
   macro avg       0.83      0.85      0.83        50
weighted avg       0.90      0.88      0.88        50


MODEL: Baseline GPT-2 (Accuracy: 72.00%)
              precision    recall  f1-score   support

    negative       0.70      1.00      0.82        30
     neutral       0.80      0.50      0.62         8
    positive       1.00      0.17      0.29        12

    accuracy                           0.72        50
   macro avg       0.83      0.56      0.57        50
weighted avg       0.79      0.72      0.66        50


MODEL: Improved SFT (Accuracy: 84.00%)
              precision    recall  f1-score   support

    negative       0.96      0.83      0.89        30
     n

## 5. Comparative Analysis
Let's identify specific tweets where the models disagree. This is useful for understanding the impact of DPO.

In [12]:
# Filter for cases where Baseline BERT and DPO disagree
if "Baseline BERT" in results_df.columns and "DPO (Active)" in results_df.columns:
    disagreements = results_df[
        results_df["Baseline BERT"] != results_df["DPO (Active)"]
    ]
    
    print(f"Found {len(disagreements)} disagreements between BERT and DPO.")
    if not disagreements.empty:
        display(disagreements[["Tweet", "True Label", "Baseline BERT", "DPO (Active)"]].head())

Found 2 disagreements between BERT and DPO.


Unnamed: 0,Tweet,True Label,Baseline BERT,DPO (Active)
7,thanks next up we will see how the slog from jfk to the city goes,positive,neutral,positive
40,whats going on with your website amp mobile app help,negative,neutral,negative
