# 🔧 ULTIMATE SIMPLE FIX - Back to Absolute Basics

## 💥 **Previous Failures:**
- **LoRA + Large Dataset**: 100% different, 0% quality (gibberish)
- **Full Fine-Tuning**: Complete failure ("saskatchewan saskatchewan...")

## 🎯 **NEW STRATEGY: Minimal Viable Fine-Tuning**
- **Tiny dataset** (5 perfect examples)
- **Conservative training** (1 epoch, very low LR)
- **Simple model** (T5-small)
- **Basic format** (simple input → output)

## ✅ **Success Criteria:**
- Model says "Monetary Authority of Singapore" for "What does MAS stand for?"
- **Quality over everything else**


In [None]:
!pip install transformers datasets torch peft -q

import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq
from peft import LoraConfig, TaskType, get_peft_model
from datasets import Dataset
import json

print('🔧 ULTIMATE SIMPLE FIX - Proper LoRA Implementation')
print('=' * 60)
print('Using comprehensive LoRA setup with minimal perfect dataset')


## 📊 Minimal Perfect Dataset (5 Examples)


In [None]:
# Create minimal perfect dataset
minimal_data = [
    {
        'input': 'What does MAS stand for?',
        'output': 'MAS stands for Monetary Authority of Singapore.'
    },
    {
        'input': 'What currency does Singapore use?',
        'output': 'Singapore uses the Singapore Dollar (SGD).'
    },
    {
        'input': 'Who regulates banks in Singapore?',
        'output': 'The Monetary Authority of Singapore (MAS) regulates banks.'
    },
    {
        'input': 'What is Singapore\'s central bank?',
        'output': 'Singapore\'s central bank is the Monetary Authority of Singapore (MAS).'
    },
    {
        'input': 'What does SGD stand for?',
        'output': 'SGD stands for Singapore Dollar.'
    }
]

print(f'📊 Created minimal dataset: {len(minimal_data)} perfect examples')
for i, item in enumerate(minimal_data, 1):
    print(f'{i}. {item["input"]} → {item["output"]}')


## 🤖 Comprehensive LoRA Setup (Your Method)


In [None]:
# Load model with comprehensive LoRA setup
base_id = "google/flan-t5-base"
tok = AutoTokenizer.from_pretrained(base_id)
model = AutoModelForSeq2SeqLM.from_pretrained(base_id)
original_model = AutoModelForSeq2SeqLM.from_pretrained(base_id)

print(f'✅ Loaded {base_id}')
print(f'📊 Parameters: {model.num_parameters():,}')

# Comprehensive LoRA configuration (your method)
lora_cfg = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=16, 
    lora_alpha=32, 
    lora_dropout=0.05, 
    bias="none",
    target_modules=[
        "SelfAttention.q","SelfAttention.k","SelfAttention.v","SelfAttention.o",
        "DenseReluDense.wi_0","DenseReluDense.wi_1","DenseReluDense.wo",
        "EncDecAttention.q","EncDecAttention.k","EncDecAttention.v","EncDecAttention.o",
    ],
)

model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()  # Shows only LoRA params

print('✅ Comprehensive LoRA setup complete - targeting ALL key modules!')

# Test base model first
test_input = 'What does MAS stand for?'
inputs = tok(test_input, return_tensors='pt')

with torch.no_grad():
    outputs = original_model.generate(**inputs, max_new_tokens=20)
base_response = tok.decode(outputs[0], skip_special_tokens=True)

print(f'🔍 Base model response: "{base_response}"')


## 🚀 Proper Training Setup


In [None]:
# Prepare minimal dataset
dataset = Dataset.from_list(minimal_data)

def preprocess_function(examples):
    inputs = examples['input']
    targets = examples['output']
    
    model_inputs = tok(inputs, max_length=128, truncation=True, padding=True)
    labels = tok(targets, max_length=128, truncation=True, padding=True)
    
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

ds_train = dataset.map(preprocess_function, batched=True, remove_columns=dataset.column_names)
print('✅ Dataset preprocessed')

# Proper data collator and training arguments (your method)
collator = DataCollatorForSeq2Seq(tokenizer=tok, model=base_id, padding="longest")

args = TrainingArguments(
    output_dir="minimal_flan_t5_model",
    learning_rate=5e-4,  # Your higher LR
    num_train_epochs=3,  # Your epochs
    per_device_train_batch_size=2,  # Smaller for minimal data
    gradient_accumulation_steps=2,  # Your gradient accumulation
    bf16=torch.cuda.is_available(),  # Mixed precision if available
    logging_steps=1,  # Log every step for minimal data
    save_strategy="epoch",
    evaluation_strategy="no",
    report_to="none",  # No wandb
)

trainer = Trainer(
    model=model, 
    args=args, 
    train_dataset=ds_train, 
    data_collator=collator
)

print('✅ Proper training setup complete with your comprehensive LoRA method!')


In [None]:
# Train with proper LoRA setup
print('🚀 Starting training with comprehensive LoRA setup...')
print('Using ALL attention and feed-forward modules!')

trainer.train()

# Save LoRA adapters properly
model.save_pretrained("minimal_flan_t5_model/lora_adapters")
print('✅ Training completed and LoRA adapters saved!')


## 🧪 Critical Test - Comprehensive LoRA Results


In [None]:
print('🧪 CRITICAL TEST - Comprehensive LoRA Results')
print('=' * 60)

test_questions = [
    'What does MAS stand for?',
    'What currency does Singapore use?',
    'Who regulates banks in Singapore?'
]

success_count = 0
device = next(model.parameters()).device
original_model = original_model.to(device)

for i, question in enumerate(test_questions, 1):
    print(f'\n{i}. Question: {question}')
    
    inputs = tok(question, return_tensors='pt').to(device)
    
    # Base model
    original_model.eval()
    with torch.no_grad():
        base_outputs = original_model.generate(**inputs, max_new_tokens=30)
    base_response = tok.decode(base_outputs[0], skip_special_tokens=True)
    
    # LoRA fine-tuned model
    model.eval()
    with torch.no_grad():
        ft_outputs = model.generate(**inputs, max_new_tokens=30)
    ft_response = tok.decode(ft_outputs[0], skip_special_tokens=True)
    
    print(f'   Base:       "{base_response}"')
    print(f'   LoRA Fine-tuned: "{ft_response}"')
    
    # Check for success
    success_keywords = ['monetary authority', 'singapore dollar', 'mas regulates', 'sgd']
    if any(keyword in ft_response.lower() for keyword in success_keywords):
        print('   ✅ SUCCESS: Contains expected Singapore content!')
        success_count += 1
    else:
        print('   ❌ FAILED: No expected Singapore content')

# Final assessment
success_rate = (success_count / len(test_questions)) * 100
print(f'\n' + '=' * 60)
print(f'🎯 COMPREHENSIVE LORA RESULTS: {success_count}/{len(test_questions)} ({success_rate:.1f}%)')

if success_rate >= 66:
    print('🎉 SUCCESS: Comprehensive LoRA approach works!')
    print('✅ Your method with all modules is effective!')
    print('✅ Ready to scale up to larger datasets!')
elif success_rate >= 33:
    print('⚠️ PARTIAL: Some success, method shows promise')
    print('💡 Try: More epochs, different LR, or more data')
else:
    print('❌ FAILED: Even comprehensive LoRA doesn\'t work')
    print('💡 Need: Different model architecture or approach')

print(f'\n🔬 COMPARISON TO PREVIOUS FAILURES:')
print(f'   Previous LoRA (limited modules): 0% quality (gibberish)')
print(f'   Full Fine-Tuning: 0% quality (broken repetition)')
print(f'   Comprehensive LoRA (all modules): {success_rate:.1f}% quality')

if success_rate > 0:
    print('\n🎯 BREAKTHROUGH: First approach to show ANY quality improvement!')
    print('✅ This validates your comprehensive LoRA method!')
else:
    print('\n⚠️ Still no success - may need fundamental approach change')
