In [None]:
import pandas as pd
import torch
from transformers import pipeline
import logging
from typing import Dict, Optional, List
from datetime import datetime
import os
from tqdm import tqdm

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('maritime_classifier.log'),
        logging.StreamHandler()
    ]
)

class MaritimeRiskClassifier:
    def __init__(self, model_name: str = "cross-encoder/nli-distilroberta-base"):
        """
        Initialize the Maritime Risk Classifier with improved error handling and logging.

        Args:
            model_name (str): Name of the transformer model to use
        """
        self.logger = logging.getLogger(__name__)
        self.logger.info("Initializing Maritime Risk Classifier")

        # Initialize device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.logger.info(f"Using device: {self.device}")

        try:
            self.classifier = pipeline(
                "zero-shot-classification",
                model=model_name,
                device=0 if torch.cuda.is_available() else -1
            )
            self.logger.info("Successfully loaded classification model")
        except Exception as e:
            self.logger.error(f"Failed to load model: {str(e)}")
            raise

        # Enhanced severity levels with more detailed descriptions
        self.severity_levels = {
            1: "Low - Minimal impacts",
            2: "Moderate - Some disruption",
            3: "High - Significant operational impact; may involve injuries",
            4: "Critical - Major incident with severe impact, likely with injuries",
            5: "Catastrophic - Severe consequences, likely fatalities or major environmental damage"
        }

        self.candidate_labels = list(self.severity_levels.values())

    def clean_input(self, text: Optional[str]) -> str:
        """
        Clean and validate input text with improved handling.

        Args:
            text: Input text to clean

        Returns:
            str: Cleaned text
        """
        if pd.isna(text) or text is None:
            return ""
        return str(text).strip()

    def predict(self, headline: str, description: str, risk_type: str) -> Dict:
        """
        Predict severity level for a maritime incident with enhanced error handling.

        Args:
            headline: Incident headline
            description: Incident description
            risk_type: Type of risk

        Returns:
            Dict containing severity level, label, and confidence
        """
        try:
            # Clean inputs
            headline = self.clean_input(headline)
            description = self.clean_input(description)
            risk_type = self.clean_input(risk_type)

            # Combine inputs with weights
            desc_text = description[:200] if description else ""
            text = f"{headline} | {risk_type} | {desc_text}".strip()

            if not text.strip():
                self.logger.warning("Empty input received")
                return self._get_default_prediction()

            # Get prediction
            result = self.classifier(text, self.candidate_labels)

            # Extract results
            predicted_label = result['labels'][0]
            confidence = result['scores'][0]
            severity_level = self._get_severity_level(predicted_label)

            return {
                'severity_level': severity_level,
                'severity_label': predicted_label,
                'confidence': confidence,
                'input_text': text
            }

        except Exception as e:
            self.logger.error(f"Prediction error: {str(e)}")
            return self._get_default_prediction()

    def _get_severity_level(self, label: str) -> int:
        """Extract severity level from label."""
        severity_mapping = {
            "Low": 1,
            "Moderate": 2,
            "High": 3,
            "Critical": 4,
            "Catastrophic": 5
        }
        return severity_mapping[label.split(' -')[0]]

    def _get_default_prediction(self) -> Dict:
        """Return default prediction for error cases."""
        return {
            'severity_level': 1,
            'severity_label': self.severity_levels[1],
            'confidence': 0.0,
            'input_text': ''
        }

    def classify_dataset(self, df: pd.DataFrame, batch_size: int = 100) -> pd.DataFrame:
        """
        Classify a dataset of maritime incidents with progress tracking and batching.

        Args:
            df: Input DataFrame
            batch_size: Size of batches for processing

        Returns:
            DataFrame with classification results
        """
        self.logger.info(f"Starting classification of {len(df)} records")
        results = []

        # Clean DataFrame
        df = df.fillna('')

        try:
            # Use tqdm for progress tracking
            for idx in tqdm(range(len(df)), desc="Processing records"):
                row = df.iloc[idx]
                try:
                    prediction = self.predict(
                        str(row.get('Cleaned_Headline', '')),
                        str(row.get('Cleaned_Description', '')),
                        str(row.get('Final Classification', ''))
                    )

                    results.append({
                        'headline': row.get('Cleaned_Headline', ''),
                        'Final Classification': row.get('Final Classification', ''),
                        'severity_level': prediction['severity_level'],
                        'severity_label': prediction['severity_label'],
                        'confidence': prediction['confidence'],
                        'input_text': prediction['input_text']
                    })

                except Exception as e:
                    self.logger.error(f"Error processing row {idx}: {str(e)}")
                    results.append(self._get_error_row(row))

        except Exception as e:
            self.logger.error(f"Error in classify_dataset: {str(e)}")

        return pd.DataFrame(results)

    def _get_error_row(self, row: pd.Series) -> Dict:
        """Create error row for failed classifications."""
        return {
            'headline': row.get('Cleaned_Headline', ''),
            'Final Classification': row.get('Final Classification', ''),
            'severity_level': 1,
            'severity_label': 'Error in processing',
            'confidence': 0.0,
            'input_text': ''
        }

def main():
    """Main function with enhanced error handling and reporting."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = "output"
    os.makedirs(output_dir, exist_ok=True)

    logger = logging.getLogger(__name__)

    try:
        # Load data
        logger.info("Loading data...")
        df = pd.read_excel('Updated_Filtered_Dataset.xlsx')
        logger.info(f"Loaded {len(df)} rows of data")

        # Initialize classifier
        logger.info("Initializing classifier...")
        classifier = MaritimeRiskClassifier()

        # Classify dataset
        logger.info("Starting classification...")
        classified_df = classifier.classify_dataset(df)

        # Save results
        output_file = os.path.join(output_dir, f'Severity_Output_{timestamp}.xlsx')
        classified_df.to_excel(output_file, index=False)
        logger.info(f"Results saved to {output_file}")

        # Generate summary
        summary = classified_df['severity_level'].value_counts().sort_index()
        logger.info("\nClassification Summary:")
        for level, count in summary.items():
            logger.info(f"Severity Level {level}: {count} incidents ({count/len(classified_df)*100:.1f}%)")

        # Save summary
        summary_file = os.path.join(output_dir, f'Summary_{timestamp}.txt')
        with open(summary_file, 'w') as f:
            f.write("Classification Summary\n")
            f.write("---------------------\n")
            for level, count in summary.items():
                f.write(f"Severity Level {level}: {count} incidents ({count/len(classified_df)*100:.1f}%)\n")

        logger.info(f"Summary saved to {summary_file}")

    except Exception as e:
        logger.error(f"Critical error in main: {str(e)}", exc_info=True)
        raise

if __name__ == "__main__":
    main()