# Fine-tuning GPT-OSS 20B Model for Persona Emulation

This notebook fine-tunes GPT-OSS 20B model using Gmail dataset to emulate your persona by learning your email writing patterns.


## Imports and Setup


In [None]:
import os
import torch
from unsloth import FastLanguageModel
from unsloth.chat_templates import standardize_sharegpt, train_on_responses_only
from transformers import TextStreamer
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset, load_from_disk
from pathlib import Path


## Configuration

Set your parameters here:


In [None]:
# Dataset configuration
hf_repo_id = "ryanlin10/gmail_dataset"  # HuggingFace repository ID
hf_token = os.environ.get("HF_TOKEN", None)  # Or set directly: "your_token_here"
local_path = None  # Set to local path if loading from disk, e.g., "./gmail_dataset"

# Model configuration
model_name = "unsloth/gpt-oss-20b"
max_seq_length = 1024

# Output configuration
output_dir = "outputs/gpt_oss_finetuned"
save_model = True

# LoRA configuration
lora_r = 8
lora_alpha = 16

# Training configuration
num_train_epochs = 1
max_steps = None  # Set to override epochs, e.g., 100
per_device_batch_size = 1
gradient_accumulation_steps = 4
learning_rate = 2e-4
warmup_steps = 5

# Testing configuration
test_inference = True  # Set to True to test inference after training


## Dataset Transformation Functions

The Gmail dataset has a single system message with all context and email content. We split it into user/assistant format where:
- User message contains context (recipient, subject, original email if reply)
- Assistant message contains your actual email content (what the model learns)

No AI assistant system prompts or "write an email" instructions are included.


In [None]:
def transform_gmail_to_gpt_oss_format(example):
    """
    Transform Gmail dataset format to GPT-OSS format for persona emulation.
    The Gmail dataset has a single system message with all context and email content.
    We split it into user/assistant format where:
    - User message contains context (recipient, subject, original email if reply)
    - Assistant message contains your actual email content (what the model learns)
    
    No AI assistant system prompts or "write an email" instructions are included.
    """
    # Safely access messages and metadata (handle missing keys)
    original_messages = example.get("messages", [])
    metadata = example.get("metadata", {})
    
    # If messages is missing or empty, skip this example
    if not original_messages:
        return {"messages": []}
    
    # The Gmail dataset currently has only a system message with all content
    if len(original_messages) == 1 and original_messages[0].get("role") == "system":
        system_content = original_messages[0].get("content", "")
        
        # Extract email content and context
        # The Gmail format is: "You are writing an email. Context: ... --- Your reply to the original email ---\n{email_content}"
        # This marker is present for ALL emails (reply or not), so we can reliably split on it
        email_content = ""
        context = ""
        
        if "--- Your reply to the original email ---" in system_content:
            parts = system_content.split("--- Your reply to the original email ---", 1)
            context = parts[0].strip()
            email_content = parts[1].strip() if len(parts) > 1 else ""
        else:
            # Fallback: if marker is missing (shouldn't happen), try to extract from end
            # Look for the actual email content (usually starts after Context info)
            lines = system_content.split('\n')
            # Find where context likely ends
            context_lines = []
            content_start = len(lines)
            for i, line in enumerate(lines):
                if line.strip().startswith('---') or (line.strip() and not ('Context:' in line or 'Recipient:' in line or 'Subject:' in line or 'Date:' in line or 'You are writing' in line)):
                    # This might be the start of actual email content
                    content_start = i
                    break
                context_lines.append(line)
            
            context = '\n'.join(context_lines).strip()
            email_content = '\n'.join(lines[content_start:]).strip() if content_start < len(lines) else system_content
        
        # Get metadata
        subject = metadata.get("subject", "No Subject")
        recipient = metadata.get("recipient", "Unknown Recipient")
        is_reply = metadata.get("is_reply", False)
        
        # Build messages in GPT-OSS format - natural context without AI assistant prompts
        new_messages = []
        
        # Add context naturally - just the facts without instructional prompts
        if is_reply and "--- Original Email ---" in context:
            # Extract the original email context for replies
            orig_email_start = context.find("--- Original Email ---")
            if orig_email_start != -1:
                orig_email_section = context[orig_email_start:]
                # Clean up headers but keep the email structure
                # The original email section already contains "From:", "Subject:", "Date:", and "Content:"
                # Remove only the instructional markers
                orig_email_section = orig_email_section.replace("--- Original Email ---", "").replace("--- End Original Email ---", "").strip()
                # Clean up any "This is a reply to..." instructional text if present
                if "This is a reply to" in orig_email_section:
                    lines = orig_email_section.split('\n')
                    cleaned_lines = []
                    skip_instruction = True
                    for line in lines:
                        if "--- Original Email ---" not in line and "This is a reply to" not in line:
                            if skip_instruction and line.strip():
                                skip_instruction = False
                            if not skip_instruction:
                                cleaned_lines.append(line)
                    orig_email_section = '\n'.join(cleaned_lines).strip()
                
                new_messages.append({
                    "role": "user",
                    "content": f"Replying to email:\n{orig_email_section}"
                })
        else:
            # For new emails, just provide recipient and subject as context
            new_messages.append({
                "role": "user",
                "content": f"To: {recipient}\nSubject: {subject}"
            })
        
        # Add assistant response (your actual email content - this is what the model learns to emulate)
        # Clean up email_content if needed
        if not email_content or not email_content.strip():
            # If we couldn't extract email content, mark for filtering
            return {"messages": []}  # Empty messages will be filtered out
        
        new_messages.append({
            "role": "assistant",
            "content": email_content
        })
        
        return {"messages": new_messages}
    
    # If already in correct format, return as is
    return {"messages": original_messages}


def formatting_prompts_func(examples, tokenizer):
    """Format conversations into text using the tokenizer's chat template."""
    convos = examples["messages"]
    texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
    return {"text": texts}


In [None]:
def load_gmail_dataset(dataset_path=None, hf_repo_id=None, hf_token=None, local_path=None):
    """
    Load Gmail dataset from either HuggingFace Hub or local path.
    
    Args:
        dataset_path: Local path to dataset folder (deprecated, use local_path)
        hf_repo_id: HuggingFace repository ID (e.g., "ryanlin10/gmail_dataset")
        hf_token: HuggingFace token (if needed for private datasets)
        local_path: Local path to dataset folder
    
    Returns:
        Dataset object
    """
    # Use local_path if provided, otherwise fall back to dataset_path for backward compatibility
    local_path = local_path or dataset_path
    
    if hf_repo_id:
        print(f"Loading dataset from HuggingFace: {hf_repo_id}")
        if hf_token:
            dataset = load_dataset(hf_repo_id, token=hf_token)
            # Handle case where dataset might have splits
            if isinstance(dataset, dict):
                dataset = dataset.get("train", list(dataset.values())[0])
        else:
            dataset = load_dataset(hf_repo_id)
            if isinstance(dataset, dict):
                dataset = dataset.get("train", list(dataset.values())[0])
        print(f"‚úì Loaded Gmail dataset from HF with {len(dataset)} examples")
    elif local_path and Path(local_path).exists():
        print(f"Loading dataset from local path: {local_path}")
        dataset = load_from_disk(local_path)
        print(f"‚úì Loaded Gmail dataset locally with {len(dataset)} examples")
    else:
        raise FileNotFoundError(
            f"Dataset not found. Please provide either:\n"
            f"  1. HuggingFace repo_id (e.g., 'ryanlin10/gmail_dataset') with optional token\n"
            f"  2. Local path to dataset folder"
        )
    
    return dataset


## Main Execution

### Step 1: Load Dataset


In [None]:
print("="*60)
print("GPT-OSS 20B Fine-tuning for Persona Emulation")
print("="*60)

print("\nüìä Step 1: Loading Gmail dataset...")
dataset = load_gmail_dataset(
    hf_repo_id=hf_repo_id if not local_path else None,
    hf_token=hf_token,
    local_path=local_path
)


### Step 2: Transform Dataset to GPT-OSS Format


In [None]:
print("\nüîÑ Step 2: Transforming dataset to GPT-OSS format...")
print(f"Dataset columns: {dataset.column_names}")
if len(dataset) > 0:
    print(f"First example keys: {list(dataset[0].keys())}")

dataset = dataset.map(
    transform_gmail_to_gpt_oss_format,
    remove_columns=[col for col in dataset.column_names if col != "messages"]
)

# Filter out examples without valid messages
dataset = dataset.filter(lambda x: "messages" in x and x["messages"] is not None and len(x["messages"]) > 0)
print(f"After transformation: {len(dataset)} examples")


### Step 3: Standardize Dataset Format


In [None]:
print("\nüìù Step 3: Standardizing dataset format...")
dataset = standardize_sharegpt(dataset)


### Step 4: Load Model and Tokenizer


In [None]:
print(f"\nü§ñ Step 4: Loading model {model_name}...")
dtype = None  # Auto-detect
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name,
    dtype=dtype,
    max_seq_length=max_seq_length,
    load_in_4bit=True,  # 4 bit quantization to reduce memory
    full_finetuning=False,
    token=hf_token if model_name.startswith(("hf_", "openai/")) else None,
)


### Step 5: Add LoRA Adapters


In [None]:
print(f"\nüîß Step 5: Adding LoRA adapters (r={lora_r}, alpha={lora_alpha})...")
model = FastLanguageModel.get_peft_model(
    model,
    r=lora_r,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                   "gate_proj", "up_proj", "down_proj",],
    lora_alpha=lora_alpha,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    use_rslora=False,
    loftq_config=None,
)


### Step 6: Format Prompts


In [None]:
print("\nüìã Step 6: Formatting prompts...")
formatting_func = lambda examples: formatting_prompts_func(examples, tokenizer)
dataset = dataset.map(formatting_func, batched=True)


### Step 7: Setup Trainer


In [None]:
print("\nüöÄ Step 7: Setting up trainer...")
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    args=SFTConfig(
        per_device_train_batch_size=per_device_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        warmup_steps=warmup_steps,
        num_train_epochs=num_train_epochs if max_steps is None else None,
        max_steps=max_steps,
        learning_rate=learning_rate,
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.001,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir=output_dir,
        report_to="none",
    ),
)


### Step 8: Configure Response-Only Training


In [None]:
print("\nüéØ Step 8: Configuring response-only training...")
gpt_oss_kwargs = dict(
    instruction_part="<|start|>user<|message|>",
    response_part="<|start|>assistant<|channel|>final<|message|>"
)
trainer = train_on_responses_only(trainer, **gpt_oss_kwargs)


### Step 9: Display Training Configuration


In [None]:
print("\nüìä Training Configuration:")
gpu_stats = torch.cuda.get_device_properties(0) if torch.cuda.is_available() else None
if gpu_stats:
    start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
    print(f"  GPU = {gpu_stats.name}")
    print(f"  Max memory = {max_memory} GB")
    print(f"  Reserved memory = {start_gpu_memory} GB")
print(f"  Training examples = {len(dataset)}")
print(f"  Batch size = {per_device_batch_size} x {gradient_accumulation_steps}")
if max_steps:
    print(f"  Max steps = {max_steps}")
else:
    print(f"  Epochs = {num_train_epochs}")
print(f"  Learning rate = {learning_rate}")


### Step 10: Train Model


In [None]:
print("\nüèÉ Step 10: Starting training...")
trainer_stats = trainer.train()


### Step 11: Training Statistics


In [None]:
if torch.cuda.is_available():
    used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    print(f"\nüìà Training Statistics:")
    print(f"  Runtime = {trainer_stats.metrics['train_runtime']:.2f} seconds ({trainer_stats.metrics['train_runtime']/60:.2f} minutes)")
    print(f"  Peak reserved memory = {used_memory} GB")


### Step 12: Save Model


In [None]:
if save_model:
    print(f"\nüíæ Step 12: Saving model to {output_dir}...")
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"‚úì Model saved to {output_dir}")
    
    # Optionally push to HuggingFace Hub
    push_to_hub = os.environ.get("PUSH_TO_HUB", "").lower() == "true"
    hub_repo_id = os.environ.get("HUB_REPO_ID")
    if push_to_hub and hub_repo_id:
        print(f"\nüì§ Pushing model to HuggingFace Hub: {hub_repo_id}...")
        model.push_to_hub(hub_repo_id, token=hf_token)
        tokenizer.push_to_hub(hub_repo_id, token=hf_token)
        print(f"‚úì Model pushed to {hub_repo_id}")


### Step 13: Test Inference (Optional)


In [None]:
if test_inference:
    print("\nüß™ Step 13: Testing inference...")
    messages = [
        {"role": "system", "content": "You are an AI assistant that helps write professional emails."},
        {"role": "user", "content": "Write an email to test@example.com about 'Project Update - Q4 Results'."},
    ]
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True,
        reasoning_effort="medium",
    ).to("cuda" if torch.cuda.is_available() else "cpu")
    
    print("\nGenerated email:")
    print("-" * 60)
    _ = model.generate(**inputs, max_new_tokens=256, streamer=TextStreamer(tokenizer))
    print("-" * 60)

print("\n‚úÖ Fine-tuning complete!")
print(f"Model saved to: {output_dir}")
