In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Environmental Claims Classifier
-------------------------------
Tiny-BERT architecture with:
- Lightweight text preprocessing
- Class imbalance handling
- Comprehensive metrics tracking
- Hyperparameter optimization

Runtime: ≈ 15 min on CPU-only, faster with GPU
"""

import os

def install_dependencies():
    """Install required packages."""
    # CPU-only PyTorch installation
    torch_install = "pip install -qU torch==2.6.0+cpu torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu"
    # Other dependencies
    deps_install = "pip install -qU transformers datasets pytorch-lightning torchmetrics matplotlib pandas scikit-learn"

    print("Installing dependencies...")
    os.system(torch_install)
    os.system(deps_install)

# Install dependencies
install_dependencies()

import re
import logging
from typing import Dict, List, Tuple, Any, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import datasets
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup
)
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
from sklearn.metrics import classification_report, confusion_matrix


# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


class Config:
    """Configuration settings for model training and evaluation."""

    # Core settings
    SEED: int = 42
    BATCH_SIZE: int = 64
    MAX_EPOCHS: int = 6
    MAX_SEQ_LEN: int = 512
    MODEL_NAME: str = "prajjwal1/bert-tiny"

    # Dataset
    DATASET_NAME: str = "climatebert/environmental_claims"

    # Model hyperparameter grids for optimization
    HYPERPARAMETER_CONFIGS = [
        {"lr": 3e-5, "freeze_layers": 0, "weight_decay": 0.01},
        {"lr": 5e-5, "freeze_layers": 2, "weight_decay": 0.00},
        {"lr": 2e-5, "freeze_layers": 0, "weight_decay": 0.01},
        {"lr": 3e-5, "freeze_layers": 4, "weight_decay": 0.00},
    ]

    # Early stopping
    EARLY_STOP_PATIENCE: int = 2
    EARLY_STOP_METRIC: str = "val_f1"
    EARLY_STOP_MODE: str = "max"

    # Training
    WARMUP_RATIO: float = 0.1
    LOG_EVERY_N_STEPS: int = 10


class DataProcessor:
    """Handles loading, preprocessing, and preparing the environmental claims dataset."""

    def __init__(self, config: Config):
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME)

    def load_dataset(self) -> Tuple[datasets.Dataset, datasets.Dataset]:
        """Load and prepare the environmental claims dataset."""
        logger.info(f"Loading dataset: {self.config.DATASET_NAME}")
        env_ds = datasets.load_dataset(self.config.DATASET_NAME)
        return env_ds["train"], env_ds["validation"]

    @staticmethod
    def clean_text(text: str) -> str:
        """Clean text with basic preprocessing."""
        text = text.lower()
        text = re.sub(r"http\S+", " ", text)                   # strip URLs
        text = re.sub(r"[^a-z0-9\s\.,!?']", " ", text)         # basic ascii filter
        text = re.sub(r"\s+", " ", text).strip()               # normalize whitespace
        return text

    def clean_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        """Apply text cleaning to a batch."""
        return {"text": self.clean_text(batch["text"])}

    def tokenize_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        """Tokenize a batch of texts."""
        return self.tokenizer(
            batch["text"],
            truncation=True,
            padding="max_length",
            max_length=self.config.MAX_SEQ_LEN
        )

    def prepare_datasets(self) -> Tuple[datasets.Dataset, datasets.Dataset]:
        """Prepare datasets for training and validation."""
        # Load raw datasets
        train_ds, val_ds = self.load_dataset()

        # Apply cleaning
        logger.info("Cleaning text data")
        train_ds = train_ds.map(self.clean_batch)
        val_ds = val_ds.map(self.clean_batch)

        # Tokenize
        logger.info("Tokenizing datasets")
        train_ds = train_ds.map(self.tokenize_batch, batched=True, remove_columns=["text"])
        val_ds = val_ds.map(self.tokenize_batch, batched=True, remove_columns=["text"])

        # Set format to PyTorch tensors
        train_ds.set_format("torch")
        val_ds.set_format("torch")

        return train_ds, val_ds

    def create_dataloaders(self, train_ds: datasets.Dataset, val_ds: datasets.Dataset) -> Tuple[DataLoader, DataLoader]:
        """Create data loaders for training and validation."""
        train_loader = DataLoader(
            train_ds,
            batch_size=self.config.BATCH_SIZE,
            shuffle=True,
            pin_memory=True
        )
        val_loader = DataLoader(
            val_ds,
            batch_size=self.config.BATCH_SIZE,
            pin_memory=True
        )
        return train_loader, val_loader

    def calculate_class_weights(self, train_ds: datasets.Dataset) -> torch.Tensor:
        """Calculate class weights to handle class imbalance."""
        cls_counts = np.bincount(train_ds["label"])
        # Inverse frequency weighting (normalized)
        weights = torch.tensor(
            cls_counts.sum() / (len(cls_counts) * cls_counts),
            dtype=torch.float
        )
        return weights


class EnvironmentalClaimClassifier(pl.LightningModule):
    """PyTorch Lightning module for environmental claim classification."""

    def __init__(
        self,
        model_name: str,
        num_classes: int = 2,
        lr: float = 3e-5,
        freeze_layers: int = 0,
        weight_decay: float = 0.01,
        class_weights: Optional[torch.Tensor] = None
    ):
        """Initialize the classifier.

        Args:
            model_name: Name of the pretrained model
            num_classes: Number of classes for classification
            lr: Learning rate
            freeze_layers: Number of encoder layers to freeze
            weight_decay: L2 regularization weight
            class_weights: Class weights for imbalanced dataset
        """
        super().__init__()
        self.save_hyperparameters()

        # Load pre-trained model
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            num_labels=num_classes
        )

        # Freeze specified encoder layers if requested
        if freeze_layers > 0:
            logger.info(f"Freezing {freeze_layers} encoder layers")
            for param in self.model.bert.encoder.layer[:freeze_layers].parameters():
                param.requires_grad_(False)

        # Set up loss function and metrics
        self.loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)
        self.acc = MulticlassAccuracy(num_classes=num_classes, average="macro")
        self.f1 = MulticlassF1Score(num_classes=num_classes, average="macro")

    def forward(self, **x):
        """Forward pass."""
        return self.model(**x).logits

    def _shared_step(self, batch):
        """Common operations for both training and validation steps."""
        labels = batch.pop("label")
        logits = self(**batch)
        loss = self.loss_fn(logits, labels)
        preds = logits.argmax(dim=1)
        return loss, preds, labels

    def training_step(self, batch, _):
        """Training step."""
        loss, _, _ = self._shared_step(batch)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, _):
        """Validation step."""
        loss, preds, labels = self._shared_step(batch)

        # Update metrics
        self.acc.update(preds, labels)
        self.f1.update(preds, labels)

        # Log metrics
        self.log("val_loss", loss, prog_bar=True, on_epoch=True, on_step=False,
                 batch_size=labels.size(0))
        self.log("val_acc", self.acc, prog_bar=True, on_epoch=True, on_step=False)
        self.log("val_f1", self.f1, prog_bar=True, on_epoch=True, on_step=False)

    def configure_optimizers(self):
        """Configure optimizer and learning rate scheduler."""
        # Set up optimizer
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay
        )

        # Set up learning rate scheduler with warmup
        total_steps = self.trainer.estimated_stepping_batches
        warmup_steps = int(Config.WARMUP_RATIO * total_steps)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": scheduler, "interval": "step"}
        }


class Evaluator:
    """Handles model evaluation and result visualization."""

    @staticmethod
    def evaluate_model(model, val_loader):
        """Evaluate model on validation set."""
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.eval().to(device)
        y_true, y_pred = [], []

        with torch.no_grad():
            for batch in val_loader:
                labels = batch.pop("label").to(device)
                batch = {k: v.to(device) for k, v in batch.items()}
                logits = model(**batch)
                preds = logits.argmax(1)
                y_true.extend(labels.cpu())
                y_pred.extend(preds.cpu())

        # Calculate metrics
        report = classification_report(y_true, y_pred, digits=3, zero_division=0)
        cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

        # Extract accuracy and F1 scores
        acc = float(re.search(r"accuracy\s+([0-9.]+)", report).group(1))
        f1 = float(re.search(r"macro avg\s+[0-9.\s]+\s([0-9.]+)\s+[0-9.]+", report).group(1))

        return {"report": report, "cm": cm, "acc": acc, "f1": f1}

    @staticmethod
    def plot_results(results_df, best_cm):
        """Plot evaluation results."""
        # Plot metrics by run
        fig, ax = plt.subplots(figsize=(8, 5))
        x = np.arange(len(results_df))
        width = 0.35

        ax.bar(x - width/2, results_df["acc"], width, label="Accuracy")
        ax.bar(x + width/2, results_df["f1"], width, label="F1-Score")

        ax.set_xticks(x)
        ax.set_xticklabels([f"Run {i}" for i in x])
        ax.set_ylim(0, 1)
        ax.set_xlabel("Experiment Run")
        ax.set_ylabel("Score")
        ax.legend()
        ax.set_title("Validation Metrics by Run")
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.show()

        # Plot confusion matrix for best run
        fig, ax = plt.subplots(figsize=(4, 4))
        im = ax.imshow(best_cm, cmap='Blues')

        # Add labels
        ax.set_xticks([0, 1])
        ax.set_yticks([0, 1])
        ax.set_xticklabels(["No Claim", "Claim"])
        ax.set_yticklabels(["No Claim", "Claim"])
        ax.set_xlabel("Predicted Label")
        ax.set_ylabel("True Label")
        ax.set_title("Best Model Confusion Matrix")

        # Add text annotations
        for i in range(2):
            for j in range(2):
                text_color = "white" if best_cm[i, j] > best_cm.max() / 2 else "black"
                ax.text(j, i, best_cm[i, j], ha="center", va="center", color=text_color)

        plt.colorbar(im, fraction=0.046, pad=0.04)
        plt.tight_layout()
        plt.show()


def train_and_evaluate(config: Config):
    """Main function to train and evaluate models."""
    # Set seeds for reproducibility
    pl.seed_everything(config.SEED, workers=True)

    # Set up data processor
    processor = DataProcessor(config)

    # Prepare datasets
    logger.info("Preparing datasets")
    train_ds, val_ds = processor.prepare_datasets()

    # Create dataloaders
    train_loader, val_loader = processor.create_dataloaders(train_ds, val_ds)

    # Calculate class weights
    class_weights = processor.calculate_class_weights(train_ds)
    logger.info(f"Class weights: {class_weights}")

    # Set up early stopping callback
    early_stopping = pl.callbacks.EarlyStopping(
        monitor=config.EARLY_STOP_METRIC,
        mode=config.EARLY_STOP_MODE,
        patience=config.EARLY_STOP_PATIENCE,
        verbose=False
    )

    # Set up trainer
    trainer = pl.Trainer(
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1,
        max_epochs=config.MAX_EPOCHS,
        deterministic=True,
        log_every_n_steps=config.LOG_EVERY_N_STEPS,
        callbacks=[early_stopping],
        enable_checkpointing=False,
    )

    # Run hyperparameter search
    results = []
    confusion_matrices = []

    for i, hp_config in enumerate(config.HYPERPARAMETER_CONFIGS):
        logger.info(f"\n▶️  Run {i}: {hp_config}")

        # Initialize model with current hyperparameters
        model = EnvironmentalClaimClassifier(
            model_name=config.MODEL_NAME,
            class_weights=class_weights,
            **hp_config
        )

        # Train model
        trainer.fit(model, train_loader, val_loader)

        # Evaluate model
        eval_results = Evaluator.evaluate_model(model, val_loader)

        # Store results
        results.append({**hp_config, "acc": eval_results["acc"], "f1": eval_results["f1"]})
        confusion_matrices.append(eval_results["cm"])

        # Print classification report
        print(eval_results["report"])

    # Create results dataframe
    results_df = pd.DataFrame(results)

    # Find best model
    best_idx = results_df.f1.idxmax()
    best_config = results[best_idx]
    best_cm = confusion_matrices[best_idx]

    # Print summary
    print("\n=== Results Summary ===")
    print(results_df)
    print(f"\nBest run: {best_config}")

    # Plot results
    Evaluator.plot_results(results_df, best_cm)

    return {
        "results_df": results_df,
        "best_config": best_config,
        "best_confusion_matrix": best_cm
    }


if __name__ == "__main__":

    # Run the pipeline
    config = Config()
    results = train_and_evaluate(config)