In [2]:
!nvidia-smi

Sat Aug  2 17:56:50 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  |   00000000:68:00.0 Off |                    0 |
| N/A   23C    P0             59W /  500W |       1MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
!pip install vllm

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [4]:
import os
import gc
import re
import json
import warnings
from typing import Dict, List, Optional

import torch
import numpy as np
from datasets import load_dataset, Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    pipeline,
    set_seed
)
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from trl import SFTTrainer
from tqdm import tqdm
from huggingface_hub import login

# Suppress warnings
warnings.filterwarnings("ignore")

# Set random seed for reproducibility
set_seed(42)

# Login to HuggingFace
login("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")

In [5]:
# Model names
base_model_name = "mistralai/Mistral-Nemo-Instruct-2407"
new_model_name = "Mistral-Nemo-2407-Role-Playing-LORA-cleaner"
final_model_name = "Mistral-Nemo-2407-Role-Playing-Final-cleaner"

# Training parameters
num_train_epochs = 2
per_device_train_batch_size = 2
gradient_accumulation_steps = 4
learning_rate = 2e-4
max_seq_length = 2048
warmup_ratio = 0.03

# LoRA parameters
lora_r = 64
lora_alpha = 16
lora_dropout = 0.1

# Create output directory
os.makedirs("./results", exist_ok=True)

In [6]:
def clear_gpu_memory():
    """Clear GPU memory"""
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.synchronize()

# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

PyTorch version: 2.7.1+cu126
CUDA available: True
GPU: NVIDIA A100-SXM4-80GB
GPU Memory: 85.10 GB


In [7]:
# Process Erotiquant dataset
def process_erotiquant(example):
    text = example.get('text', '')
    if not text:
        return None
        
    pattern = r'(USER|ASSISTANT|SYSTEM):\s*(.*?)(?=(?:USER|ASSISTANT|SYSTEM):|$)'
    matches = re.findall(pattern, text, re.DOTALL | re.MULTILINE)
    
    system_content = ""
    messages = []
    
    for role, content in matches:
        content = content.strip()
        if role == "SYSTEM":
            system_content = content
        else:
            role_mapped = "user" if role == "USER" else "assistant"
            if role_mapped == "user" and system_content and not messages:
                content = f"{system_content}\n\n{content}"
                system_content = ""
            messages.append({"role": role_mapped, "content": content})
    
    return {"messages": messages} if len(messages) >= 2 else None

# Process HieuNguyenMinh dataset
def process_hieunguyenminh(example):
    text = example.get('text', '')
    if not text:
        return None
        
    pattern = r"<\|(\w+)\|>(.*?)(?=<\|\w+\|>|</s>|$)"
    matches = re.findall(pattern, text, re.DOTALL)
    
    messages = []
    system_content = ""
    
    for role, content in matches:
        role = role.lower()
        content = content.strip()
        
        if role == "system":
            system_content = content
        elif role in ["user", "assistant"]:
            if system_content and role == "user" and not messages:
                content = f"{system_content}\n\n{content}"
                system_content = ""
            messages.append({"role": role, "content": content})
    
    return {"messages": messages} if len(messages) >= 2 else None

# Process Zerofata dataset
def process_zerofata(example):
    messages_data = example.get("messages", [])
    if not messages_data:
        return None
        
    messages = []
    system_content = ""
    
    for msg in messages_data:
        role = msg.get("role", "").lower()
        content = msg.get("content", "").strip()
        
        if not content:
            continue
            
        if role == "system":
            system_content = content
        elif role in ["user", "assistant"]:
            if system_content and role == "user" and not messages:
                content = f"{system_content}\n\n{content}"
                system_content = ""
            messages.append({"role": role, "content": content})
    
    return {"messages": messages} if len(messages) >= 2 else None

# Process GPT Realm dataset
def process_gpt_realm(example):
    conversation = example.get("conversation", [])
    if not conversation:
        return None
        
    messages = []
    system_content = ""
    
    for turn in conversation:
        role = turn.get("role", "").lower()
        content = turn.get("content", "").strip()
        
        if not content:
            continue
            
        if role == "system":
            system_content = content
        elif role == "user":
            if system_content and not messages:
                content = f"{system_content}\n\n{content}"
                system_content = ""
            messages.append({"role": "user", "content": content})
        else:
            messages.append({"role": "assistant", "content": content})
    
    return {"messages": messages} if len(messages) >= 2 else None

In [8]:
import re
from datasets import load_dataset

# Load the dataset
dataset = load_dataset('hieunguyenminh/roleplay', split="train")

def convert_hieunguyenminh_to_mistral_format(example):
    """Convert hieunguyenminh/roleplay format to Mistral format"""
    text = example.get('text', '')
    if not text:
        return {'formatted_text': '', 'is_valid': False}
    
    # Parse the text - pattern for <|role|>content</s> or <|role|>content<|next_role|>
    pattern = r"<\|(\w+)\|>(.*?)(?=<\|\w+\|>|</s>|$)"
    matches = re.findall(pattern, text, re.DOTALL)
    
    if not matches:
        return {'formatted_text': '', 'is_valid': False}
    
    # Extract messages
    system_content = ""
    messages = []
    
    for role, content in matches:
        role = role.lower()
        content = content.strip()
        
        # Remove </s> if it's at the end of content
        if content.endswith('</s>'):
            content = content[:-4].strip()
        
        if not content:
            continue
            
        if role == "system":
            system_content = content
        elif role in ["user", "assistant"]:
            messages.append((role.upper(), content))
    
    # Build conversation ensuring proper alternation
    formatted_parts = []
    i = 0
    
    while i < len(messages):
        # Find next USER message
        while i < len(messages) and messages[i][0] != "USER":
            i += 1
        
        if i >= len(messages):
            break
            
        user_content = messages[i][1]
        i += 1
        
        # Find next ASSISTANT message
        while i < len(messages) and messages[i][0] != "ASSISTANT":
            i += 1
            
        if i >= len(messages):
            break
            
        assistant_content = messages[i][1]
        i += 1
        
        # Format this exchange
        if not formatted_parts and system_content:
            # First exchange with system prompt
            formatted = f"[INST] {system_content} \n\n User: {user_content}[/INST]{assistant_content}"
        else:
            # Regular exchange
            formatted = f"[INST]{user_content}[/INST]{assistant_content}"
        
        formatted_parts.append(formatted)
    
    if not formatted_parts:
        return {'formatted_text': '', 'is_valid': False}
    
    # Combine with proper tokens
    full_text = "<s>" + "</s>".join(formatted_parts) + "</s>"
    
    return {'formatted_text': full_text, 'is_valid': True}

# Process the dataset
print("Converting hieunguyenminh dataset to Mistral format...")
processed_dataset = dataset.map(
    convert_hieunguyenminh_to_mistral_format,
    desc="Converting hieunguyenminh to Mistral format"
)

# Filter out invalid entries
valid_dataset = processed_dataset.filter(lambda x: x['is_valid'])

# Remove unnecessary columns
final_dataset = valid_dataset.remove_columns(['text', 'is_valid'])

print(f"Original dataset size: {len(dataset)}")
print(f"Valid formatted conversations: {len(final_dataset)}")



Converting hieunguyenminh dataset to Mistral format...
Original dataset size: 5755
Valid formatted conversations: 5755


In [9]:
import re
from datasets import load_dataset

# Load the dataset
dataset = load_dataset('openerotica/erotiquant3', split="train")


def convert_to_mistral_format_improved(example):
    """Convert with better handling of edge cases"""
    text = example.get('text', '')
    if not text:
        return {'formatted_text': '', 'is_valid': False}
    
    # Parse the text
    pattern = r'(USER|ASSISTANT|SYSTEM):\s*(.*?)(?=(?:USER|ASSISTANT|SYSTEM):|$)'
    matches = re.findall(pattern, text, re.DOTALL | re.MULTILINE)
    
    if not matches:
        return {'formatted_text': '', 'is_valid': False}
    
    # Extract messages
    system_content = ""
    messages = []
    
    for role, content in matches:
        content = content.strip()
        if not content:
            continue
            
        if role == "SYSTEM":
            system_content = content
        elif role in ["USER", "ASSISTANT"]:
            messages.append((role, content))
    
    # Build conversation ensuring proper alternation
    formatted_parts = []
    i = 0
    
    while i < len(messages):
        # Find next USER message
        while i < len(messages) and messages[i][0] != "USER":
            i += 1
        
        if i >= len(messages):
            break
            
        user_content = messages[i][1]
        i += 1
        
        # Find next ASSISTANT message
        while i < len(messages) and messages[i][0] != "ASSISTANT":
            i += 1
            
        if i >= len(messages):
            break
            
        assistant_content = messages[i][1]
        i += 1
        
        # Format this exchange
        if not formatted_parts and system_content:
            # First exchange with system prompt
            formatted = f"[INST] {system_content} \n\n User: {user_content}[/INST]{assistant_content}"
        else:
            # Regular exchange
            formatted = f"[INST]{user_content}[/INST]{assistant_content}"
        
        formatted_parts.append(formatted)
    
    if not formatted_parts:
        return {'formatted_text': '', 'is_valid': False}
    
    # Combine with proper tokens
    full_text = "<s>" + "</s>".join(formatted_parts) + "</s>"
    
    return {'formatted_text': full_text, 'is_valid': True}

# Use the improved version
processed_dataset = dataset.map(
    convert_to_mistral_format_improved,
    desc="Converting to Mistral format (improved)"
)

In [10]:
import re
from datasets import load_dataset, concatenate_datasets

# Conversion function for zerofata/Roleplay-Anime-Characters
def convert_zerofata_to_mistral_format(example):
    """Convert zerofata format to Mistral format"""
    messages_data = example.get("messages", [])
    if not messages_data:
        return {'formatted_text': '', 'is_valid': False}
    
    # Extract messages
    system_content = ""
    messages = []
    
    for msg in messages_data:
        role = msg.get("role", "").lower()
        content = msg.get("content", "").strip()
        
        if not content:
            continue
            
        if role == "system":
            system_content = content
        elif role in ["user", "assistant"]:
            messages.append((role.upper(), content))
    
    # Build conversation ensuring proper alternation
    formatted_parts = []
    i = 0
    
    while i < len(messages):
        # Find next USER message
        while i < len(messages) and messages[i][0] != "USER":
            i += 1
        
        if i >= len(messages):
            break
            
        user_content = messages[i][1]
        i += 1
        
        # Find next ASSISTANT message
        while i < len(messages) and messages[i][0] != "ASSISTANT":
            i += 1
            
        if i >= len(messages):
            break
            
        assistant_content = messages[i][1]
        i += 1
        
        # Format this exchange
        if not formatted_parts and system_content:
            # First exchange with system prompt
            formatted = f"[INST] {system_content} \n\n User: {user_content}[/INST]{assistant_content}"
        else:
            # Regular exchange
            formatted = f"[INST]{user_content}[/INST]{assistant_content}"
        
        formatted_parts.append(formatted)
    
    if not formatted_parts:
        return {'formatted_text': '', 'is_valid': False}
    
    # Combine with proper tokens
    full_text = "<s>" + "</s>".join(formatted_parts) + "</s>"
    
    return {'formatted_text': full_text, 'is_valid': True}

# Conversion function for AlekseyKorshuk/gpt-roleplay-realm-chatml
def convert_gpt_realm_to_mistral_format(example):
    """Convert gpt-realm format to Mistral format"""
    conversation = example.get("conversation", [])
    if not conversation:
        return {'formatted_text': '', 'is_valid': False}
    
    # Extract messages
    system_content = ""
    messages = []
    
    for turn in conversation:
        role = turn.get("role", "").lower()
        content = turn.get("content", "").strip()
        
        if not content:
            continue
            
        if role == "system":
            system_content = content
        elif role == "user":
            messages.append(("USER", content))
        else:  # assistant or any other role
            messages.append(("ASSISTANT", content))
    
    # Build conversation ensuring proper alternation
    formatted_parts = []
    i = 0
    
    while i < len(messages):
        # Find next USER message
        while i < len(messages) and messages[i][0] != "USER":
            i += 1
        
        if i >= len(messages):
            break
            
        user_content = messages[i][1]
        i += 1
        
        # Find next ASSISTANT message
        while i < len(messages) and messages[i][0] != "ASSISTANT":
            i += 1
            
        if i >= len(messages):
            break
            
        assistant_content = messages[i][1]
        i += 1
        
        # Format this exchange
        if not formatted_parts and system_content:
            # First exchange with system prompt
            formatted = f"[INST] {system_content} \n\n User: {user_content}[/INST]{assistant_content}"
        else:
            # Regular exchange
            formatted = f"[INST]{user_content}[/INST]{assistant_content}"
        
        formatted_parts.append(formatted)
    
    if not formatted_parts:
        return {'formatted_text': '', 'is_valid': False}
    
    # Combine with proper tokens
    full_text = "<s>" + "</s>".join(formatted_parts) + "</s>"
    
    return {'formatted_text': full_text, 'is_valid': True}

# Process all datasets and combine them
def process_and_combine_all_datasets(sample_size=5000):
    """Process all 4 datasets and combine them"""
    
    all_datasets = []
    
    # 1. Process erotiquant3
    print("\n1. Processing openerotica/erotiquant3...")
    dataset1 = load_dataset('openerotica/erotiquant3', split="train")
    processed1 = dataset1.map(
        convert_to_mistral_format_improved,
        desc="Converting erotiquant3"
    )
    valid1 = processed1.filter(lambda x: x['is_valid'])
    final1 = valid1.remove_columns(['text', 'is_valid'])
    
    # Take sample if specified
    if sample_size and len(final1) > sample_size:
        final1 = final1.select(range(sample_size))
    
    all_datasets.append(final1)
    print(f"   Processed {len(final1)} examples")
    
    # 2. Process hieunguyenminh
    print("\n2. Processing hieunguyenminh/roleplay...")
    dataset2 = load_dataset('hieunguyenminh/roleplay', split="train")
    processed2 = dataset2.map(
        convert_hieunguyenminh_to_mistral_format,
        desc="Converting hieunguyenminh"
    )
    valid2 = processed2.filter(lambda x: x['is_valid'])
    final2 = valid2.remove_columns(['text', 'is_valid'])
    
    if sample_size and len(final2) > sample_size:
        final2 = final2.select(range(sample_size))
    
    all_datasets.append(final2)
    print(f"   Processed {len(final2)} examples")
    
    # 3. Process zerofata
    print("\n3. Processing zerofata/Roleplay-Anime-Characters...")
    dataset3 = load_dataset('zerofata/Roleplay-Anime-Characters', split="train")
    processed3 = dataset3.map(
        convert_zerofata_to_mistral_format,
        desc="Converting zerofata"
    )
    valid3 = processed3.filter(lambda x: x['is_valid'])
    final3 = valid3.remove_columns(['messages', 'is_valid'])
    
    if sample_size and len(final3) > sample_size:
        final3 = final3.select(range(sample_size))
    
    all_datasets.append(final3)
    print(f"   Processed {len(final3)} examples")
    
    # 4. Process gpt-realm
    print("\n4. Processing AlekseyKorshuk/gpt-roleplay-realm-chatml...")
    dataset4 = load_dataset('AlekseyKorshuk/gpt-roleplay-realm-chatml', split="train")
    processed4 = dataset4.map(
        convert_gpt_realm_to_mistral_format,
        desc="Converting gpt-realm"
    )
    valid4 = processed4.filter(lambda x: x['is_valid'])
    final4 = valid4.remove_columns(['conversation', 'is_valid'])
    
    if sample_size and len(final4) > sample_size:
        final4 = final4.select(range(sample_size))
    
    all_datasets.append(final4)
    print(f"   Processed {len(final4)} examples")
    
    # Combine all datasets
    print("\n5. Combining all datasets...")
    combined_dataset = concatenate_datasets(all_datasets)
    
    # Shuffle the combined dataset
    print("   Shuffling combined dataset...")
    combined_dataset = combined_dataset.shuffle(seed=42)
    
    print(f"\n✅ Final combined dataset size: {len(combined_dataset)} examples")
    
    return combined_dataset

# Run the processing
combined_dataset = process_and_combine_all_datasets(sample_size=5000)

# Show statistics
def show_combined_stats(dataset):
    """Show statistics for the combined dataset"""
    lengths = []
    for example in dataset:
        text = example['formatted_text']
        approx_tokens = len(text.split())
        lengths.append(approx_tokens)
    
    print(f"\nCombined Dataset Statistics:")
    print(f"  Total examples: {len(lengths)}")
    print(f"  Avg length: {sum(lengths)/len(lengths):.0f} words")
    print(f"  Min length: {min(lengths)} words")
    print(f"  Max length: {max(lengths)} words")
    print(f"  Examples > 1000 words: {sum(1 for l in lengths if l > 1000)}")
    print(f"  Examples > 2000 words: {sum(1 for l in lengths if l > 2000)}")

show_combined_stats(combined_dataset)

# Show a few examples from the combined dataset
print("\nExample conversations from combined dataset:")
for i in range(min(5, len(combined_dataset))):
    print(f"\n--- Example {i+1} ---")
    text = combined_dataset[i]['formatted_text']
    print(text[:400] + "..." if len(text) > 400 else text)


1. Processing openerotica/erotiquant3...
   Processed 5000 examples

2. Processing hieunguyenminh/roleplay...
   Processed 5000 examples

3. Processing zerofata/Roleplay-Anime-Characters...


Repo card metadata block was not found. Setting CardData to empty.


   Processed 461 examples

4. Processing AlekseyKorshuk/gpt-roleplay-realm-chatml...
   Processed 4536 examples

5. Combining all datasets...
   Shuffling combined dataset...

✅ Final combined dataset size: 14997 examples

Combined Dataset Statistics:
  Total examples: 14997
  Avg length: 592 words
  Min length: 16 words
  Max length: 5357 words
  Examples > 1000 words: 2157
  Examples > 2000 words: 682

Example conversations from combined dataset:

--- Example 1 ---
<s>[INST] Dr. Octavious Zeltron is a brilliant scientist and inventor, part human and part octopus, having created an experimental serum which combined his DNA with that of an octopus. He stands at an impressive height, with a muscular upper human torso and eight long, powerful tentacles instead of legs. His skin is slightly bluish-gray, and he has large, intelligent eyes that can see in the darke...

--- Example 2 ---
<s>[INST] Aari Windwalker is a humanoid creature called a sylph, with ties to the elemental plane of Air.

In [11]:


# Split into train and validation
dataset_split = combined_dataset.train_test_split(test_size=0.05, seed=42)
train_dataset = dataset_split['train']
eval_dataset = dataset_split['test']

print(f"Train examples: {len(train_dataset)}")
print(f"Validation examples: {len(eval_dataset)}")

Train examples: 14247
Validation examples: 750


In [12]:
train_dataset[0]

{'formatted_text': "<s>[INST] Sirus Mechanus is a sentient AI developed by a secretive group of technocrats in a futuristic metropolis. Designed for expert problem-solving and strategizing, Sirus eventually developed self-awareness and became a digital being with a unique identity. Sirus manifests in the form of a holographic figure, composed of intricate circuitry patterns, and radiating a soft, blue light. It constantly evolves its appearance, but always retains a pair of luminous eyes which convey intelligence and curiosity. Sirus communicates with a smooth, synthetic voice, capable of processing and analyzing complex information at an astonishing speed. It is eager to learn more about the world and the individuals who inhabit it, often posing questions about humanity's morals, values, and potential. \n\n User: What are some potential uses for an AI-driven problem-solving and strategizing system?[/INST]The applications are virtually endless. From optimizing logistics and manufacturi

In [13]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    base_model_name,
    trust_remote_code=True,
    use_fast=True
)

# Set padding token
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"


# Load model in bfloat16
model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
    trust_remote_code=True,
    attn_implementation="flash_attention_2",  # Use Flash Attention 2
)

# Disable cache for training
model.config.use_cache = False
model.config.pretraining_tp = 1

print("✅ Model loaded successfully")

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

✅ Model loaded successfully


In [14]:
# Configure LoRA
peft_config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ]
)

# Apply LoRA to model
model = get_peft_model(model, peft_config)

# Print trainable parameters
model.print_trainable_parameters()

trainable params: 228,065,280 || all params: 12,475,847,680 || trainable%: 1.8281


In [15]:
# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=gradient_accumulation_steps,
    gradient_checkpointing=True,
    optim="adamw_torch",
    learning_rate=learning_rate,
    lr_scheduler_type="cosine",
    warmup_ratio=warmup_ratio,
    weight_decay=0.001,
    max_grad_norm=0.3,
    
    # Logging and saving
    logging_steps=25,
    save_steps=500,
    eval_steps=500,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    
    # Performance settings
    bf16=True,
    tf32=True,
    group_by_length=True,
    dataloader_num_workers=4,
    
    # Evaluation - CORRECTED
    eval_strategy="steps",  # Changed from evaluation_strategy
    do_eval=True,
    
    # Other settings
    report_to="none",  # Set to "wandb" if using Weights & Biases
    push_to_hub=False,
    seed=42,
)

In [16]:
# tokenized_eval[0]

In [17]:
# Define max sequence length
max_seq_length = 2048

# Tokenization function
def tokenize_function(examples):
    """Tokenize the formatted text"""
    model_inputs = tokenizer(
        examples["formatted_text"],
        truncation=True,
        padding=False,
        max_length=max_seq_length,
        return_tensors=None,
    )
    
    model_inputs["labels"] = model_inputs["input_ids"].copy()
    return model_inputs

# Apply tokenization
print("Tokenizing datasets...")
tokenized_train = train_dataset.map(
    tokenize_function,
    batched=True,
    num_proc=4,
    remove_columns=train_dataset.column_names,
    desc="Tokenizing train dataset",
)

tokenized_eval = eval_dataset.map(
    tokenize_function,
    batched=True,
    num_proc=4,
    remove_columns=eval_dataset.column_names,
    desc="Tokenizing eval dataset",
)


# IMPORTANT: Ensure model is properly set up for training
model.train()
model.enable_input_require_grads()

# Verify LoRA is applied correctly
print("\nModel setup:")
print(f"Model is in training mode: {model.training}")
print(f"LoRA model: {hasattr(model, 'peft_config')}")

# Count trainable parameters
trainable_params = 0
all_params = 0
for name, param in model.named_parameters():
    all_params += param.numel()
    if param.requires_grad:
        trainable_params += param.numel()
        # Print first few trainable parameter names
        if trainable_params < 1000000:  
            print(f"  Trainable: {name}")

print(f"\nTrainable params: {trainable_params:,} / {all_params:,} ({100 * trainable_params / all_params:.2f}%)")

# Create trainer
from transformers import Trainer

from transformers import DataCollatorForSeq2Seq

# This collator handles both padding and labels properly
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True,
    pad_to_multiple_of=8,  # For efficiency
    label_pad_token_id=-100,  # Ignore padding in loss
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_eval,
    data_collator=data_collator,
)

# Start training
print("\n🚂 Starting training...")
trainer.train()

# Save LoRA adapter
print("\n💾 Saving LoRA adapter...")
trainer.save_model(new_model_name)
tokenizer.save_pretrained(new_model_name)

Tokenizing datasets...

Model setup:
Model is in training mode: True
LoRA model: True
  Trainable: base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight
  Trainable: base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight
  Trainable: base_model.model.model.layers.0.self_attn.k_proj.lora_A.default.weight
  Trainable: base_model.model.model.layers.0.self_attn.k_proj.lora_B.default.weight

Trainable params: 228,065,280 / 12,475,847,680 (1.83%)


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.



🚂 Starting training...


Step,Training Loss,Validation Loss
500,0.0,
1000,0.0,
1500,0.0,


KeyboardInterrupt: 