# Fine Tune Llama Model

In [1]:
import json
from functools import partial
import os
import sys
import gc
from datetime import datetime
from tqdm import tqdm
from pymongo import MongoClient

import torch
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
from unsloth import FastLanguageModel # FastLanguageModel for LLMs
from peft import prepare_model_for_kbit_training

  from .autonotebook import tqdm as notebook_tqdm

Please restructure your imports with 'import unsloth' at the top of your file.
  from unsloth import is_bfloat16_supported


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


## Params

In [2]:
app_path = '../'
s3_bucket = "watspeed-data-gr-project"
s3_prefix = "models"
use_s3 = True
mongo_uri = "mongodb://localhost:27017/"
mongo_db_name = "biorxiv"
mongo_db_collection = "abstracts"
local_model_path = "models"
base_model_name = "unsloth/Llama-3.2-3B" 
use_adapted_model = False
adapter_path = None # path is relative to local_model_path or s3_prefix"
use_time_series_split = False
test_size = 0.10
report_to = "tensorboard" # report to tensorboard
disable_tqdm = False
eval_steps=None
eval_strategy="epoch"
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. Works for llama 8b but not 3.2-1b
per_device_train_batch_size=8
gradient_accumulation_steps=8

In [3]:
os.chdir(app_path)

In [4]:
from utils.aws import get_boto3_client
if use_s3:
    s3 = get_boto3_client("s3")

Loaded .env — assuming local environment


In [5]:
if not os.path.exists(local_model_path):
    os.makedirs(local_model_path)

## Model Prep

In [6]:
## Model Setup
print('Model Setup')
print(datetime.now())
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+



if use_adapted_model:
    # if use_s3, download the adapted model from S3 from specified, bucket, prefix and path
    assert adapter_path is not None, "Adapter path must be specified when using adapted model."
    if use_s3:
        # assert s3 handler exists
        assert s3 is not None, "S3 client is not initialized."
        s3_model_path = f"{s3_prefix}/{adapter_path}"
        full_local_model_path = os.path.join(local_model_path, adapter_path)
        # Wipe local directory if it exists
        # if os.path.exists(full_model_local_path):
        #     os.rmdir(full_model_local_path)
        os.makedirs(full_local_model_path, exist_ok=True)
        # List all objects under the prefix
        paginator = s3.get_paginator('list_objects_v2')
        for page in paginator.paginate(Bucket=s3_bucket, Prefix=s3_model_path):
            for obj in page.get('Contents', []):
                key = obj['Key']
                if key.endswith('/'):  # Skip folders
                    continue
                # Determine local file path
                rel_path = os.path.basename(key)
                local_path = os.path.join(full_local_model_path, rel_path)
                os.makedirs(os.path.dirname(local_path), exist_ok=True)
    
                print(f"Downloading {key} to {local_path}")
                s3.download_file(s3_bucket, key, local_path)
    else:
        full_local_model_path = os.path.join(local_model_path, adapter_path)

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = full_local_model_path,
        max_seq_length = max_seq_length,
        dtype = dtype,
        load_in_4bit = load_in_4bit
        #
    )
else:
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = base_model_name,
        max_seq_length = max_seq_length,
        dtype = dtype,
        load_in_4bit = load_in_4bit,
        # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
    )
    num_layers = model.config.num_hidden_layers
    model = FastLanguageModel.get_peft_model(
                model,
                r = 8, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
                target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                                  "gate_proj", "up_proj", "down_proj",],
                # layers_to_transform=[num_layers - 1],
                lora_alpha = 16,
                lora_dropout = 0, # Supports any, but = 0 is optimized
                bias = "none",    # Supports any, but = "none" is optimized
                # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
                use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
                random_state = 3407,
                use_rslora = False,  # We support rank stabilized LoRA
                loftq_config = None, # And LoftQ
            )

Model Setup
2025-08-11 23:41:17.324168
==((====))==  Unsloth 2025.8.4: Fast Llama patching. Transformers: 4.55.0.
   \\   /|    NVIDIA GeForce RTX 4060 Laptop GPU. Num GPUs = 1. Max memory: 7.996 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.1+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.3.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.31.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth 2025.8.4 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


In [7]:
model.print_trainable_parameters()

trainable params: 12,156,928 || all params: 3,224,906,752 || trainable%: 0.3770


## Data Setup

In [8]:
from utils.pytorch_dataset import BioRxivDataset
# dataset = load_dataset("your_dataset_name", split="train")
dataset = BioRxivDataset(mongo_uri=mongo_uri,
                         db_name=mongo_db_name,
                         collection_name=mongo_db_collection,
                         )
# dataset.map(partial(tokenize_with_eos, tokenizer=tokenizer, max_length=max_seq_length))
train_dataset, eval_dataset = dataset.train_test_split(test_size=test_size, 
                                random_state=42, 
                                use_time_series_split=use_time_series_split
                                )


In [9]:
train_dataset

<utils.pytorch_dataset.BioRxivDataset at 0x73e8dc201e20>

In [10]:
len(train_dataset)

39138

In [11]:
len(eval_dataset)

4349

In [12]:
train_dataset.to_dict()[0:2]

[{'_id': '689835383c834e4e5e1097b7',
  'doi': '10.1101/2024.07.24.604968',
  'text': 'Phylogenomics has emerged as a transformative approach in systematics, conservation biology, and biomedicine, enabling the inference of evolutionary relationships by leveraging hundreds to thousands of genes from genomic or transcriptomic data. However, acquiring high-quality genomes and transcriptomes necessitates samples with intact DNA and RNA, substantial sequencing investments, and extensive bioinformatic processing, such as genome/transcriptome assembly and annotation. This challenge is particularly pronounced for rare or difficult-to-collect species, such as those inhabiting the deep sea, where only fragmented DNA reads are often available due to environmental degradation or suboptimal preservation conditions. To address these limitations, we introduce VEHoP (Versatile, Easy-to-use Homology-based Phylogenomic pipeline), a tool designed to infer protein-coding regions from diverse inputs, includ

In [13]:
def add_eos(example):
    eos_token = tokenizer.eos_token
    if eos_token is None:
        raise ValueError("Tokenizer does not define an EOS token.")
    
    text = example.get("text", "")
    if not text:
        return {"text": ""}
    
    return {"text": text + eos_token}

In [14]:
from datasets import Dataset
from tqdm import tqdm
train_hf_dataset = []
eval_hf_dataset = []
print('converting train data')
for i in tqdm(range(len(train_dataset))):
    item = train_dataset[i]
    if "text" in item.keys():
        train_hf_dataset.append(add_eos(item))
    else:
        print("skipping for index {} in train dataset".format(i))
for i in tqdm(range(len(eval_dataset))):
    item = eval_dataset[i]
    if "text" in item.keys():
        eval_hf_dataset.append(add_eos(item))
    else:
        print("skipping for index {} in eval dataset".format(i))

converting train data


100%|█████████████████████████████████████████████████████████████████████████| 39138/39138 [00:00<00:00, 496749.01it/s]
100%|███████████████████████████████████████████████████████████████████████████| 4349/4349 [00:00<00:00, 469875.28it/s]


In [15]:
train_hf_dataset = Dataset.from_list(train_hf_dataset)
eval_hf_dataset = Dataset.from_list(eval_hf_dataset)

## Training Setup

In [16]:
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_hf_dataset,
    eval_dataset = eval_hf_dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = per_device_train_batch_size,
        gradient_accumulation_steps = gradient_accumulation_steps,
        warmup_steps = 5,
        num_train_epochs = 2, # Set this for 1 full training run.
        # max_steps = 5,
        eval_steps=eval_steps,
        eval_strategy = eval_strategy,
        learning_rate = 1e-5,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = report_to, # Use this for WandB etc
        logging_dir = "logs",
        disable_tqdm = disable_tqdm
    ),
)

Unsloth: Tokenizing ["text"]: 100%|██████████████████████████████████████| 39138/39138 [00:04<00:00, 9317.64 examples/s]
Unsloth: Tokenizing ["text"]: 100%|████████████████████████████████████████| 4349/4349 [00:00<00:00, 9567.31 examples/s]


## Run Training

In [17]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA GeForce RTX 4060 Laptop GPU. Max memory = 7.996 GB.
2.953 GB of memory reserved.


In [18]:
datetime.now().strftime("%A, %B %d, %Y at %I:%M %p")

'Monday, August 11, 2025 at 11:41 PM'

In [19]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 39,138 | Num Epochs = 2 | Total steps = 1,224
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 8 x 1) = 64
 "-____-"     Trainable parameters = 12,156,928 of 3,224,906,752 (0.38% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Epoch,Training Loss,Validation Loss
1,2.2519,2.198221
2,2.1308,2.192798


Unsloth: Not an error, but LlamaForCausalLM does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient


In [20]:
datetime.now().strftime("%A, %B %d, %Y at %I:%M %p")

'Tuesday, August 12, 2025 at 09:07 AM'

In [21]:
# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

33948.032 seconds used for training.
565.8 minutes used for training.
Peak reserved memory = 5.684 GB.
Peak reserved memory for training = 2.731 GB.
Peak reserved memory % of max memory = 71.086 %.
Peak reserved memory for training % of max memory = 34.155 %.


## Save Lora Weights

In [22]:
import re
# Save LoRA Weights locally and to S3 if required
print("Saving LoRA Weights...")

base_model_folder = base_model_name.replace("/", "_") + "_{}".format(datetime.now().strftime("%Y%m%d_%H%M%S"))

model_subdir = os.path.join(local_model_path, base_model_folder)
if not os.path.exists(model_subdir):
    os.makedirs(model_subdir)
lora_weights_path = os.path.join(model_subdir, "lora_weights")
if not os.path.exists(lora_weights_path):
    os.makedirs(lora_weights_path)
trainer.save_model(lora_weights_path)
tokenizer.save_pretrained(lora_weights_path)
if use_s3:
    print("Uploading LoRA Weight Files to S3...")
    for fname in os.listdir(lora_weights_path):
        fpath = os.path.join(lora_weights_path, fname)
        if os.path.isfile(fpath):
            print("{}".format(fpath))
            s3.upload_file(
                Filename=os.path.join(lora_weights_path, fname),
                Bucket=s3_bucket,
                Key=os.path.join(s3_prefix, base_model_folder, "lora_weights", fname)
            )

Saving LoRA Weights...
Uploading LoRA Weight Files to S3...
models/unsloth_Llama-3.2-3B_20250812_090718/lora_weights/special_tokens_map.json
models/unsloth_Llama-3.2-3B_20250812_090718/lora_weights/tokenizer_config.json
models/unsloth_Llama-3.2-3B_20250812_090718/lora_weights/training_args.bin
models/unsloth_Llama-3.2-3B_20250812_090718/lora_weights/README.md
models/unsloth_Llama-3.2-3B_20250812_090718/lora_weights/adapter_model.safetensors
models/unsloth_Llama-3.2-3B_20250812_090718/lora_weights/tokenizer.json
models/unsloth_Llama-3.2-3B_20250812_090718/lora_weights/adapter_config.json


In [23]:
datetime.now().strftime("%A, %B %d, %Y at %I:%M %p")

'Tuesday, August 12, 2025 at 09:07 AM'