# MLX Model Fine-Tuning with LoRA

This notebook will guide you through the steps of loading a pre-trained model, modifying it with LoRA layers, and training it on a specific dataset. This process is crucial for adapting large pre-trained models to new tasks with relatively small datasets and computational resources. We will load a pre-trained model, modify it with LoRA layers, and train it on a specific dataset.

---

# Setting Up Jupyter Notebook for MLX Fine-Tuning

## Installation
Before you can start fine-tuning models with MLX, you need to set up your environment. We recommend using JupyterLab for this tutorial as it provides a robust, interactive development environment for Jupyter notebooks.

### Ensure Python is installed on your system.

- Create the MLX environment:

```bash
conda create -n mlx-lora -y
```

- Activate the MLX environment:

```bash
conda activate mlx-lora
```

### Install Jupyter Notebook
If you haven't already installed JupyterLab, you can do so using Conda, a popular package and environment management system. Run the following command in your terminal:

```bash
conda install jupyter notebook
```

*This command will install JupyterLab and all required dependencies in your Conda environment.*

## Launch Jupyter Notebook
Once the installation is complete, you can launch JupyterLab by running:

```bash
jupyter notebook
```


*This command starts the Jupyter Notbook server and opens Jupyter Notebook in your default web browser. You can create a new notebook by clicking on the "New" button and selecting "Python 3" from the dropdown menu.*

## Next Steps
With Jupyter Notebook running, you can now proceed to the tutorial sections in this notebook to start fine-tuning your MLX model with LoRA layers.

---

## Importing Necessary Libraries and Modules
Before we start, we need to import all necessary libraries and modules that will be used throughout this notebook. This includes standard libraries for handling files and JSON data, as well as specific modules from the MLX library for model loading, modification, and training. 

*Before we begin the tutorial, it's important to ensure that all necessary Python libraries are installed. This includes libraries for machine learning, data manipulation, and model training. We will install these from a `requirements.txt` file that lists all the dependencies.*

In [18]:
# Clone mlx-examples repo
!git clone https://github.com/ml-explore/mlx-examples
!python3 -m pip install -r ./mlx-examples/lora/requirements.txt

# Install the necessary libraries from the requirements.txt file
!pip install mlx mlx-lm mlx_lm torch numpy transformers

Cloning into 'mlx-examples'...
remote: Enumerating objects: 3026, done.[K
remote: Counting objects: 100% (1547/1547), done.[K
remote: Compressing objects: 100% (406/406), done.[K
remote: Total 3026 (delta 1349), reused 1196 (delta 1141), pack-reused 1479[K
Receiving objects: 100% (3026/3026), 4.63 MiB | 2.40 MiB/s, done.
Resolving deltas: 100% (2041/2041), done.
[0m

In [2]:
# Importing necessary libraries and modules
import random
from typing import Tuple
from mlx_lm import load  # Load function to load models
from mlx_lm.tuner.lora import LoRALinear  # LoRA module for linear transformations
from mlx.utils import tree_flatten  # Utility to flatten model parameters
from mlx_lm.tuner.trainer import TrainingArgs, train  # Training utilities
import mlx.optimizers as optim  # Optimizers for model training
import json  # Module to work with JSON data
from pathlib import Path  # Module for handling filesystem paths

---

## Dataset Class Definition
Here we define a `Dataset` class to handle data operations. This class will be responsible for loading and accessing our dataset. It takes a list of data items and a key under which text data is stored. This abstraction allows us to easily fetch data by index and get its length, which are essential operations during training.

In [3]:
# Definition of the Dataset class to handle data operations
class Dataset:
    def __init__(self, data, key: str = "text"):
        self._data = data
        self._key = key

    def __getitem__(self, idx: int):
        return self._data[idx][self._key]

    def __len__(self):
        return len(self._data)

## Loading the Dataset
To train our model, we first need to load our training and validation datasets. This function `load_dataset` takes a file path as input, checks for the file's existence, and reads the data. It returns an instance of the `Dataset` class filled with the loaded data. This setup is crucial for handling data efficiently during model training.

In [4]:
# Function to load a dataset from a specified path
def load_dataset(path: str):
    path = Path(path)
    if not path.exists():
        raise FileNotFoundError(f"File not found: {path}")
    with open(path, "r") as fid:
        data = [json.loads(line) for line in fid]
    return Dataset(data)

---

## Setting Up the Model and Data
In this cell, we define the `setup` function which initializes and returns essential components for our training: the model, tokenizer, and datasets. We load a pre-trained model and tokenizer from a specified path and load both training and validation datasets using the previously defined `load_dataset` function.

In [9]:
# Main function setup
def setup():
    train_dataset_path = "./data/dorian_training_dataset.jsonl"
    val_dataset_path = "./data/dorian_valid_dataset.jsonl"
    model_path = "/Users/anima/DorainGray-Phi3-4k-MLX"
    model, tokenizer = load(model_path)
    train_dst, valid_dst = load_dataset(train_dataset_path), load_dataset(val_dataset_path)
    return model, tokenizer, train_dst, valid_dst

---

## Modifying the Model with LoRA
In this section, we will modify the pre-trained model by integrating LoRA layers. LoRA allows us to adapt large pre-trained models with minimal additional parameters, making it efficient for fine-tuning on specific tasks. Below, we will freeze the original model parameters and add LoRA layers where necessary.

In [10]:
from mlx_lm.tuner.lora import LoRALinear 

# Modify the model with LoRA layers
def modify_model_with_lora(model):
    # Freeze the model to prevent updating weights of non-LoRA layers
    model.freeze()
    for l in model.model.layers:
        # Iterate through each layer in the model
        # Define the projections you want to update
        projections = [
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj"
        ]
        
        # Update self_attn projections if they exist
        for proj in projections[:]:  # For q_proj, k_proj, v_proj, o_proj
            if hasattr(l.self_attn, proj):
                # Replace existing linear layers with LoRALinear layers
                setattr(l.self_attn, proj, LoRALinear.from_linear(
                    getattr(l.self_attn, proj), r=32, lora_alpha=64
                ))

                        # Update additional custom projections
        if hasattr(l, 'gate_proj'):
            l.gate_proj = LoRALinear.from_linear(l.gate_proj, r=32, lora_alpha=64)
        if hasattr(l, 'up_proj'):
            l.up_proj = LoRALinear.from_linear(l.up_proj, r=32, lora_alpha=64)
        if hasattr(l, 'down_proj'):
            l.down_proj = LoRALinear.from_linear(l.down_proj, r=32, lora_alpha=64)


---

## Training Configuration and Execution
Now that our model has been modified to include LoRA layers, we need to set up the training configuration. This includes defining the training arguments, learning rate schedule, and optimizer. We will then proceed to train the model using the specified training and validation datasets. The training process is monitored by evaluating the model periodically and saving the model at specified intervals.

In [11]:
# Configure and execute training
def train_model(model, tokenizer, train_dst, valid_dst):
    trainingArgs = TrainingArgs(
        batch_size=1,
        iters=5000,
        val_batches=25,
        steps_per_report=10,
        steps_per_eval=200,
        steps_per_save=100,
        adapter_file="adapters.npz",
        max_seq_length=4096,
    )
    decay_steps = trainingArgs.iters
    lr_schedule = optim.cosine_decay(1e-5, decay_steps)
    opt = optim.AdamW(learning_rate=lr_schedule)

    
    train(model=model, 
          tokenizer=tokenizer, 
          args=trainingArgs, 
          optimizer=opt, 
          train_dataset=train_dst, 
          val_dataset=valid_dst)

---

## Executing the Main Function
Finally, we execute the main function which orchestrates the setup, model modification, and training process. This cell will trigger all the defined functions and start the model training process. Watch the outputs for progress and any potential issues that might need debugging.

### The saved adapaters will appear in your directory as training progresses

In [12]:
# Execute main function
model, tokenizer, train_dst, valid_dst = setup()
modify_model_with_lora(model)
train_model(model, tokenizer, train_dst, valid_dst)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Starting training..., iters: 5000
Iter 1: Val loss 1.258, Val took 47.805s
Iter 10: Train loss 1.049, Learning Rate 1.000e-05, It/sec 0.046, Tokens/sec 95.092, Trained Tokens 20750
Iter 20: Train loss 0.921, Learning Rate 1.000e-05, It/sec 0.054, Tokens/sec 109.715, Trained Tokens 41124
Iter 30: Train loss 0.935, Learning Rate 9.999e-06, It/sec 0.053, Tokens/sec 115.291, Trained Tokens 62700
Iter 40: Train loss 1.039, Learning Rate 9.998e-06, It/sec 0.045, Tokens/sec 107.939, Trained Tokens 86431
Iter 50: Train loss 0.886, Learning Rate 9.998e-06, It/sec 0.054, Tokens/sec 117.412, Trained Tokens 108177
Iter 60: Train loss 0.808, Learning Rate 9.997e-06, It/sec 0.068, Tokens/sec 129.936, Trained Tokens 127400
Iter 70: Train loss 0.957, Learning Rate 9.995e-06, It/sec 0.048, Tokens/sec 111.106, Trained Tokens 150448
Iter 80: Train loss 0.939, Learning Rate 9.994e-06, It/sec 0.048, Tokens/sec 110.488, Trained Tokens 173324
Iter 90: Train loss 0.943, Learning Rate 9.992e-06, It/sec 0.045, 

KeyboardInterrupt: 

---

## Fuse Trained Adapters to the Base Model

After training adapters for specific tasks, you can fuse these adapters to the base model. This step integrates the specialized capabilities of the adapters directly into the model, which in turn creates a single model that can be used for inference. This model will be used as the starting point for conversion into GGUF format. This allows us to interact with it locally!

### Breakdown

- `python3`: This invokes the Python interpreter to run the script.

- `./mlx-examples/lora/fuse.py`: This is the path to the Python script that handles the fusion of adapters to the base model.

- `--model ./path/to/model`: Specifies the path to the base model file. This should be the path where the pre-trained or previously fine-tuned model is stored.

- `--save-path ./new-fused-model-name`: This option sets the path and name for the output model file after the fusion process. This file will contain the base model with the adapters integrated.

- `--de-quantize`: This flag indicates that if the model is quantized, it should be de-quantized before fusion. This is often necessary to ensure compatibility between the model and the adapters.

- `--adapter-file ./adapters.npz.safetensors`: Specifies the path to the adapter file. This file contains the trained adapter parameters that will be fused with the base model.

### Customization Options

- **Model Path (`--model`)**: You can specify different models to which you want to apply the adapters, allowing for flexibility in experimenting with various base models.

- **Output Path (`--save-path`)**: Adjust this path based on where you want to store the fused model. This is useful for organizing different versions or types of fused models.

- **De-quantization (`--de-quantize`)**: This option can be toggled based on whether the input model is quantized. If your workflow involves models that are not quantized, this flag can be omitted.

- **Adapter File (`--adapter-file`)**: This path can be changed to point to different adapter files, allowing you to fuse various adapters with the base model depending on the specific enhancements or customizations you've developed.

In [None]:
!python3 -m mlx_lm.fuse --model /Users/anima/DorainGray-Phi3-4k-MLX \
        --save-path ./DorainGray-Phi3-4k \
        --de-quantize \
        --adapter-file /Users/anima/Vodalus-Master-RAG-Wiki/MLX_Fine-Tuning/checkpoints/200_adapters.npz