# Gemma fine-tuning ⚙️👨🏻‍💻
In this notebook we will use the Apple MLX framework to fine-tune the open source Gemma model created by Google.

As an input for fine-tuning we will use the questions and responses generated by the larger llama3-70b model that was inferences via Replicate. This is a student teacher approach, we use a large model to generate the data and a smaller model to learn from it.

The instruction for fine tuning this model are based on [this notebook](https://gist.github.com/alexweberk/635431b5c5773efd6d1755801020429f), I thank Alex for sharing this information 🙏 

 -> ***Currently I'm running this notebook on a 2023 M3 Pro with 36GB of RAM***

 We need to install the following packages
- `transformers`
- `mlx`
- `torch`
- `mlx_lm`

-> It's recommended to use a package manager like [poetry](https://python-poetry.org/) to install the dependencies

## The Model
⚠️ We will the instruction tuned gemma-7b-it for this fine-tuning task ⚠️

In [1]:
from mlx_lm import generate, load

In [2]:
model, tokenizer = load("google/gemma-1.1-7b-it")



Fetching 11 files:   0%|          | 0/11 [00:00<?, ?it/s]

config.json:   0%|          | 0.00/620 [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/2.11G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/34.2k [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

Once the model is downloaded we can try to run a prompt to see that inferencing works.

In [4]:
# Generating without adding a prompt template manually
prompt = """
Why is the sky blue?
""".strip()
response = generate(
    model,
    tokenizer,
    prompt=prompt,
    verbose=True,  # Set to True to see the prompt and response
    temp=0.0,
    max_tokens=256,
)

Prompt: Why is the sky blue?


**Answer:**

The sky is blue due to a phenomenon called **Rayleigh scattering**. 

* Sunlight is composed of all the colors of the rainbow, each with a specific wavelength.
* When sunlight interacts with molecules in the atmosphere, such as nitrogen and oxygen, the molecules scatter the light.
* Different wavelengths of light are scattered differently. 
* **Blue light has a shorter wavelength than other colors** and is scattered more efficiently by these molecules.
* This means that more blue light is scattered in all directions, reaching our eyes and making the sky appear blue.
Prompt: 0.853 tokens-per-sec
Generation: 7.404 tokens-per-sec


## Creating the dataset

We need to create a dataset of JSONL files with the following structure:
```json
{"text": "<bos><start_of_turn>user\nWhat is the capital of France?<end_of_turn>\n<start_of_turn>model\nParis is the capital of France.<end_of_turn><eos>"}
```

We need to create train.json and valid.json, having a 90/10 split between the two files. Let's create the dataset by querying the postgresql database we filled with in the site_crawl_processing example notebook

In [17]:
import sqlalchemy
import pandas as pd
connection = "postgresql+psycopg://postgres:mysecretpassword@localhost/paginx"

engine = sqlalchemy.create_engine(connection)

query = """
select "document" , cmetadata 
from langchain_pg_embedding lpe
limit 1000
"""

df = pd.read_sql(query, engine)

In [19]:
print(df.iloc[0]['cmetadata']['NER'])

{'category': 'advertising', 'language': 'en', 'entity_list': [{'entity_name': 'Quiet 2', 'entity_type': 'product'}, {'entity_name': 'Loop', 'entity_type': 'brand'}]}
