<a href="https://colab.research.google.com/github/zahraniayudyaa/finnalterm-dl/blob/main/01_MNLI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **FINE-TUNING HUGGINGFACE MODELS (MNLI)**

## **1. Setup dan Instalasi**

In [None]:
# 1. Setup dan Instalasi
!pip install transformers datasets torch scikit-learn pandas numpy matplotlib seaborn evaluate

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)
from datasets import load_dataset, load_metric
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## **2. Load Dataset**

In [None]:
# 2. Load Dataset - MNLI (Multi-Genre Natural Language Inference)
print("Loading MNLI dataset from GLUE...")
dataset = load_dataset("glue", "mnli")

print("\nDataset structure:")
print(dataset)
print(f"Train samples: {len(dataset['train'])}")
print(f"Validation matched samples: {len(dataset['validation_matched'])}")
print(f"Validation mismatched samples: {len(dataset['validation_mismatched'])}")
print(f"Test matched samples: {len(dataset['test_matched'])}")
print(f"Test mismatched samples: {len(dataset['test_mismatched'])}")

# 3. Examine data samples
print("\nSample data from training set:")
sample = dataset['train'][0]
print(f"Premise: {sample['premise']}")
print(f"Hypothesis: {sample['hypothesis']}")
print(f"Label: {sample['label']} ({sample['label']} = {['entailment', 'neutral', 'contradiction'][sample['label']]})")
print(f"Genre: {sample['genre']}")

# 4. Analyze dataset statistics
def analyze_mnli_dataset(dataset_split, split_name):
    print(f"\n{split_name} Statistics:")

    # Label distribution
    labels = dataset_split['label']
    unique, counts = np.unique(labels, return_counts=True)
    label_names = ['entailment', 'neutral', 'contradiction']

    print(f"Label distribution:")
    for label, count in zip(unique, counts):
        print(f"  {label_names[label]}: {count} samples ({count/len(labels)*100:.1f}%)")

    # Genre distribution
    if 'genre' in dataset_split.features:
        genres = dataset_split['genre']
        unique_genres, genre_counts = np.unique(genres, return_counts=True)
        print(f"\nGenre distribution:")
        for genre, count in zip(unique_genres, genre_counts):
            print(f"  {genre}: {count} samples ({count/len(genres)*100:.1f}%)")

    # Text length statistics
    premise_lengths = [len(p.split()) for p in dataset_split['premise']]
    hypothesis_lengths = [len(h.split()) for h in dataset_split['hypothesis']]

    print(f"\nText length statistics:")
    print(f"  Premise - Avg: {np.mean(premise_lengths):.1f} words, Max: {max(premise_lengths)}")
    print(f"  Hypothesis - Avg: {np.mean(hypothesis_lengths):.1f} words, Max: {max(hypothesis_lengths)}")

analyze_mnli_dataset(dataset['train'], 'Training Set')
analyze_mnli_dataset(dataset['validation_matched'], 'Validation Matched Set')

## **3. Preprocessing Data**

In [None]:
# 5. Preprocessing dan Tokenization untuk NLI
MODEL_NAME = "bert-base-uncased"  # Bisa diganti dengan distilbert-base-uncased atau roberta-base
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def preprocess_mnli(examples):
    # Format untuk NLI: "[CLS] premise [SEP] hypothesis [SEP]"
    return tokenizer(
        examples['premise'],
        examples['hypothesis'],
        truncation=True,
        padding=True,
        max_length=128,
        truncation_strategy='only_first'  # Truncate premise jika terlalu panjang
    )

print("\nTokenizing dataset...")
tokenized_datasets = dataset.map(preprocess_mnli, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['premise', 'hypothesis', 'idx', 'genre'])
tokenized_datasets = tokenized_datasets.rename_column('label', 'labels')

# 6. Split training data untuk validation
train_val_split = tokenized_datasets['train'].train_test_split(test_size=0.1, seed=42)
tokenized_datasets['train'] = train_val_split['train']
tokenized_datasets['val'] = train_val_split['test']

## **4. Load Model dan Training**

In [None]:
# 7. Load Model untuk Sequence Classification (3 labels)
print(f"\nLoading model: {MODEL_NAME}")
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=3  # entailment, neutral, contradiction
)
model.to(device)

# 8. Training Arguments
training_args = TrainingArguments(
    output_dir="./results_mnli",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    logging_dir="./logs",
    logging_steps=100,
    report_to="none",
    save_total_limit=2,
    push_to_hub=False
)

# 9. Load GLUE metric untuk MNLI
metric = load_metric("glue", "mnli")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    # Calculate basic metrics
    accuracy = accuracy_score(labels, predictions)

    # Use GLUE metric for additional metrics
    result = metric.compute(predictions=predictions, references=labels)

    return {
        "accuracy": accuracy,
        **result
    }

# 10. Data Collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# 11. Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["val"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

# 12. Train Model
print("\nTraining model for Natural Language Inference...")
train_result = trainer.train()

## **5. Evaluasi**

In [None]:
# 13. Evaluate on Validation Set
print("\nEvaluating model on validation set...")
eval_result = trainer.evaluate()
print(f"\nValidation results:")
for key, value in eval_result.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")

# 14. Evaluate on both matched and mismatched validation sets
print("\n" + "="*80)
print("Evaluation on Official Validation Sets")
print("="*80)

# Evaluate on validation_matched
print("\nEvaluating on validation_matched...")
val_matched_results = trainer.predict(tokenized_datasets['validation_matched'])
val_matched_metrics = val_matched_results.metrics
print(f"Validation Matched Accuracy: {val_matched_metrics.get('test_accuracy', 0):.4f}")

# Evaluate on validation_mismatched
print("\nEvaluating on validation_mismatched...")
val_mismatched_results = trainer.predict(tokenized_datasets['validation_mismatched'])
val_mismatched_metrics = val_mismatched_results.metrics
print(f"Validation Mismatched Accuracy: {val_mismatched_metrics.get('test_accuracy', 0):.4f}")

# 15. Save Model
print("\nSaving model...")
trainer.save_model("./saved_model_mnli")
tokenizer.save_pretrained("./saved_model_mnli")

In [None]:
# 16. Visualization Functions
def plot_confusion_matrix_nli(y_true, y_pred, labels, title='MNLI Confusion Matrix'):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=labels, yticklabels=labels)
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig('confusion_matrix_mnli.png')
    plt.show()

def plot_training_history_nli(trainer_state):
    if trainer_state.log_history:
        history = pd.DataFrame(trainer_state.log_history)

        # Plot loss
        plt.figure(figsize=(12, 4))

        plt.subplot(1, 2, 1)
        if 'loss' in history.columns:
            train_loss = history[history['loss'].notna()]
            plt.plot(train_loss['step'], train_loss['loss'], label='Training Loss')
        if 'eval_loss' in history.columns:
            eval_loss = history[history['eval_loss'].notna()]
            plt.plot(eval_loss['step'], eval_loss['eval_loss'], label='Validation Loss')
        plt.title('Training and Validation Loss')
        plt.xlabel('Steps')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.subplot(1, 2, 2)
        if 'eval_accuracy' in history.columns:
            eval_acc = history[history['eval_accuracy'].notna()]
            plt.plot(eval_acc['step'], eval_acc['eval_accuracy'], label='Validation Accuracy', color='green')
        plt.title('Validation Accuracy')
        plt.xlabel('Steps')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig('training_history_mnli.png')
        plt.show()

# Get predictions for confusion matrix
predictions = np.argmax(val_matched_results.predictions, axis=1)
labels = val_matched_results.label_ids

# Class names for MNLI
class_names = ["entailment", "neutral", "contradiction"]

# Generate classification report
print("\nDetailed Classification Report (validation_matched):")
print(classification_report(labels, predictions, target_names=class_names, digits=4))

# Plot confusion matrix
plot_confusion_matrix_nli(labels, predictions, class_names, 'MNLI - Confusion Matrix (Matched)')

# Plot training history
plot_training_history_nli(trainer.state)

# 17. Inference Function untuk NLI
def predict_nli(premise, hypothesis, model, tokenizer, device):
    # Tokenize input pair
    inputs = tokenizer(
        premise,
        hypothesis,
        truncation=True,
        padding=True,
        max_length=128,
        return_tensors="pt"
    )

    inputs = {k: v.to(device) for k, v in inputs.items()}

    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)

    # Get predictions
    predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
    predicted_class = torch.argmax(predictions, dim=1).item()
    probabilities = predictions[0].cpu().numpy()

    # Get class with highest probability
    class_names = ["entailment", "neutral", "contradiction"]

    # Get top predictions
    top3_indices = np.argsort(probabilities)[-3:][::-1]
    top3_predictions = [(class_names[i], probabilities[i]) for i in top3_indices]

    return {
        "premise": premise[:100] + "..." if len(premise) > 100 else premise,
        "hypothesis": hypothesis[:100] + "..." if len(hypothesis) > 100 else hypothesis,
        "predicted_class": predicted_class,
        "predicted_label": class_names[predicted_class],
        "confidence": probabilities[predicted_class],
        "probabilities": probabilities,
        "top3_predictions": top3_predictions,
        "interpretation": get_nli_interpretation(class_names[predicted_class], probabilities[predicted_class])
    }

def get_nli_interpretation(label, confidence):
    interpretations = {
        "entailment": "The hypothesis follows from the premise.",
        "neutral": "The hypothesis might be true given the premise, but it's not necessarily entailed.",
        "contradiction": "The hypothesis contradicts the premise."
    }

    strength = ""
    if confidence > 0.9:
        strength = " (strong)"
    elif confidence > 0.7:
        strength = " (moderate)"
    else:
        strength = " (weak)"

    return interpretations.get(label, "Unknown relationship") + strength

# 18. Test Inference dengan berbagai contoh
print("\n" + "="*80)
print("Natural Language Inference Test Examples")
print("="*80)

test_examples = [
    {
        "premise": "A man is playing guitar on stage.",
        "hypothesis": "A musician is performing.",
        "expected": "entailment"
    },
    {
        "premise": "The cat is sleeping on the sofa.",
        "hypothesis": "The dog is sleeping on the sofa.",
        "expected": "contradiction"
    },
    {
        "premise": "She bought apples from the market.",
        "hypothesis": "She purchased fruits.",
        "expected": "neutral"
    },
    {
        "premise": "All students passed the exam.",
        "hypothesis": "No student failed the exam.",
        "expected": "entailment"
    },
    {
        "premise": "The restaurant was empty.",
        "hypothesis": "The restaurant was full of customers.",
        "expected": "contradiction"
    }
]

print("\nTesting NLI predictions:")
print("-" * 80)

for i, example in enumerate(test_examples, 1):
    result = predict_nli(
        example["premise"],
        example["hypothesis"],
        model,
        tokenizer,
        device
    )

    print(f"\nExample {i}:")
    print(f"Premise: {result['premise']}")
    print(f"Hypothesis: {result['hypothesis']}")
    print(f"Predicted: {result['predicted_label']} (Expected: {example['expected']})")
    print(f"Confidence: {result['confidence']:.2%}")
    print(f"Interpretation: {result['interpretation']}")

    # Check if prediction matches expected
    match = "✓" if result['predicted_label'] == example['expected'] else "✗"
    print(f"Correct: {match}")

    print("Top 3 predictions:")
    for label, prob in result['top3_predictions']:
        print(f"  - {label}: {prob:.2%}")

# 19. Per-relation analysis
def analyze_relation_performance(model, tokenizer, device, test_data, num_samples=50):
    """Analyze model performance per relation type"""
    relations = {
        "entailment": {"correct": 0, "total": 0},
        "neutral": {"correct": 0, "total": 0},
        "contradiction": {"correct": 0, "total": 0}
    }

    # Sample data
    sample_indices = np.random.choice(len(test_data), min(num_samples, len(test_data)), replace=False)

    for idx in sample_indices:
        sample = test_data[int(idx)]
        premise = sample['premise']
        hypothesis = sample['hypothesis']
        true_label = sample['label']

        result = predict_nli(premise, hypothesis, model, tokenizer, device)
        predicted_label = result['predicted_label']

        true_label_name = ["entailment", "neutral", "contradiction"][true_label]

        relations[true_label_name]["total"] += 1
        if predicted_label == true_label_name:
            relations[true_label_name]["correct"] += 1

    # Calculate and display results
    print("\n" + "="*80)
    print("Per-Relation Performance Analysis")
    print("="*80)

    for relation, stats in relations.items():
        if stats["total"] > 0:
            accuracy = stats["correct"] / stats["total"]
            print(f"{relation.capitalize()}: {stats['correct']}/{stats['total']} = {accuracy:.2%}")
        else:
            print(f"{relation.capitalize()}: No samples")

# Analyze performance
analyze_relation_performance(model, tokenizer, device, dataset['validation_matched'])

# 20. Error Analysis - Contoh salah prediksi
def find_error_cases(model, tokenizer, device, test_data, num_cases=5):
    """Find and analyze error cases"""
    print("\n" + "="*80)
    print("Error Analysis - Misclassified Examples")
    print("="*80)

    error_cases = []
    sample_indices = np.random.choice(len(test_data), min(100, len(test_data)), replace=False)

    for idx in sample_indices:
        sample = test_data[int(idx)]
        premise = sample['premise']
        hypothesis = sample['hypothesis']
        true_label = sample['label']
        true_label_name = ["entailment", "neutral", "contradiction"][true_label]

        result = predict_nli(premise, hypothesis, model, tokenizer, device)

        if result['predicted_label'] != true_label_name:
            error_cases.append({
                "premise": premise,
                "hypothesis": hypothesis,
                "true_label": true_label_name,
                "predicted_label": result['predicted_label'],
                "confidence": result['confidence'],
                "probabilities": result['probabilities']
            })

        if len(error_cases) >= num_cases:
            break

    # Display error cases
    for i, case in enumerate(error_cases, 1):
        print(f"\nError Case {i}:")
        print(f"Premise: {case['premise']}")
        print(f"Hypothesis: {case['hypothesis']}")
        print(f"True: {case['true_label']}, Predicted: {case['predicted_label']}")
        print(f"Confidence: {case['confidence']:.2%}")
        print(f"Probabilities: entailment={case['probabilities'][0]:.2%}, "
              f"neutral={case['probabilities'][1]:.2%}, "
              f"contradiction={case['probabilities'][2]:.2%}")

# Find and analyze error cases
find_error_cases(model, tokenizer, device, dataset['validation_matched'])


# 22. Advanced Analysis - Genre-wise performance
if 'genre' in dataset['validation_matched'].features:
    print("\n" + "="*80)
    print("Genre-wise Performance Analysis")
    print("="*80)

    genres = dataset['validation_matched']['genre']
    unique_genres = np.unique(genres)

    genre_performance = {}

    for genre in unique_genres:
        # Get indices for this genre
        genre_indices = [i for i, g in enumerate(genres) if g == genre]
        genre_samples = dataset['validation_matched'].select(genre_indices)

        # Tokenize and predict
        tokenized_genre = genre_samples.map(preprocess_mnli, batched=True)
        tokenized_genre = tokenized_genre.remove_columns(['premise', 'hypothesis', 'idx', 'genre'])
        tokenized_genre = tokenized_genre.rename_column('label', 'labels')

        genre_results = trainer.predict(tokenized_genre)
        genre_predictions = np.argmax(genre_results.predictions, axis=1)
        genre_labels = genre_results.label_ids

        accuracy = accuracy_score(genre_labels, genre_predictions)
        genre_performance[genre] = {
            "accuracy": accuracy,
            "samples": len(genre_indices)
        }

    # Display results
    print("\nGenre Performance:")
    for genre, stats in genre_performance.items():
        print(f"  {genre}: {stats['accuracy']:.4f} accuracy ({stats['samples']} samples)")

print("\n" + "="*80)
print("MNLI Fine-tuning Complete!")
print(f"Model saved to: ./saved_model_mnli")
print(f"Inference script: inference_mnli.py")
print("="*80)