# Training Gemma 3 (4B) Model

In [None]:
import torch
# from transformers import AutoTokenizer, AutoModelForCausalLM # Unsloth handles this
from unsloth import FastLanguageModel
import os
from datasets import load_dataset # For loading example datasets
from trl import SFTTrainer
from transformers import TrainingArguments

# Setup model directory and Unsloth parameters
model_dir = '../Resources/gemma-3-4b-it' # Assuming Gemma 3 4B is here
max_seq_length = 2048 # Choose based on VRAM
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory

# Load model with Unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_dir,
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    local_files_only=True # Ensure local files are used
)

# Add LoRA adapters
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose LoRA rank
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",], # Modules to apply LoRA to
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    use_gradient_checkpointing = True,
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

# --- Persona LoRA Training (Example with a dummy dataset) ---
# In a real scenario, replace this with your actual PersonaPlugs dataset
# For example, load text files from a directory
# For this example, we'll use a small part of an open-source dataset
persona_dataset_name = "wikitext"
persona_dataset_config = "wikitext-2-raw-v1" # A small dataset for demonstration
persona_dataset = load_dataset(persona_dataset_name, persona_dataset_config, split="train[:1%]") # Use a tiny fraction

# Preprocess dataset (example)
def formatting_prompts_func(examples):
    text = examples["text"]
    return { "text": text } # Keep it simple for this example
persona_dataset = persona_dataset.map(formatting_prompts_func, batched = True,)


trainer_persona = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = persona_dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        # max_steps = 60, # Set a low number of steps for quick demo
        num_train_epochs = 1, # Train for 1 epoch as per llm.md for initial LoRAs
        learning_rate = 2e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs_persona_lora",
    ),
)
print("Starting Persona LoRA training...")
trainer_persona.train()
print("Persona LoRA training finished.")
# model.save_pretrained("lora_model_persona") # Optionally save the Persona LoRA adapter


# --- Task-Specific LoRA Training (Example: Continuation) ---
# This would be a separate training run, typically after saving the Persona LoRA
# or by re-initializing the LoRA layers for the new task.
# For simplicity, we'll re-configure the existing LoRA for a new task.
# Ideally, you'd train separate LoRAs as per llm.md.

# Re-initialize LoRA for a new task (or load a fresh base model + new LoRA)
# This is a simplified approach. For true modularity as in llm.md,
# you would train a new LoRA adapter on the base model.
# For this example, we'll just continue training the same adapter with a new dataset.

# Example Task: Text Continuation (using a different part of wikitext)
task_dataset_name = "wikitext"
task_dataset_config = "wikitext-103-raw-v1" # A larger dataset for the task
task_dataset = load_dataset(task_dataset_name, task_dataset_config, split="train[:1%]") # Use a tiny fraction

task_dataset = task_dataset.map(formatting_prompts_func, batched = True,)

trainer_task = SFTTrainer(
    model = model, # Continue with the same LoRA-adapted model for this example
                   # Or, load a fresh base model and add a new LoRA adapter for the task
    tokenizer = tokenizer,
    train_dataset = task_dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        # max_steps = 60,
        num_train_epochs = 1, # Train for 1 epoch
        learning_rate = 2e-4, # Can be different for task LoRA
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs_task_lora",
    ),
)
print("Starting Task-Specific (Continuation) LoRA training...")
trainer_task.train()
print("Task-Specific (Continuation) LoRA training finished.")
# model.save_pretrained("lora_model_task_continue") # Optionally save the Task LoRA

# The `llm.md` mentions DPO refinement later. That would be another step.
# This script covers the initial independent LoRA training phase.

# Apple ML Compute Experimental

In [None]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx_lm.tuner.lora import LoRALinear
from mlx_lm.tuner.utils import linear_to_lora_layers
from mlx_lm.utils import load, generate # Assuming these are available or similar utilities
from mlx_lm.models.gemma import Model as GemmaModel # Adjust if model class name differs
import os

# --- Apple MLX LoRA Training ---
print("Setting up Apple MLX LoRA training...")

# Model and Tokenizer Path (ensure your Gemma model is MLX compatible or convert it)
# This example assumes you have an MLX-compatible Gemma model.
# You might need to convert the Hugging Face model to MLX format first.
# See MLX examples for conversion scripts.
mlx_model_path = '../Resources/gemma-3-4b-it-mlx' # Placeholder: path to MLX-converted Gemma model
if not os.path.exists(mlx_model_path):
    print(f"MLX model not found at {mlx_model_path}. Please convert the HF model to MLX format.")
    print("Skipping MLX LoRA training part.")
else:
    # Load the MLX model and tokenizer
    # The load function might differ based on how mlx_lm expects models.
    # This is a conceptual representation.
    try:
        model, tokenizer = load(mlx_model_path) # mlx_lm's load utility
    except Exception as e:
        print(f"Error loading MLX model: {e}. Make sure it's converted and path is correct.")
        model = None # Ensure model is None if loading fails

    if model:
        # Configure LoRA layers
        # This part is highly dependent on the mlx-lm tuner API
        # The following is a conceptual guide based on typical LoRA application
        lora_rank = 8
        lora_alpha = 16
        lora_dropout = 0.1
        lora_scale = lora_alpha / lora_rank
        
        # Example of how LoRA layers might be applied using mlx-lm's tuner
        # This assumes a utility like `linear_to_lora_layers` exists and works similarly to other libraries
        # You'll need to adapt this to the exact API of mlx-lm.
        # The `llm.md` mentions `Persona_LoRA` and `Task_LoRA`. We'll simulate one LoRA fine-tuning.
        
        # Specify layers to apply LoRA. This needs to match your model's layer names.
        # For Gemma, it might be 'k_proj', 'v_proj', 'q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'
        keys = ["k_proj", "v_proj", "q_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] # Example keys
        
        # This is a conceptual step. The actual API in mlx-lm might differ.
        # It might involve a config dict passed to a tuner function or model wrapper.
        # linear_to_lora_layers(model, keys, lora_rank, lora_alpha, lora_dropout) # This is a common pattern
        
        # Or, mlx-lm might have a specific LoRA model class or a function to adapt the model:
        # model.add_lora_layers(rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, keys=keys) # Hypothetical API
        
        print("MLX Model loaded. LoRA layer application would happen here if API was fully known.")
        print("Refer to mlx-lm documentation for the precise way to add and train LoRA adapters.")

        # Dummy dataset and training loop for MLX (conceptual)
        # Replace with actual data loading and preprocessing for MLX
        # MLX typically uses mx.array
        dummy_input_ids = mx.array([[101, 7592, 2026, 3899, 102]] * 4) # Example tokenized input
        dummy_labels = mx.array([[7592, 2026, 3899, 102, 0]] * 4)      # Example labels

        def loss_fn(model, inputs, targets):
            logits = model(inputs)
            # Add loss calculation, e.g., cross-entropy
            # This is simplified; actual loss needs to handle logits shape and ignore padding
            # log_probs = nn.log_softmax(logits, axis=-1)
            # return nn.losses.nll_loss(log_probs, targets, reduction="mean") # Example
            return mx.mean(mx.square(logits - targets)) # Placeholder simple loss for structure

        loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
        optimizer = optim.Adam(learning_rate=1e-5)

        print("Starting conceptual MLX LoRA training loop...")
        for epoch in range(1): # Minimal epochs for demo
            for i in range(5): # Minimal steps for demo
                loss, grads = loss_and_grad_fn(model, dummy_input_ids, dummy_labels)
                optimizer.update(model, grads)
                mx.eval(model.parameters(), optimizer.state) # Evaluate parameters
                if i % 1 == 0:
                    print(f"Epoch {epoch}, Step {i}, Loss: {loss.item():.4f}")
        print("Conceptual MLX LoRA training finished.")
        # Save LoRA adapter (API specific)
        # save_lora_adapter(model, "mlx_lora_adapter_persona")

print("MLX training script part finished.")