In [None]:
import json
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from huggingface_hub import login
import matplotlib.pyplot as plt
import seaborn as sns
import os


class ModelEvaluator:
    def __init__(self, base_model_name: str, trained_model_path: str = None):
        # Login
        try:
            from google.colab import userdata
            HF_TOKEN = userdata.get('HF_TOKEN')
            login(token=HF_TOKEN)
        except:
            pass

        print(f"Loading model: {base_model_name}")
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Load base model
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            torch_dtype=torch.float16,
            device_map="auto",
            low_cpu_mem_usage=True
        )

        # Load trained model if path provided
        if trained_model_path:
            print(f"Loading trained model from: {trained_model_path}")
            self.model = PeftModel.from_pretrained(
                base_model,
                trained_model_path,
                is_trainable=False
            )
        else:
            self.model = base_model

        print("Model loaded\n")

    def generate_text(self, prompt: str, max_length: int = 150) -> str:
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                temperature=0.8,
                do_sample=True,
                top_p=0.9,
                pad_token_id=self.tokenizer.eos_token_id
            )

        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

    def generate_from_prompts(self, prompts: list) -> list:
        print(f"Generating from {len(prompts)} prompts...")
        generated = []
        for i, prompt in enumerate(prompts):
            text = self.generate_text(prompt)
            generated.append(text)
            if (i + 1) % 10 == 0:
                print(f"  Generated {i + 1}/{len(prompts)}")
        return generated

    def calculate_lexical_diversity(self, texts: list) -> dict:
        all_tokens = []
        for text in texts:
            tokens = text.lower().split()
            all_tokens.extend(tokens)

        types = len(set(all_tokens))
        tokens_count = len(all_tokens)
        ttr = types / tokens_count if tokens_count > 0 else 0

        bigrams = []
        for text in texts:
            words = text.lower().split()
            bigrams.extend([f"{words[i]}_{words[i+1]}" for i in range(len(words)-1)])

        unique_bigrams = len(set(bigrams))
        total_bigrams = len(bigrams)

        return {
            'type_token_ratio': round(ttr, 4),
            'vocabulary_size': types,
            'total_tokens': tokens_count,
            'bigram_diversity': round(unique_bigrams / total_bigrams if total_bigrams > 0 else 0, 4)
        }

    def calculate_perplexity(self, texts: list) -> float:
        total_loss = 0
        count = 0

        for text in texts:
            inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = self.model(**inputs, labels=inputs["input_ids"])
                total_loss += outputs.loss.item()
                count += 1

        avg_loss = total_loss / count
        perplexity = np.exp(avg_loss)
        return round(perplexity, 2)


def load_test_prompts(filepath: str, n: int = 50) -> list:
    """Load test prompts - default 50 for real evaluation"""
    with open(filepath, 'r') as f:
        data = [json.loads(line) for line in f]

    prompts = []
    for item in data[:n*2]:  # Get extra in case some are too short
        words = item['text'].split()
        if len(words) >= 20:
            prompt = ' '.join(words[:15])
            prompts.append(prompt)
            if len(prompts) >= n:
                break

    return prompts


def main():
    from google.colab import drive
    drive.mount('/content/drive')

    project_root = "/content/drive/MyDrive/FinalProject"
    BASE_MODEL = "meta-llama/Llama-3.2-1B"

    # Load test prompts - 50 for real evaluation
    test_file = f"{project_root}/human_baseline_data/test.jsonl"
    prompts = load_test_prompts(test_file, n=50)
    print(f"Loaded {len(prompts)} test prompts\n")

    # Models to evaluate - INCLUDING RECURSIVE GENERATIONS
    models_config = {
        'Base (Untrained)': None,
        'Human-trained': f"{project_root}/trained_models_v2/human_baseline_data_llama",
        'AI Gen 1': f"{project_root}/trained_models_v2/ai_generated_data_gpt2_medium_llama",
        'AI Gen 2': f"{project_root}/trained_models_v2/ai_gen2_llama",
        'AI Gen 3': f"{project_root}/trained_models_v2/ai_gen3_llama",
        'Mixed': f"{project_root}/trained_models_v2/mixed_data_gpt2_medium_llama"
    }

    results = {}

    for model_name, model_path in models_config.items():
        print("="*60)
        print(f"EVALUATING: {model_name}")
        print("="*60)

        try:
            # Check if model exists (for Gen 2 and Gen 3 which may not exist yet)
            if model_path and not os.path.exists(model_path):
                print(f"⚠️  Model not found. Skipping {model_name}")
                print(f"   Train this model first!\n")
                continue

            evaluator = ModelEvaluator(BASE_MODEL, model_path)
            generated = evaluator.generate_from_prompts(prompts)

            print("\nCalculating metrics...")

            lexical = evaluator.calculate_lexical_diversity(generated)
            perplexity = evaluator.calculate_perplexity(generated)

            results[model_name] = {
                'lexical_diversity': lexical,
                'perplexity': perplexity,
                'num_prompts': len(prompts)
            }

            print(f"\n{model_name} Results:")
            print(f"  Type-Token Ratio: {lexical['type_token_ratio']}")
            print(f"  Vocabulary Size: {lexical['vocabulary_size']}")
            print(f"  Bigram Diversity: {lexical['bigram_diversity']}")
            print(f"  Perplexity: {perplexity}")
            print()

            del evaluator
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"Error: {e}")
            import traceback
            traceback.print_exc()
            print("Skipping...\n")
            continue

    # Summary
    print("\n" + "="*60)
    print(f"RECURSIVE TRAINING EVALUATION ({len(prompts)} prompts)")
    print("="*60)

    # Print as table
    print(f"\n{'Model':<20} {'TTR':<8} {'Vocab':<8} {'Bigram':<8} {'Perplexity':<12}")
    print("-" * 62)

    for model_name, metrics in results.items():
        ttr = metrics['lexical_diversity']['type_token_ratio']
        vocab = metrics['lexical_diversity']['vocabulary_size']
        bigram = metrics['lexical_diversity']['bigram_diversity']
        ppl = metrics['perplexity']
        print(f"{model_name:<20} {ttr:<8.4f} {vocab:<8} {bigram:<8.4f} {ppl:<12.2f}")

    # Create Recursive Degradation Visualization
    if len(results) >= 3:  # Only if we have multiple generations
        create_recursive_degradation_chart(results, project_root)

    # Save
    results_path = f"{project_root}/evaluation_results_recursive.json"
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2)

    print(f"\n✓ Results saved to: {results_path}")
    print("\n" + "="*60)
    print("RECURSIVE EVALUATION COMPLETE")
    print("="*60)

    return results


def create_recursive_degradation_chart(results: dict, project_root: str):
    """Create line chart showing degradation across generations"""

    print("\n" + "="*60)
    print("GENERATING RECURSIVE DEGRADATION CHARTS")
    print("="*60)

    # Extract generation data
    generations = []
    ttr_values = []
    vocab_values = []
    bigram_values = []

    # Map models to generations
    gen_mapping = {
        'Human-trained': 0,
        'AI Gen 1': 1,
        'AI Gen 2': 2,
        'AI Gen 3': 3
    }

    for model_name in ['Human-trained', 'AI Gen 1', 'AI Gen 2', 'AI Gen 3']:
        if model_name in results:
            generations.append(gen_mapping[model_name])
            ttr_values.append(results[model_name]['lexical_diversity']['type_token_ratio'])
            vocab_values.append(results[model_name]['lexical_diversity']['vocabulary_size'])
            bigram_values.append(results[model_name]['lexical_diversity']['bigram_diversity'])

    # Create line chart
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    metrics_data = [
        ('Type-Token Ratio', ttr_values, 'TTR'),
        ('Vocabulary Size', vocab_values, 'Vocab'),
        ('Bigram Diversity', bigram_values, 'Bigram')
    ]

    for ax, (title, values, label) in zip(axes, metrics_data):
        ax.plot(generations, values, marker='o', linewidth=2, markersize=10, color='red', label='Degradation')
        ax.set_xlabel('Generation', fontsize=12, fontweight='bold')
        ax.set_ylabel(label, fontsize=12, fontweight='bold')
        ax.set_title(f'{title} Degradation\nAcross Recursive Training', fontsize=13, fontweight='bold')
        ax.grid(True, alpha=0.3)
        ax.set_xticks(generations)
        ax.set_xticklabels([f'Gen {g}' if g > 0 else 'Human' for g in generations])

        # Add value labels
        for gen, val in zip(generations, values):
            ax.annotate(f'{val:.3f}' if isinstance(val, float) else f'{val}',
                       xy=(gen, val), xytext=(0, 10),
                       textcoords='offset points',
                       ha='center', fontweight='bold')

        # Calculate degradation percentage
        if len(values) >= 2:
            total_degradation = ((values[0] - values[-1]) / values[0]) * 100
            ax.text(0.5, 0.05, f'Total Degradation: {total_degradation:.1f}%',
                   transform=ax.transAxes, ha='center',
                   bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.5),
                   fontsize=10, fontweight='bold')

    plt.tight_layout()
    plt.savefig(f'{project_root}/recursive_degradation_line_chart.png', dpi=300, bbox_inches='tight')
    print("✓ Saved: recursive_degradation_line_chart.png")
    plt.show()

    # Create bar chart comparison
    fig, ax = plt.subplots(figsize=(12, 6))

    model_names = list(results.keys())
    ttr_all = [results[m]['lexical_diversity']['type_token_ratio'] for m in model_names]

    colors = ['gray' if 'Base' in m else 'green' if 'Human' in m else 'orange' if 'Mixed' in m else 'red' for m in model_names]

    bars = ax.bar(range(len(model_names)), ttr_all, color=colors, alpha=0.8)
    ax.set_xlabel('Model', fontsize=12, fontweight='bold')
    ax.set_ylabel('Type-Token Ratio', fontsize=12, fontweight='bold')
    ax.set_title('Recursive Model Collapse: TTR Degradation', fontsize=14, fontweight='bold')
    ax.set_xticks(range(len(model_names)))
    ax.set_xticklabels(model_names, rotation=45, ha='right')
    ax.grid(True, alpha=0.3, axis='y')

    # Add value labels
    for bar, val in zip(bars, ttr_all):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
               f'{val:.3f}',
               ha='center', va='bottom', fontweight='bold')

    plt.tight_layout()
    plt.savefig(f'{project_root}/recursive_collapse_bar_chart.png', dpi=300, bbox_inches='tight')
    print("✓ Saved: recursive_collapse_bar_chart.png")
    plt.show()


if __name__ == "__main__":
    results = main()