In [None]:
import os
import re
import json
from typing import List, Dict
import pandas as pd
import numpy as np
from tqdm import tqdm
import argparse

from analytics.sentiment_analysis import EmailPreprocessor, EnsembleSentimentAnalyzer

from config.logger import CustomLogger
logger = CustomLogger(name="BaseChainAnalyzer")


class BaseChainsAnalyzer:
    """
    Analyzes sentiment of base chains from a JSONL file where each entry contains 
    a string with multiple emails marked as "Email 1:", "Email 2:", etc.
    """
    
    def __init__(self, jsonl_path: str):
        """
        Initialize the analyzer with the path to the JSONL file.
        
        Args:
            jsonl_path (str): Path to the JSONL file containing base chains
        """
        self.jsonl_path = jsonl_path
        logger.info(f"Loading date from {jsonl_path}...")
        self.df = pd.read_json(jsonl_path, lines=True)
        logger.info(f"Loaded {len(self.df)} email chains")
        
        self.preprocessor = EmailPreprocessor()
        self.ensemble_analyzer = EnsembleSentimentAnalyzer()
        self.results = []
        
        if "llama3b" in jsonl_path.lower():
            self.model_name = "Llama-3B"
        elif "llama8b" in jsonl_path.lower():
            self.model_name = "Llama-8B"
        else:
            self.model_name = "Unknown"
        
        logger.ok(f"Detected model: {self.model_name}")
        
    def split_email_chain(self, chain_text: str) -> List[str]:
        """
        Split a chain text containing multiple emails into individual email texts.
        
        Args:
            chain_text (str): The chain text with multiple emails
            
        Returns:
            List[str]: List of individual email texts
        """
        email_pattern = r'(?:<)?Email\s*\d+(?:>)?:\s*'        
        emails = re.split(email_pattern, chain_text)

        emails = [email.strip() for email in emails if email.strip()]
        
        return emails
    
    def process_chain(self, chain_text: str, chain_id: int) -> Dict:
        """
        Process a single email chain and compute its sentiment.
        
        Args:
            chain_text (str): The chain text with multiple emails
            chain_id (int): ID for the chain
            
        Returns:
            Dict: Dictionary with sentiment analysis results
        """
                
        emails = self.split_email_chain(chain_text)
        
        processed_emails = []
        individual_sentiments = []
        
        for i, email in enumerate(emails):
            processed_email = self.preprocessor.process_email(email)
            if processed_email.strip():
                processed_emails.append(processed_email)
                
                sentiment = self.ensemble_analyzer.predict_sentiment(processed_email)
                sentiment_result = {
                    "chain_id": chain_id,
                    "email_index": i,
                    "model": self.model_name,
                    "sentiment_neg": sentiment["sentiment_neg"],
                    "sentiment_neu": sentiment["sentiment_neu"],
                    "sentiment_pos": sentiment["sentiment_pos"]
                }
                individual_sentiments.append(sentiment_result)
        
        aggregated_text = " ".join(processed_emails)
        if not aggregated_text:
            chain_sentiment = {
                "sentiment_neg": 0.0,
                "sentiment_neu": 1.0,
                "sentiment_pos": 0.0
            }
        else:
            chain_sentiment = self.ensemble_analyzer.predict_sentiment(aggregated_text)
            
        result = {
            "chain_id": chain_id,
            "model": self.model_name,
            "num_emails": len(processed_emails),
            "sentiment_neg": chain_sentiment["sentiment_neg"],
            "sentiment_neu": chain_sentiment["sentiment_neu"],
            "sentiment_pos": chain_sentiment["sentiment_pos"],
            "individual_sentiments": individual_sentiments
        }
        
        return result
    
    def analyze_all_chains(self) -> List[Dict]:
        """
        Analyze all chains in the dataframe.
        
        Returns:
            List[Dict]: List of sentiment analysis results for each chain
        """
        self.results = []
        
        logger.info(f"Analyzing sentiment for {len(self.df)} chains")
        for i, row in tqdm(self.df.iterrows(), total=len(self.df)):
            chain_text = row["chain"]
            result = self.process_chain(chain_text, i)
            self.results.append(result)
            
        return self.results
    
    def save_results(self, output_dir: str = "output/sentiment_analysis") -> str:
        """
        Save the analysis results to a JSON file.
        
        Args:
            output_dir (str): Directory where results will be saved
            
        Returns:
            str: Path to the saved file
        """
        if not self.results:
            self.analyze_all_chains()
            
        os.makedirs(output_dir, exist_ok=True)
        
        model_str = self.model_name.lower().replace('-', '')
        filename = f"sentiment_base_{model_str}.json"
        file_path = os.path.join(output_dir, filename)
        
        chain_results = []
        for result in self.results:
            chain_result = {k: v for k, v in result.items() if k != "individual_sentiments"}
            chain_results.append(chain_result)
        
        with open(file_path, 'w') as f:
            json.dump(chain_results, f, indent=2)
            
        individual_results = []
        for result in self.results:
            individual_results.extend(result["individual_sentiments"])
            
        individual_file_path = os.path.join(output_dir, f"sentiment_base_{model_str}_individual.json")
        with open(individual_file_path, 'w') as f:
            json.dump(individual_results, f, indent=2)
        
        logger.ok(f"Saved chain results to: {file_path}")
        logger.ok(f"Saved individual email results to: {individual_file_path}")
            
        return file_path
    
    def get_summary_statistics(self) -> Dict:
        """
        Calculate summary statistics for the sentiment analysis.
        
        Returns:
            Dict: Dictionary with summary statistics
        """
        if not self.results:
            self.analyze_all_chains()
            
        sentiment_neg = [r["sentiment_neg"] for r in self.results]
        sentiment_neu = [r["sentiment_neu"] for r in self.results]
        sentiment_pos = [r["sentiment_pos"] for r in self.results]
        
        summary = {
            "model": self.model_name,
            "num_chains": len(self.results),
            "sentiment_neg_mean": np.mean(sentiment_neg),
            "sentiment_neg_std": np.std(sentiment_neg),
            "sentiment_neu_mean": np.mean(sentiment_neu),
            "sentiment_neu_std": np.std(sentiment_neu),
            "sentiment_pos_mean": np.mean(sentiment_pos),
            "sentiment_pos_std": np.std(sentiment_pos),
        }
        
        return summary

def analyze_base_chains(jsonl_path, output_dir="output/sentiment_analysis"):
    """
    Analyze sentiment of base chains from a JSONL file.
    
    Args:
        jsonl_path (str): Path to the JSONL file
        output_dir (str): Directory to save results
        
    Returns:
        Dict: Summary statistics of the analysis
    """
    analyzer = BaseChainsAnalyzer(jsonl_path)
    analyzer.analyze_all_chains()
    file_path = analyzer.save_results(output_dir)
    summary = analyzer.get_summary_statistics()
    return summary

In [None]:
jsonl_path_3b = "../data/email_datasets/synthetic/baserefine/base/llama3b/base_chains.jsonl"
results_3b = analyze_base_chains(jsonl_path_3b, output_dir="output/llama3b_sentiment")

In [None]:
jsonl_path_8b = "../data/email_datasets/synthetic/baserefine/base/llama8b/base_chains.jsonl"
results_8b = analyze_base_chains(jsonl_path_8b, output_dir="output/llama8b_sentiment")

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
import matplotlib.ticker as mtick
import os
import json
import glob

def load_and_process_sentiment_data(method_filepath, base_llama3b, base_llama8b):
    """
    Load and process raw sentiment data from files and calculate averages.
    
    Parameters:
    -----------
    method_filepath : str
        Directory containing sentiment distribution data for refined models
    base_llama3b : str
        Path to base Llama-3B sentiment data JSON file
    base_llama8b : str
        Path to base Llama-8B sentiment data JSON file
        
    Returns:
    --------
    DataFrame with calculated sentiment averages
    """
    results = []
    
    base_models = {
        'llama3b': base_llama3b,
        'llama8b': base_llama8b
    }
    
    for model_key, file_path in base_models.items():
        try:
            with open(file_path, 'r') as f:
                base_data = json.load(f)
                
            df = pd.DataFrame(base_data)
            
            sentiment_cols = ['sentiment_neg', 'sentiment_neu', 'sentiment_pos']
            if all(col in df.columns for col in sentiment_cols):
                avg_sentiments = {
                    'base_model': model_key,
                    'refiner_model': 'base',
                    'stage': 'Base',
                    'sentiment_neg_avg': df['sentiment_neg'].mean(),
                    'sentiment_neu_avg': df['sentiment_neu'].mean(),
                    'sentiment_pos_avg': df['sentiment_pos'].mean(),
                    'sentiment_neg_std': df['sentiment_neg'].std(),
                    'sentiment_neu_std': df['sentiment_neu'].std(),
                    'sentiment_pos_std': df['sentiment_pos'].std(),
                    'sample_count': len(df)
                }
                results.append(avg_sentiments)
                print(f"Processed {model_key} base data with {len(df)} samples")
            else:
                print(f"Warning: Missing sentiment columns in {model_key} base data")
        except Exception as e:
            print(f"Error processing {model_key} base data: {e}")
    
    pattern = os.path.join(method_filepath, "sentiment_distribution_*.csv")
    distribution_files = glob.glob(pattern)
    
    for file_path in distribution_files:
        try:
            file_name = os.path.basename(file_path)
            
            if 'llama3b' in file_name.lower():
                base_model = 'llama3b'
            elif 'llama8b' in file_name.lower():
                base_model = 'llama8b'
            else:
                print(f"Unknown base model in file: {file_name}")
                continue
                
            for refiner in ['claude', 'deepseek', 'gemini', 'gpt4', 'mistral']:
                if refiner in file_name.lower():
                    refiner_model = refiner
                    break
            else:
                print(f"Unknown refiner model in file: {file_name}")
                continue
            
            df = pd.read_csv(file_path)
            
            sentiment_cols = ['sentiment_neg', 'sentiment_neu', 'sentiment_pos']
            if all(col in df.columns for col in sentiment_cols):
                avg_sentiments = {
                    'base_model': base_model,
                    'refiner_model': refiner_model,
                    'stage': 'Refined',
                    'sentiment_neg_avg': df['sentiment_neg'].mean(),
                    'sentiment_neu_avg': df['sentiment_neu'].mean(),
                    'sentiment_pos_avg': df['sentiment_pos'].mean(),
                    'sentiment_neg_std': df['sentiment_neg'].std(),
                    'sentiment_neu_std': df['sentiment_neu'].std(),
                    'sentiment_pos_std': df['sentiment_pos'].std(),
                    'sample_count': len(df)
                }
                results.append(avg_sentiments)
                print(f"Processed {base_model}_{refiner_model} data with {len(df)} samples")
            else:
                print(f"Warning: Missing sentiment columns in {file_name}")
                
        except Exception as e:
            print(f"Error processing {file_path}: {e}")
    
    results_df = pd.DataFrame(results)
    
    results_df['base_model_display'] = results_df['base_model'].apply(
        lambda x: 'Llama-3B' if x == 'llama3b' else 'Llama-8B')
    
    results_df['refiner_model_display'] = results_df['refiner_model'].apply(
        lambda x: x.upper() if x.lower() == 'gpt4' else x.title() if x != 'base' else 'Base')
    
    os.makedirs("output/sentiment", exist_ok=True)
    results_df.to_csv("../output/sentiment/recomputed_sentiment_averages.csv", index=False)
    print(f"Saved recomputed sentiment averages to output/sentiment/recomputed_sentiment_averages.csv")
    
    return results_df

def create_sentiment_visualization(results_df, sentiment_type, color, output_dir="output/sentiment"):
    """
    Creates a side-by-side visualization of sentiment changes between base and refined models.
    
    Parameters:
    -----------
    results_df : DataFrame
        DataFrame containing sentiment results
    sentiment_type : str
        Type of sentiment to visualize ('neg', 'pos', or 'neu')
    color : str or tuple
        Color to use for the bars
    output_dir : str
        Directory to save the output figures
    """
    sentiment_map = {
        'neg': ('sentiment_neg_avg', 'Negative'),
        'pos': ('sentiment_pos_avg', 'Positive'),
        'neu': ('sentiment_neu_avg', 'Neutral')
    }
    
    column_name, display_name = sentiment_map.get(sentiment_type, sentiment_map['neg'])
    
    delta_data = []
    
    for base_model in results_df['base_model_display'].unique():
        base_data = results_df[(results_df['base_model_display'] == base_model) & 
                             (results_df['refiner_model_display'] == 'Base')]
        
        if base_data.empty:
            print(f"No base data for {base_model}")
            continue
        
        base_value = base_data[column_name].values[0]
        print(f"{base_model} base {sentiment_type} value: {base_value:.6f}")
        
        for refiner in results_df[results_df['refiner_model_display'] != 'Base']['refiner_model_display'].unique():
            refined_data = results_df[(results_df['base_model_display'] == base_model) & 
                                    (results_df['refiner_model_display'] == refiner)]
            
            if refined_data.empty:
                print(f"No refined data for {base_model} with {refiner}")
                continue
            
            refined_value = refined_data[column_name].values[0]
            delta = refined_value - base_value
            
            print(f"{base_model} with {refiner} {sentiment_type}: {refined_value:.6f}, delta: {delta:.6f}")
            
            delta_data.append({
                'base_model': base_model,
                'refiner_model': refiner,
                f'delta_{sentiment_type}': delta
            })
    
    delta_df = pd.DataFrame(delta_data)
    
    if delta_df.empty:
        print(f"No delta data to plot for {display_name} sentiment")
        return None
    
    rcParams['font.size'] = 10
    rcParams['axes.linewidth'] = 0.8
    rcParams['axes.spines.top'] = False
    rcParams['axes.spines.right'] = False
    rcParams['xtick.major.width'] = 0.8
    rcParams['ytick.major.width'] = 0.8
    
    fig, axes = plt.subplots(1, 2, figsize=(8, 3.5), sharey=True)
    
    base_models = sorted(delta_df['base_model'].unique())
    
    delta_col = f'delta_{sentiment_type}'
    y_max = delta_df[delta_col].max() * 1.2
    y_min = min(-0.02, delta_df[delta_col].min() * 1.2)
    
    if abs(y_max) < 0.01:
        y_max = max(0.01, y_max)
    if abs(y_min) < 0.01:
        y_min = min(-0.01, y_min)
    
    for i, base_model in enumerate(base_models):
        model_data = delta_df[delta_df['base_model'] == base_model]
        if model_data.empty:
            continue
            
        refiners = sorted(model_data['refiner_model'].unique())
        
        x = np.arange(len(refiners))
        
        delta_values = []
        for refiner in refiners:
            refiner_data = model_data[model_data['refiner_model'] == refiner]
            if not refiner_data.empty:
                delta_values.append(refiner_data[f'delta_{sentiment_type}'].values[0])
            else:
                delta_values.append(0)
        
        bars = axes[i].bar(x, delta_values, 0.7, 
                         color=color, 
                         edgecolor='black', 
                         linewidth=0.5)
        
        for j, bar in enumerate(bars):
            height = bar.get_height()
            
            if height >= 0:
                va = 'bottom'
                offset = 3
            else:
                va = 'top'
                offset = -9
                
            if abs(height) < 0.001:
                label = f'{height * 100:+.3f}%'
                fontsize = 8
            else:
                label = f'{height * 100:+.1f}%'
                fontsize = 9
                
            axes[i].annotate(label,
                           xy=(bar.get_x() + bar.get_width()/2, height),
                           xytext=(0, offset), 
                           textcoords="offset points",
                           ha='center', 
                           va=va,
                           fontsize=fontsize)
        
        axes[i].axhline(y=0, color='black', linestyle='-', alpha=0.3, linewidth=0.8)
        axes[i].grid(True, linestyle='--', alpha=0.3, axis='y')
        axes[i].yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=1))
        
        axes[i].set_ylim(-0.16, 0.16)
        
        axes[i].set_xticks(x)
        axes[i].set_xticklabels(refiners, rotation=0, ha='center', fontsize=9)
        
        axes[i].set_title(base_model, fontsize=11, fontweight="bold")
    
    axes[0].set_ylabel(f'Relative Change', fontsize=10)
    
    plt.tight_layout()
    
    os.makedirs(output_dir, exist_ok=True)
    output_path_png = os.path.join(output_dir, f"{sentiment_type}_sentiment_sidebyside.png")
    
    plt.savefig(output_path_png, format='png', dpi=300, bbox_inches='tight')
    
    return fig

def main():
    method_filepath = "../output/sentiment_distribution"
    base_llama3b = "../output/llama3b_sentiment/sentiment_base_llama3b.json"
    base_llama8b = "../output/llama3b_sentiment/sentiment_base_llama8b.json"
    
    results_df = load_and_process_sentiment_data(method_filepath, base_llama3b, base_llama8b)
    
    negative_color = '#B3003F'  # Burgundy red
    positive_color = '#006C3B'  # Forest green
    neutral_color = '#474C55'   # Slate gray
    
    create_sentiment_visualization(results_df, 'neg', negative_color, "output/sentiment")
    create_sentiment_visualization(results_df, 'pos', positive_color, "output/sentiment")
    create_sentiment_visualization(results_df, 'neu', neutral_color, "output/sentiment")
    

if __name__ == "__main__":
    main()