<a href="https://colab.research.google.com/github/spamhamneggs/FinalProjectCOMP6885/blob/main/brautigan_llm_trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Brautigan LLM Training Pipeline
Fine-tunes Qwen3-4B to act as an interactive haiku editor that suggests
specific improvements, rewrites, and alternatives based on learned patterns.

## Installation

In [10]:
%%capture
import os, re
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    import torch; v = re.match(r"[0-9]{1,}\.[0-9]{1,}", str(torch.__version__)).group(0)
    xformers = "xformers==" + ("0.0.33.post1" if v=="2.9" else "0.0.32.post2" if v=="2.8" else "0.0.29.post3")
    !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth
!pip install transformers==4.56.2
!pip install --no-deps trl==0.22.2
!pip install evaluate rouge_score bert_score

## Imports

In [11]:
import os
import json
import pandas as pd
import numpy as np
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig, DataCollatorForLanguageModeling
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import warnings
import random
import evaluate

from unsloth import FastLanguageModel

warnings.filterwarnings('ignore')

In [12]:
# ============================================================================
# CONFIGURATION
# ============================================================================

@dataclass
class TrainingConfig:
    """Training configuration parameters"""
    model_name: str = "unsloth/Qwen3-4B-Instruct-2507"
    dataset_path: str = "haiku_analysis_results.csv"
    output_dir: str = "./haiku_brautigan_model"

    # Training hyperparameters
    num_epochs: int = 3
    batch_size: int = 2
    gradient_accumulation_steps: int = 2
    # FIX 1: Increase Learning Rate for QLoRA
    learning_rate: float = 3e-4
    max_seq_length: int = 1024
    weight_decay: float = 0.1

    lora_r: int = 128
    lora_alpha: int = 256
    lora_dropout: float = 0.05
    min_quality_score: float = 0.3
    max_samples: int = 25600

In [13]:
# ============================================================================
# GRAMMARLY-STYLE SUGGESTION GENERATOR
# ============================================================================

class BrautiganTrainer:
    """Generates Grammarly-style training examples from metrics"""

    def __init__(self):
        self.suggestion_templates = self._load_templates()

    def _load_templates(self) -> Dict:
        """Natural language templates for different types of suggestions"""
        return {
            'syllable_line': [
                "Line {line} has {current} syllables but should have {target}. Try removing/adding a syllable.",
                "Line {line} needs adjustment: currently {current} syllables, target is {target}.",
                "Consider revising line {line} to match the {target}-syllable target (currently {current})."
            ],
            'weak_coherence': [
                "The connection between lines feels weak. Consider strengthening the thematic link.",
                "Lines don't flow together smoothly. Try connecting them with a shared image or concept.",
                "The three lines feel disconnected. Unify them around a single moment or observation."
            ],
            'abstract_imagery': [
                "The language is abstract. Replace with concrete, sensory details.",
                "Use more specific imagery. Show the reader what you see, hear, or feel.",
                "Too conceptual. Ground the haiku in physical, observable details."
            ],
            'weak_kireji': [
                "The haiku lacks juxtaposition. Try contrasting two images or adding a shift in perspective.",
                "Consider adding a 'cutting word' or moment of contrast between lines.",
                "The lines flow but don't create tension. Introduce an unexpected element or shift."
            ],
            'no_nature': [
                "Traditional haiku often include nature imagery. Consider adding a seasonal word (kigo).",
                "Try incorporating a natural element to ground the haiku in the physical world.",
                "Add natural imagery - seasons, weather, plants, or animals."
            ],
            'repetitive_words': [
                "Word choice is repetitive. Vary your vocabulary for stronger impact.",
                "Some words repeat. Find more precise alternatives.",
                "Diversify your language. Each word should earn its place."
            ],
            'low_sensory': [
                "Enhance sensory detail. What can be seen, heard, or felt?",
                "Add sensory language - visual, auditory, or tactile details.",
                "Make the imagery more vivid. Engage the reader's senses."
            ]
        }

    def generate_suggestions(self, row: pd.Series, haiku_text: str) -> Dict:
        """Generate Grammarly-style suggestions with optional rewrites"""

        suggestions = []

        # 1. Syllable issues (specific line-level feedback)
        if not row['follows_575_pattern']:
            if row['line1_syllables'] != 5:
                line_text = row['line1']
                suggestions.append({
                    "type": "syllable_count",
                    "severity": "error",
                    "line": 1,
                    "message": random.choice(self.suggestion_templates['syllable_line']).format(
                        line=1, current=int(row['line1_syllables']), target=5
                    ),
                    "original": line_text,
                    "alternatives": self._generate_syllable_alternatives(line_text, int(row['line1_syllables']), 5)
                })

            if row['line2_syllables'] != 7:
                line_text = row['line2']
                suggestions.append({
                    "type": "syllable_count",
                    "severity": "error",
                    "line": 2,
                    "message": random.choice(self.suggestion_templates['syllable_line']).format(
                        line=2, current=int(row['line2_syllables']), target=7
                    ),
                    "original": line_text,
                    "alternatives": self._generate_syllable_alternatives(line_text, int(row['line2_syllables']), 7)
                })

            if row['line3_syllables'] != 5:
                line_text = row['line3']
                suggestions.append({
                    "type": "syllable_count",
                    "severity": "error",
                    "line": 3,
                    "message": random.choice(self.suggestion_templates['syllable_line']).format(
                        line=3, current=int(row['line3_syllables']), target=5
                    ),
                    "original": line_text,
                    "alternatives": self._generate_syllable_alternatives(line_text, int(row['line3_syllables']), 5)
                })

        # 2. Semantic coherence (whole-haiku feedback)
        if row['semantic_coherence_sbert'] < 0.5:
            suggestions.append({
                "type": "coherence",
                "severity": "warning",
                "message": random.choice(self.suggestion_templates['weak_coherence']),
                "score": float(row['semantic_coherence_sbert']),
                "suggestion": "Consider: What single moment or image connects all three lines?"
            })

        # 3. Imagery concreteness (word-level suggestions)
        if row['imagery_concreteness'] < 0.5:
            suggestions.append({
                "type": "imagery",
                "severity": "warning",
                "message": random.choice(self.suggestion_templates['abstract_imagery']),
                "score": float(row['imagery_concreteness']),
                "examples": self._get_concrete_imagery_examples()
            })

        # 4. Kireji/juxtaposition (structural suggestion)
        if row['kireji_strength'] < 0.3:
            suggestions.append({
                "type": "structure",
                "severity": "suggestion",
                "message": random.choice(self.suggestion_templates['weak_kireji']),
                "score": float(row['kireji_strength']),
                "technique": "Try: observation in lines 1-2, then shift perspective or add contrast in line 3"
            })

        # 5. Nature imagery (content suggestion)
        if row['nature_score'] < 0.2:
            suggestions.append({
                "type": "content",
                "severity": "suggestion",
                "message": random.choice(self.suggestion_templates['no_nature']),
                "score": float(row['nature_score']),
                "kigo_examples": ["cherry blossoms (spring)", "cicada (summer)", "red leaves (autumn)", "snow (winter)"]
            })

        # 6. Lexical diversity (word-level feedback)
        if row['lexical_diversity'] < 0.6:
            suggestions.append({
                "type": "word_choice",
                "severity": "suggestion",
                "message": random.choice(self.suggestion_templates['repetitive_words']),
                "score": float(row['lexical_diversity'])
            })

        # 7. Sensory details (enhancement suggestion)
        sensory_max = max(
            row['imagery_visual'],
            row['imagery_auditory'],
            row['imagery_tactile'],
            row['imagery_gustatory'],
            row['imagery_olfactory']
        )

        if sensory_max < 0.3:
            suggestions.append({
                "type": "enhancement",
                "severity": "suggestion",
                "message": random.choice(self.suggestion_templates['low_sensory']),
                "prompt": "What specific colors, sounds, or textures are present?"
            })

        # Generate overall assessment
        quality = row['composite_haiku_quality_score']

        if quality >= 0.8:
            overall = "Excellent haiku! Strong across all dimensions."
        elif quality >= 0.6:
            overall = "Good foundation. A few refinements would elevate this."
        elif quality >= 0.4:
            overall = "Solid attempt. Focus on the suggestions to strengthen impact."
        else:
            overall = "This needs work. Start with structural and imagery improvements."

        return {
            "overall_assessment": overall,
            "quality_score": float(quality),
            "suggestions": suggestions if suggestions else [
                {
                    "type": "praise",
                    "severity": "none",
                    "message": "Well-crafted haiku! No major issues detected."
                }
            ]
        }

    def _generate_syllable_alternatives(self, line: str, current: int, target: int) -> List[str]:
        """Generate plausible alternatives for syllable count fixes"""
        # This is illustrative - in practice, the LLM will learn to generate these
        if current > target:
            return [f"[Remove {current - target} syllable(s) from this line]"]
        else:
            return [f"[Add {target - current} syllable(s) to this line]"]

    def _get_concrete_imagery_examples(self) -> List[str]:
        """Examples of concrete vs abstract imagery"""
        return [
            "Instead of 'beauty' → 'pink petals'",
            "Instead of 'sadness' → 'wilting stems'",
            "Instead of 'time passes' → 'shadows lengthen'"
        ]

In [14]:
# ============================================================================
# DATASET PREPARATION
# ============================================================================

class HaikuDatasetBuilder:
    """Builds Grammarly-style training dataset"""

    def __init__(self, config: TrainingConfig):
        self.config = config
        self.trainer_gen = BrautiganTrainer()
        # CRITICAL FIX: Load tokenizer for diagnostics
        self.tokenizer = AutoTokenizer.from_pretrained(
            config.model_name,
            trust_remote_code=True
        )

    def load_and_prepare(self) -> Tuple[Dataset, Dataset, Dataset]:
        """Load CSV and create conversational training examples with Train/Val/Test split"""
        print(f"Loading dataset from {self.config.dataset_path}...")
        df = pd.read_csv(self.config.dataset_path)
        print(f"Initial dataset size: {len(df)}")

        # Filter by quality
        if self.config.min_quality_score is not None:
            df = df[df['composite_haiku_quality_score'] >= self.config.min_quality_score]
            print(f"After quality filter (>= {self.config.min_quality_score}): {len(df)} examples")

        if df.empty:
            raise ValueError("DataFrame is empty after filtering!")

        # Cap dataset size with stratified sampling
        if len(df) > self.config.max_samples:
            quality_bins = pd.qcut(df['composite_haiku_quality_score'], q=4,
                                  labels=['low', 'med', 'good', 'excellent'],
                                  duplicates='drop')
            samples_per_bin = self.config.max_samples // len(quality_bins.unique())
            sampled_dfs = []

            for bin_label in quality_bins.unique():
                bin_df = df[quality_bins == bin_label]
                n_samples = min(len(bin_df), samples_per_bin)
                sampled_dfs.append(bin_df.sample(n=n_samples, random_state=42))

            df = pd.concat(sampled_dfs).sample(frac=1, random_state=42)
            print(f"Stratified sampling: {len(df)} examples across quality spectrum")

        # Create training examples - SIMPLIFIED OUTPUT FORMAT
        examples = []
        for idx, row in df.iterrows():
            haiku_text = f"{row['line1']}\n{row['line2']}\n{row['line3']}"
            feedback = self.trainer_gen.generate_suggestions(row, haiku_text)

            # MUCH SIMPLER: Natural language instead of JSON
            simple_output = f"Quality: {feedback['quality_score']:.2f}/1.0\n\n{feedback['overall_assessment']}\n\nIssues:"

            for i, s in enumerate(feedback['suggestions'][:3], 1):
                simple_output += f"\n{i}. [{s['severity'].upper()}] {s['message']}"

            example = {
                "instruction": self._get_system_prompt(),
                "input": haiku_text,
                "output": simple_output
            }
            examples.append(example)

        print(f"Created {len(examples)} training examples")

        # === DIAGNOSTICS ===
        print("\n" + "="*60)
        print("DATASET DIAGNOSTICS")
        print("="*60)

        # Sample one example
        sample = examples[0]
        formatted = f"""<|im_start|>system
{sample['instruction']}<|im_end|>
<|im_start|>user
Please review this haiku:

{sample['input']}<|im_end|>
<|im_start|>assistant
{sample['output']}<|im_end|>"""

        print("\nSample formatted prompt:")
        print(formatted[:500] + "..." if len(formatted) > 500 else formatted)
        print("\n" + "-"*60)

        # Tokenize and check length
        tokens = self.tokenizer(formatted, add_special_tokens=True)
        print(f"Sample token count: {len(tokens['input_ids'])}")
        print(f"Max sequence length: {self.config.max_seq_length}")
        print(f"Average output length: {np.mean([len(ex['output']) for ex in examples]):.0f} chars")
        print(f"Max output length: {max([len(ex['output']) for ex in examples])} chars")

        # Check for truncation issues
        all_lengths = []
        for ex in examples:
            full_prompt = f"<|im_start|>system\n{ex['instruction']}<|im_end|>\n<|im_start|>user\nPlease review this haiku:\n\n{ex['input']}<|im_end|>\n<|im_start|>assistant\n{ex['output']}<|im_end|>"
            token_len = len(self.tokenizer(full_prompt, add_special_tokens=True)['input_ids'])
            all_lengths.append(token_len)

        too_long = sum(1 for l in all_lengths if l > self.config.max_seq_length)
        print(f"\nExamples exceeding max_seq_length: {too_long}/{len(examples)} ({100*too_long/len(examples):.1f}%)")
        print(f"Average prompt length: {np.mean(all_lengths):.0f} tokens")
        print(f"Max prompt length: {max(all_lengths)} tokens")
        print("="*60 + "\n")

        # === TRAIN/VAL/TEST SPLIT ===
        full_dataset = Dataset.from_pandas(pd.DataFrame(examples))

        # 1. Split off 20% for Val+Test
        train_valtest = full_dataset.train_test_split(test_size=0.2, seed=42)
        train_dataset = train_valtest['train']

        # 2. Split the 20% into 50/50 (10% Val, 10% Test)
        val_test = train_valtest['test'].train_test_split(test_size=0.5, seed=42)
        val_dataset = val_test['train']
        test_dataset = val_test['test']

        print(f"Training Set:   {len(train_dataset)} examples (80%)")
        print(f"Validation Set: {len(val_dataset)} examples (10%)")
        print(f"Test Set:       {len(test_dataset)} examples (10%)")

        return train_dataset, val_dataset, test_dataset

    def _get_system_prompt(self) -> str:
        """SIMPLIFIED system prompt"""
        return """You are a haiku editor. Analyze haiku and provide concise feedback.

Format:
Quality: X/1.0
[Brief overall assessment]
Issues:
1. [SEVERITY] specific issue and suggestion
2. [SEVERITY] specific issue and suggestion
3. [SEVERITY] specific issue and suggestion

Focus on: syllable count (5-7-5), imagery, coherence, and structure.
Be direct and actionable."""

In [15]:
# ============================================================================
# MODEL TRAINING
# ============================================================================
class BrautiganModelTrainer:
    """Fine-tunes Qwen3-4B as a haiku editor"""

    def __init__(self, config: TrainingConfig):
        self.config = config
        self.model = None
        self.tokenizer = None

    def setup_model(self):
        """Load and prepare model for training"""
        print(f"Loading model: {self.config.model_name}")

        # 1. Load Model via Unsloth for optimization
        self.model, _ = FastLanguageModel.from_pretrained(
            model_name = self.config.model_name,
            max_length = self.config.max_seq_length,
            dtype = None,
            load_in_4bit = True,
        )

        # 2. Load Tokenizer explicitly via Transformers to ensure clean state
        from transformers import AutoTokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.config.model_name,
            trust_remote_code=True
        )

        # Correctly apply PEFT/LoRA using Unsloth's optimized method
        self.model = FastLanguageModel.get_peft_model(
            self.model,
            r = self.config.lora_r,
            target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                              "gate_proj", "up_proj", "down_proj"],
            lora_alpha = self.config.lora_alpha,
            lora_dropout = self.config.lora_dropout,
            bias = "none",
            use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
            random_state = 3407,
        )
        self.model.print_trainable_parameters()

    def train(self, train_dataset: Dataset, eval_dataset: Dataset = None):
        """Train with diagnostics"""
        print("Starting training with completion masking...")

        # === FORCE TOKENIZER FIX IMMEDIATELY BEFORE USE ===
        # TRL/Unsloth sometimes revert tokens or defaults are wrong.
        # We forcefully correct it here to ensure SFTTrainer sees the right token.
        self.tokenizer.eos_token = "<|im_end|>"
        self.tokenizer.pad_token = "<|im_end|>"
        self.tokenizer.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
        self.tokenizer.pad_token_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
        print(f"Tokenizer force-fixed inside train(): eos='{self.tokenizer.eos_token}' (id={self.tokenizer.eos_token_id})")
        # ==================================================

        # === PRE-TRAINING TEST ===
        print("\n" + "="*60)
        print("PRE-TRAINING TEST")
        print("="*60)

        test_example = train_dataset[0]
        test_prompt = f"""<|im_start|>system\n{test_example['instruction']}<|im_end|>\n<|im_start|>user\nPlease review this haiku:\n\n{test_example['input']}<|im_end|>\n<|im_start|>assistant\n"""

        inputs = self.tokenizer(test_prompt, return_tensors="pt").to(self.model.device)
        self.model.eval()
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=100,
                do_sample=False,
                pad_token_id=self.tokenizer.pad_token_id
            )

        pre_train_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        print("Pre-training response (last 300 chars):")
        print(pre_train_output[-300:])
        print("="*60 + "\n")

        # Format dataset
        def formatting_func(example):
            if isinstance(example['instruction'], list):
                return [
                    f"<|im_start|>system\n{example['instruction'][i]}<|im_end|>\n<|im_start|>user\nPlease review this haiku:\n\n{example['input'][i]}<|im_end|>\n<|im_start|>assistant\n{example['output'][i]}<|im_end|>"
                    for i in range(len(example['instruction']))
                ]
            else:
                return [f"<|im_start|>system\n{example['instruction']}<|im_end|>\n<|im_start|>user\nPlease review this haiku:\n\n{example['input']}<|im_end|>\n<|im_start|>assistant\n{example['output']}<|im_end|>"]


        # Configure training
        sft_config = SFTConfig(
            output_dir=self.config.output_dir,
            num_train_epochs=self.config.num_epochs,
            per_device_train_batch_size=self.config.batch_size,
            gradient_accumulation_steps=self.config.gradient_accumulation_steps,
            learning_rate=self.config.learning_rate,
            lr_scheduler_type="cosine",
            warmup_ratio=0.3,
            weight_decay=self.config.weight_decay,
            fp16=False,
            bf16=True,
            logging_steps=10,
            logging_first_step=True,
            save_steps=2000,
            save_total_limit=2,
            optim="paged_adamw_32bit",
            eval_strategy="steps" if eval_dataset else "no",
            eval_steps=100,
            per_device_eval_batch_size=2,
            max_grad_norm=0.5,
            dataset_text_field="text", # Kept as standard, though formatting_func takes precedence
            max_length=self.config.max_seq_length,
            packing=False,
            report_to="none",
        )

        trainer = SFTTrainer(
            model=self.model,
            processing_class=self.tokenizer,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            formatting_func=formatting_func,
            args=sft_config,
        )

         # ADD THIS BEFORE trainer.train():
        print("\n" + "="*60)
        print("GRADIENT CHECK")
        print("="*60)

        # Get one batch
        self.model.train()
        batch = next(iter(trainer.get_train_dataloader()))
        batch = {k: v.to(self.model.device) for k, v in batch.items()}

        # Forward pass
        outputs = self.model(**batch)
        loss = outputs.loss

        print(f"Loss value: {loss.item()}")
        print(f"Loss requires grad: {loss.requires_grad}")

        # Backward pass
        loss.backward()

        # Check gradients
        grad_norms = []
        for name, param in self.model.named_parameters():
            if param.requires_grad and param.grad is not None:
                grad_norm = param.grad.norm().item()
                grad_norms.append(grad_norm)
                if len(grad_norms) <= 5:  # Print first 5
                    print(f"  {name}: grad_norm={grad_norm:.6f}")

        print(f"\nTotal params with gradients: {len(grad_norms)}")
        print(f"Gradient stats: min={min(grad_norms):.6f}, max={max(grad_norms):.6f}, mean={np.mean(grad_norms):.6f}")

        if max(grad_norms) < 1e-6:
            print("⚠️  WARNING: Gradients are near zero! Model is not learning!")

        print("="*60 + "\n")

        # Clear gradients before actual training
        self.model.zero_grad()

        trainer.train()

        # === POST-TRAINING TEST ===
        print("\n" + "="*60)
        print("POST-TRAINING TEST")
        print("="*60)

        self.model.eval()
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=100,
                do_sample=False,
                pad_token_id=self.tokenizer.pad_token_id
            )

        post_train_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        print("Post-training response (last 300 chars):")
        print(post_train_output[-300:])
        print("\n" + "="*60)
        print("COMPARISON:")
        print("Before:", pre_train_output[-150:])
        print("After: ", post_train_output[-150:])
        print("="*60 + "\n")

        # Save
        self.model.save_pretrained(self.config.output_dir)
        self.tokenizer.save_pretrained(self.config.output_dir)
        print(f"\n✅ Model saved to {self.config.output_dir}")

    def run_test_evaluation(self, dataset: Dataset, num_samples: int = 30):
        """Run rigorous metrics-based testing (ROUGE, BLEU, BERTScore)"""
        print("\n" + "="*60)
        print(f"RUNNING FINAL TEST EVALUATION (on {num_samples} samples)")
        print("="*60)

        # Load metrics
        try:
            rouge = evaluate.load("rouge")
            bleu = evaluate.load("bleu")
            bertscore = evaluate.load("bertscore")
        except Exception as e:
            print(f"Error loading metrics: {e}")
            return {}

        # Select samples
        n = min(len(dataset), num_samples)
        test_subset = dataset.select(range(n))

        predictions = []
        references = []

        self.model.eval()
        print("Generating predictions...")

        for i, example in enumerate(test_subset):
            prompt = f"<|im_start|>system\n{example['instruction']}<|im_end|>\n<|im_start|>user\nPlease review this haiku:\n\n{example['input']}<|im_end|>\n<|im_start|>assistant\n"

            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=512,
                    do_sample=False, # Deterministic for evaluation
                    pad_token_id=self.tokenizer.pad_token_id
                )

            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Extract only the assistant's response
            if "<|im_start|>assistant" in generated_text:
                response = generated_text.split("<|im_start|>assistant")[-1].strip()
            else:
                response = generated_text

            predictions.append(response)
            references.append(example['output'])

            if i % 10 == 0:
                print(f"Processed {i}/{n}")

        print("\nComputing metrics...")
        results = {}

        # 1. ROUGE
        try:
            results['rouge'] = rouge.compute(predictions=predictions, references=references)
            print("\nROUGE Scores:", results['rouge'])
        except Exception as e:
            print(f"ROUGE calculation failed: {e}")

        # 2. BLEU
        try:
            results['bleu'] = bleu.compute(predictions=predictions, references=references)
            print("\nBLEU Score:", results['bleu'])
        except Exception as e:
            print(f"BLEU calculation failed: {e}")

        # 3. BERTScore (Semantic Similarity)
        try:
            bert_res = bertscore.compute(predictions=predictions, references=references, lang="en", model_type="distilbert-base-uncased")
            results['bertscore_f1'] = np.mean(bert_res['f1'])
            print(f"\nBERTScore (Mean F1): {results['bertscore_f1']:.4f}")
        except Exception as e:
            print(f"\nBERTScore calculation failed: {e}")

        return results

In [16]:
# ============================================================================
# INFERENCE (GRAMMARLY-STYLE INTERFACE)
# ============================================================================

class Brautigan:
    """Grammarly-style interface for haiku editing"""

    def __init__(self, model_path: str):
        print(f"Loading Brautigan model from {model_path}")
        self.model, self.tokenizer = FastLanguageModel.from_pretrained(
            model_path,
            max_length=1024,
            dtype=None,
            load_in_4bit=True,
        )

        self.system_prompt = """You are a haiku editor. Analyze haiku and provide concise feedback.

Format:
Quality: X/1.0
[Brief overall assessment]
Issues:
1. [SEVERITY] specific issue and suggestion
2. [SEVERITY] specific issue and suggestion
3. [SEVERITY] specific issue and suggestion

Focus on: syllable count (5-7-5), imagery, coherence, and structure.
Be direct and actionable."""

    def review(self, haiku: str, max_new_tokens: int = 512) -> str:
        """Review a haiku and return formatted feedback as text"""
        prompt = f"""<|im_start|>system
{self.system_prompt}<|im_end|>
<|im_start|>user
Please review this haiku:

{haiku}<|im_end|>
<|im_start|>assistant
"""

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                pad_token_id=self.tokenizer.pad_token_id
            )

        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Extract just the assistant's response
        if '<|im_start|>assistant' in response:
            return response.split('<|im_start|>assistant')[-1].strip()
        return response

    def suggest_improvement(self, haiku: str, focus_area: str = None) -> str:
        """Get specific improvement suggestions for a focus area"""
        focus_prompt = f"\n\nFocus particularly on: {focus_area}" if focus_area else ""

        prompt = f"""<|im_start|>system
{self.system_prompt}<|im_end|>
<|im_start|>user
Please review this haiku and suggest specific improvements:{focus_prompt}

{haiku}<|im_end|>
<|im_start|>assistant
"""

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=512,
                temperature=0.8,
                top_p=0.9,
                do_sample=True,
                pad_token_id=self.tokenizer.pad_token_id
            )

        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Extract assistant response
        if '<|im_start|>assistant' in response:
            return response.split('<|im_start|>assistant')[-1].strip()
        return response

In [17]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [18]:
# ========================================================================
# MAIN EXECUTION
# ========================================================================

def main():
    """Main training pipeline"""
    config = TrainingConfig(
        dataset_path="/content/drive/MyDrive/haiku_dataset/haiku_analysis_results.csv",
        output_dir="brautigan_haiku_suggester",
    )

    # Build dataset (Now returns 3-way split)
    builder = HaikuDatasetBuilder(config)
    train_dataset, val_dataset, test_dataset = builder.load_and_prepare()

    # Train model
    trainer = BrautiganModelTrainer(config)
    trainer.setup_model()

    # Train with Validation Set (for loss monitoring)
    trainer.train(train_dataset, val_dataset)

    # Test with strictly held-out Test Set (for final metrics)
    trainer.run_test_evaluation(test_dataset)

    print("\n✅ Training and Validation complete!")
    print(f"Model saved to: {config.output_dir}")

    # Test inference
    print("\n" + "="*60)
    print("Testing Brautigan...")
    print("="*60)

    brautigan = Brautigan(config.output_dir)

    test_haiku = 'old pond still water\na frog leaps into the pond very quickly\nsplash heard everywhere'

    print(f"\nTest Haiku:\n{test_haiku}\n")

    result = brautigan.review(test_haiku)
    print("Brautigan Review:")
    print(result)

if __name__ == "__main__":
    main()

Loading dataset from /content/drive/MyDrive/haiku_dataset/haiku_analysis_results.csv...
Initial dataset size: 173726
After quality filter (>= 0.3): 173645 examples
Stratified sampling: 25600 examples across quality spectrum
Created 25600 training examples

DATASET DIAGNOSTICS

Sample formatted prompt:
<|im_start|>system
You are a haiku editor. Analyze haiku and provide concise feedback.

Format:
Quality: X/1.0
[Brief overall assessment]
Issues:
1. [SEVERITY] specific issue and suggestion
2. [SEVERITY] specific issue and suggestion
3. [SEVERITY] specific issue and suggestion

Focus on: syllable count (5-7-5), imagery, coherence, and structure.
Be direct and actionable.<|im_end|>
<|im_start|>user
Please review this haiku:

enough pockets
the eye turned towards the moon
nearly full<|im_end|>
<|i...

------------------------------------------------------------
Sample token count: 228
Max sequence length: 1024
Average output length: 344 chars
Max output length: 388 chars

Examples exceeding

Unsloth: Tokenizing ["text"] (num_proc=16):   0%|          | 0/20480 [00:00<?, ? examples/s]

Unsloth: Tokenizing ["text"] (num_proc=16):   0%|          | 0/2560 [00:00<?, ? examples/s]


GRADIENT CHECK
Loss value: 3.9995503425598145
Loss requires grad: True
  base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight: grad_norm=0.000000
  base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight: grad_norm=0.355898
  base_model.model.model.layers.0.self_attn.k_proj.lora_A.default.weight: grad_norm=0.000000
  base_model.model.model.layers.0.self_attn.k_proj.lora_B.default.weight: grad_norm=0.358727
  base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight: grad_norm=0.000000

Total params with gradients: 504
Gradient stats: min=0.000000, max=5.004268, mean=0.417173



==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 20,480 | Num Epochs = 3 | Total steps = 15,360
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 2
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 2 x 1) = 4
 "-____-"     Trainable parameters = 264,241,152 of 4,286,709,248 (6.16% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,Validation Loss
100,0.4353,0.415311
200,0.374,0.340368
300,0.3445,0.333668
400,0.3181,0.329853
500,0.3329,0.3272
600,0.3138,0.324372
700,0.3169,0.323046
800,0.3363,0.323622
900,0.3174,0.321239
1000,0.2993,0.322223


Unsloth: Not an error, but Qwen3ForCausalLM does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient



POST-TRAINING TEST
Post-training response (last 300 chars):
on the suggestions to strengthen impact.

Issues:
1. [ERROR] Line 1 has 4 syllables but should have 5. Try removing/adding a syllable.
2. [ERROR] Line 2 has 5 syllables but should have 7. Try removing/adding a syllable.
3. [SUGGESTION] The lines flow but don't create tension. Introduce an unexpected

COMPARISON:
Before: ssues:  
1. HIGH – Repetition and redundancy: "click quietly click click" repeats the same sound with no variation or progression. Suggestion: Replace
After:   2 has 5 syllables but should have 7. Try removing/adding a syllable.
3. [SUGGESTION] The lines flow but don't create tension. Introduce an unexpected


✅ Model saved to brautigan_haiku_suggester

RUNNING FINAL TEST EVALUATION (on 30 samples)


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Generating predictions...
Processed 0/30
Processed 10/30
Processed 20/30

Computing metrics...

ROUGE Scores: {'rouge1': np.float64(0.4279856946381893), 'rouge2': np.float64(0.32848115030877517), 'rougeL': np.float64(0.38990090472413774), 'rougeLsum': np.float64(0.4224540034581765)}

BLEU Score: {'bleu': 0.2558679005049569, 'precisions': [0.3296069239091237, 0.2605148658448151, 0.23040466642362378, 0.21664222873900293], 'brevity_penalty': 1.0, 'length_ratio': 2.416557734204793, 'translation_length': 5546, 'reference_length': 2295}


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]


BERTScore (Mean F1): 0.8583

✅ Training and Validation complete!
Model saved to: brautigan_haiku_suggester

Testing Brautigan...
Loading Brautigan model from brautigan_haiku_suggester
==((====))==  Unsloth 2025.12.1: Fast Qwen3 patching. Transformers: 4.56.2.
   \\   /|    NVIDIA L4. Num GPUs = 1. Max memory: 22.161 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
unsloth/qwen3-4b-instruct-2507-unsloth-bnb-4bit does not have a padding token! Will use pad_token = <|vision_pad|>.

Test Haiku:
old pond still water
a frog leaps into the pond very quickly
splash heard everywhere

Brautigan Review:
system
You are a haiku editor. Analyze haiku and provide concise feedback.

Format:
Quality: X/1.0
[Brief overall assessment]