## 1. Setup and Imports

In [None]:
import os
import sys
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Get project root directory
PROJECT_ROOT = os.path.dirname(os.getcwd())

# Add src to path
sys.path.insert(0, os.path.join(PROJECT_ROOT, 'src'))

from transformers import DistilBertTokenizer
from model import DisasterTweetClassifier, load_model
from utils import preprocess_tweet, predict_single_tweet, get_device
from config import MODEL_NAME, MAX_LENGTH, DROPOUT

# Setup
sns.set_style('white')
print("Setup complete!")

## 2. Load Pre-trained Model

In [None]:
# Get device
device = get_device()

# Load tokenizer
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
print(f"Tokenizer loaded: {MODEL_NAME}")

# Load model
checkpoint_path = os.path.join(PROJECT_ROOT, 'checkpoints', 'disaster_tweet_classifier.pt')

if os.path.exists(checkpoint_path):
    model = load_model(checkpoint_path, device, num_classes=2, dropout=DROPOUT)
    print(f"Model loaded from: {checkpoint_path}")
else:
    print(f"ERROR: Model checkpoint not found at {checkpoint_path}")
    print("Please download the pre-trained model from the link in README.md")

## 3. Single Tweet Classification

Enter any tweet to classify it as disaster or not disaster.

In [None]:
def classify_tweet(tweet_text):
    """Classify a single tweet and display results."""
    label, confidence = predict_single_tweet(model, tokenizer, tweet_text, device, MAX_LENGTH)
    
    print("=" * 70)
    print(f"Tweet: {tweet_text}")
    print("-" * 70)
    print(f"Prediction: {label}")
    print(f"Confidence: {confidence:.1%}")
    print("=" * 70)
    
    return label, confidence

In [None]:
# Example: Disaster tweet
classify_tweet("BREAKING: Massive earthquake hits California, buildings collapsed")

In [None]:
# Example: Non-disaster tweet
classify_tweet("Just watched a disaster movie, it was so good!")

## 4. Batch Classification Demo

Classify multiple sample tweets and visualize results.

In [None]:
# Sample tweets for demonstration
sample_tweets = [
    # Disaster tweets
    "BREAKING: Massive earthquake hits California, buildings collapsed",
    "Forest fire spreading rapidly in Oregon, evacuations ordered",
    "Tsunami warning issued for coastal areas after underwater earthquake",
    "Multiple casualties reported in train derailment accident",
    "Hurricane approaching Florida coast, residents urged to evacuate",
    
    # Non-disaster tweets
    "Just watched a disaster movie, it was so good!",
    "My kitchen is a war zone after cooking dinner",
    "This concert is absolutely fire! Best night ever!",
    "The new iPhone launch was explosive! So many new features",
    "Traffic is so bad today, it's like a disaster out here"
]

# Classify all tweets
results = []
for tweet in sample_tweets:
    label, confidence = predict_single_tweet(model, tokenizer, tweet, device, MAX_LENGTH)
    results.append({
        'tweet': tweet[:50] + '...' if len(tweet) > 50 else tweet,
        'prediction': label,
        'confidence': confidence
    })

# Display results
results_df = pd.DataFrame(results)
print("Classification Results:")
print("=" * 80)
for _, row in results_df.iterrows():
    icon = "ðŸš¨" if row['prediction'] == "DISASTER" else "âœ…"
    print(f"{icon} [{row['confidence']:.1%}] {row['prediction']:12s} | {row['tweet']}")
print("=" * 80)

In [None]:
# Visualize confidence scores
fig, ax = plt.subplots(figsize=(12, 6))

colors = ['coral' if p == 'DISASTER' else 'steelblue' for p in results_df['prediction']]
bars = ax.barh(range(len(results_df)), results_df['confidence'], color=colors)

ax.set_yticks(range(len(results_df)))
ax.set_yticklabels(results_df['tweet'], fontsize=10)
ax.set_xlabel('Confidence Score')
ax.set_title('Disaster Tweet Classification Results')
ax.set_xlim(0, 1)

# Add legend
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor='coral', label='Disaster'),
                   Patch(facecolor='steelblue', label='Not Disaster')]
ax.legend(handles=legend_elements, loc='lower right')

plt.tight_layout()

# Save to results folder
results_path = os.path.join(PROJECT_ROOT, 'results', 'demo_classification_results.png')
os.makedirs(os.path.dirname(results_path), exist_ok=True)
plt.savefig(results_path, dpi=150, bbox_inches='tight')
print(f"Results saved to: {results_path}")

plt.show()

## 5. Interactive Classification

Try your own tweets below!

In [None]:
# Enter your own tweet here!
your_tweet = "Enter your tweet here to classify it"

classify_tweet(your_tweet)

## 6. Model Information

In [None]:
print("Model Architecture:")
print("=" * 50)
print(f"Base Model: {MODEL_NAME}")
print(f"Max Sequence Length: {MAX_LENGTH}")
print(f"Dropout: {DROPOUT}")
print(f"Total Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Device: {device}")
print("\nClassification Head:")
print("  768 â†’ Dropout â†’ 256 â†’ ReLU â†’ Dropout â†’ 2")

---

## Summary

This demo showcases the Disaster Tweet Classifier's ability to:

1. **Accurately classify** tweets as disaster-related or not
2. **Handle metaphorical language** (e.g., "This concert is fire!" â‰  disaster)
3. **Provide confidence scores** for each prediction

The model achieves **~79% F1 score** on the validation set, outperforming traditional TF-IDF baselines.