# TRL API Interface Documentation

This notebook documents the programming interface for the **Airline Sentiment Analysis** library. 
It demonstrates the usage of the high-level `SentimentModel` wrapper and data utilities defined in `trl_utils`.

### Supported Model Registry
The library is configured to support the following specific Hugging Face repositories:

| Model Name | Type | Hub Path |
| :--- | :--- | :--- |
| **Baseline BERT** | `bert` | `blank4hd/airline-sentiment-bert-baseline` |
| **Baseline GPT-2** | `gpt2` | `blank4hd/airline-sentiment-baseline-gpt2-sft` |
| **Improved SFT** | `gpt2` | `blank4hd/airline-sentiment-gpt2-improved-sft` |
| **DPO (Active)** | `gpt2` | `blank4hd/airline-sentiment-gpt2-dpo-active-learning` |

In [1]:
import trl_utils
import torch

# We filter warnings here to keep the API demonstration clean and readable
import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


## 1. Device Configuration API
The library includes a robust device detection mechanism. It automatically prioritizes hardware accelerators in this order:
1. **MPS** (Metal Performance Shaders) for Apple Silicon (Mac M1/M2/M3).
2. **CUDA** for NVIDIA GPUs.
3. **CPU** as a fallback.

**Function:** `get_device()`

In [2]:
device = trl_utils.get_device()
print(f"Active Computation Device: {device}")

Active Computation Device: mps


## 2. Model Wrapper Interface
The core of the API is the `SentimentModel` class. This wrapper abstracts the complexity of switching between **Sequence Classification** (BERT) and **Causal Language Modeling** (GPT-2).

### Initialization Contract
To load a model, you must specify the repository ID and the model architecture type (`bert` or `gpt2`).

```python
model = trl_utils.SentimentModel(
    repo_id: str,      # The Hugging Face Hub ID
    model_type: str,   # 'bert' or 'gpt2'
    device: torch.device = None  # Optional override
)
```

In [3]:
# Example 1: Loading the Baseline BERT (Discriminative)
print("Initializing BERT Baseline...")
bert_api = trl_utils.SentimentModel(
    repo_id="blank4hd/airline-sentiment-bert-baseline",
    model_type="bert"
)

# Example 2: Loading the DPO Active Learning Model (Generative)
print("Initializing GPT-2 DPO Model...")
dpo_api = trl_utils.SentimentModel(
    repo_id="blank4hd/airline-sentiment-gpt2-dpo-active-learning",
    model_type="gpt2"
)

print("\n✅ Models initialized successfully.")

Initializing BERT Baseline...
Loading blank4hd/airline-sentiment-bert-baseline (bert)...
Initializing GPT-2 DPO Model...
Loading blank4hd/airline-sentiment-gpt2-dpo-active-learning (gpt2)...

✅ Models initialized successfully.


## 3. Inference Interface
The `predict` method provides a unified output format regardless of the underlying model architecture. 

For **BERT**, it calculates the softmax probability of the target class.
For **GPT-2**, it constructs a prompt, generates a completion, and parses the resulting token.

**Method:** `predict(text: str)`  
**Returns:** `Tuple[label: str, confidence: float]`
- `label`: One of `['negative', 'neutral', 'positive']`
- `confidence`: 
    - **BERT**: Float between 0.0 - 1.0 representing certainty.
    - **GPT-2**: Returns `0.0` (Generative models do not output a single classification probability).

In [4]:
sample_input = "The flight was delayed but the staff was very helpful."

# BERT Prediction
label_b, conf_b = bert_api.predict(sample_input)
print(f"[BERT Baseline] Input: '{sample_input}'")
print(f"-> Label: {label_b} | Confidence: {conf_b:.4f}")

# DPO Prediction
label_g, conf_g = dpo_api.predict(sample_input)
print(f"\n[GPT-2 DPO]     Input: '{sample_input}'")
print(f"-> Label: {label_g} | Confidence: {conf_g:.4f}")

[BERT Baseline] Input: 'The flight was delayed but the staff was very helpful.'
-> Label: positive | Confidence: 0.7588

[GPT-2 DPO]     Input: 'The flight was delayed but the staff was very helpful.'
-> Label: positive | Confidence: 0.0000


## 4. Data Loading Interface
The library provides standardized helpers to load and split the specific airline dataset. This ensures that the same data cleaning steps (regex for removing URLs and user mentions) are applied consistently across all experiments.

**Function:** `load_airline_data(csv_path: str)`  
**Function:** `get_data_splits(df: DataFrame)`

In [7]:
try:
    # Attempt to load data if file exists locally
    # This function automatically applies text cleaning and label mapping
    df, label_map = trl_utils.load_airline_data("./data/Tweets.csv")
    print(f"Data loaded successfully. Total Rows: {df.shape[0]}")
    print(f"Label Mapping Used: {label_map}")
    
    # Split the data into Train/Val/Test
    train, val, test = trl_utils.get_data_splits(df)
    print(f"Held-out Test Split Size: {len(test)} samples")
    
except FileNotFoundError:
    print("⚠️ Note: 'Tweets.csv' not found. Data loading API requires the dataset file to be present in the root directory.")

Data loaded successfully. Total Rows: 14640
Label Mapping Used: {'negative': 0, 'neutral': 1, 'positive': 2}
Held-out Test Split Size: 1464 samples
