# Fine-Tuning GPT-2 with LoRA and FHE using `LoraTrainer`

This notebook demonstrates how to fine-tune a Llama-3.2-1B model using LoRA (Low-Rank Adaptation) with Fully Homomorphic Encryption (FHE). We leverage the `LoraTrainer` API from the `concrete.ml.torch.lora` library to simplify the process.


In [1]:
import random
import shutil
from pathlib import Path

import numpy as np
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)
from utils_lora import generate_and_print

# Import LoraTrainer from the provided library
from concrete.ml.torch.lora import LoraTrainer

Concrete ML LoRA fine-tuning is implemented in a 'hybrid' setting: the client machine outsources all
computations that involve the original model weights, but runs gradient descent on LoRA layers locally. 

The client machine thus executes some layers of the LoRA training protocol and it can use CPU or dedicated
accelerators for this process. 

In [2]:
# Set seed for reproducibility
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
device = "cpu"
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

import concrete_ml_extensions as fhext

cuda_fhext = fhext.is_cuda_enabled() and fhext.is_cuda_available()  # pylint: disable=no-member
print(
    "Original model linear layers execute with FHE on: ",
    "cuda" if cuda_fhext else "cpu",
)
print("Non-FHE layers and the LoRA weight optimizer executed on: ", device)

Original model linear layers execute with FHE on:  cpu
Non-FHE layers and the LoRA weight optimizer executed on:  cpu


## Set-up

Load the LLAMA model, tokenize the dataset, and create LoRA fine-tuning configuration.

In [3]:
# Load the model and tokenizer
model_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Ensure the tokenizer has a pad token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

# Freeze the original model's weights
for param in model.parameters():
    param.requires_grad = False

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

In [4]:
# Apply LoRA configuration
peft_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.01,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules="all-linear",
)
peft_model = get_peft_model(model, peft_config)

In [5]:
# Load the dataset and tokenize it
dataset = load_dataset("json", data_files="data_finetune/dataset.jsonl", split="train")


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="longest", truncation=True)


tokenized_dataset = dataset.map(tokenize_function, batched=True)
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/46 [00:00<?, ? examples/s]

In [6]:
# Define training arguments
EPOCHS = 10
PER_DEVICE_TRAIN_BATCH_SIZE = 4
training_args = TrainingArguments(
    output_dir="./checkpoints",
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    gradient_accumulation_steps=1,
    save_total_limit=1,
    use_cpu=True,
    learning_rate=2e-4,
    lr_scheduler_type="linear",
    seed=SEED,
    data_seed=SEED,
    warmup_steps=10,
    weight_decay=0.01,
    prediction_loss_only=True,
)

In [7]:
# Create optimizer and scheduler using HuggingFace's Trainer
hf_trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)
train_dataloader = hf_trainer.get_train_dataloader()
hf_trainer.create_optimizer_and_scheduler(num_training_steps=len(train_dataloader) * EPOCHS)

optimizer = hf_trainer.optimizer
lr_scheduler = hf_trainer.lr_scheduler


# Define a causal LM loss function
def causal_lm_loss(logits, labels, ignore_index=-100):
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    shift_logits = shift_logits.view(-1, shift_logits.size(-1))
    shift_labels = shift_labels.view(-1)
    loss = torch.nn.functional.cross_entropy(
        shift_logits, shift_labels, ignore_index=ignore_index, reduction="mean"
    )
    return loss

## Test the original model

In [8]:
# Print the initial generation with the base model
PROMPT = "from concrete.ml.sklearn import LogisticRegression\n\nmodel = LogisticRegression("
print("Initial generation with base model:")
print(generate_and_print(PROMPT, model, tokenizer, seed=SEED))

Initial generation with base model:


from concrete.ml.sklearn import LogisticRegression

model = LogisticRegression( eta=0.1, n_iter=1000, random_state=42)
None


## Convert the model to use FHE

Similarily to all Concrete ML models, LoRA fine-tuning is set up using by compiling the
model. For this, a representative set of data is required.

In [9]:
# Prepare input data for calibration
lengths = [len(item["input_ids"]) for item in tokenized_dataset]
if not all(length == lengths[0] for length in lengths):
    raise ValueError("All examples must have the same length for calibration.")
BLOCK_SIZE = lengths[0]

input_tensor = torch.randint(
    0, tokenizer.vocab_size, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE), dtype=torch.long
)
label_tensor = torch.randint(
    0, tokenizer.vocab_size, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE), dtype=torch.long
)
attention_mask = torch.ones((PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE), dtype=torch.long)
inputset = {"input_ids": input_tensor, "attention_mask": attention_mask, "labels": label_tensor}

# Initialize LoraTrainer
training_args_dict = vars(training_args)
lora_trainer = LoraTrainer(
    model=peft_model,
    optimizer=optimizer,
    loss_fn=causal_lm_loss,
    lr_scheduler=lr_scheduler,
    training_args=training_args_dict,
    n_layers_to_skip_for_backprop=3,
)

LoRA layers detected in the model.


Compile the model using quantization. 

In [10]:
# Compile the model with FHE
lora_trainer.compile(inputset, n_bits=16)

Compiling FHE layers:   0%|          | 0/221 [00:00<?, ?it/s]

Compiling FHE layers:   0%|          | 1/221 [00:00<00:59,  3.69it/s]

Compiling FHE layers:   1%|▏         | 3/221 [00:00<00:31,  7.00it/s]

Compiling FHE layers:   2%|▏         | 4/221 [00:00<00:38,  5.61it/s]

Compiling FHE layers:   2%|▏         | 5/221 [00:00<00:42,  5.05it/s]

Compiling FHE layers:   3%|▎         | 6/221 [00:01<01:34,  2.28it/s]

Compiling FHE layers:   3%|▎         | 7/221 [00:02<02:02,  1.75it/s]

Compiling FHE layers:   4%|▎         | 8/221 [00:05<04:03,  1.15s/it]

Compiling FHE layers:   4%|▍         | 9/221 [00:07<04:52,  1.38s/it]

Compiling FHE layers:   5%|▍         | 10/221 [00:08<04:59,  1.42s/it]

Compiling FHE layers:   5%|▍         | 11/221 [00:09<04:23,  1.25s/it]

Compiling FHE layers:   5%|▌         | 12/221 [00:09<03:19,  1.05it/s]

Compiling FHE layers:   6%|▌         | 13/221 [00:09<02:33,  1.35it/s]

Compiling FHE layers:   6%|▋         | 14/221 [00:10<01:53,  1.83it/s]

Compiling FHE layers:   7%|▋         | 16/221 [00:10<01:08,  2.99it/s]

Compiling FHE layers:   8%|▊         | 18/221 [00:10<00:53,  3.78it/s]

Compiling FHE layers:   9%|▊         | 19/221 [00:10<00:51,  3.94it/s]

Compiling FHE layers:   9%|▉         | 20/221 [00:11<01:27,  2.30it/s]

Compiling FHE layers:  10%|▉         | 21/221 [00:12<01:48,  1.84it/s]

Compiling FHE layers:  10%|▉         | 22/221 [00:13<02:05,  1.59it/s]

Compiling FHE layers:  10%|█         | 23/221 [00:14<02:13,  1.48it/s]

Compiling FHE layers:  11%|█         | 24/221 [00:15<02:30,  1.31it/s]

Compiling FHE layers:  11%|█▏        | 25/221 [00:16<02:32,  1.29it/s]

Compiling FHE layers:  12%|█▏        | 26/221 [00:16<02:01,  1.61it/s]

Compiling FHE layers:  12%|█▏        | 27/221 [00:16<01:38,  1.96it/s]

Compiling FHE layers:  13%|█▎        | 29/221 [00:16<00:59,  3.24it/s]

Compiling FHE layers:  14%|█▍        | 31/221 [00:16<00:40,  4.70it/s]

Compiling FHE layers:  14%|█▍        | 32/221 [00:17<00:40,  4.61it/s]

Compiling FHE layers:  15%|█▍        | 33/221 [00:17<00:39,  4.74it/s]

Compiling FHE layers:  15%|█▌        | 34/221 [00:18<01:13,  2.56it/s]

Compiling FHE layers:  16%|█▌        | 35/221 [00:18<01:32,  2.01it/s]

Compiling FHE layers:  16%|█▋        | 36/221 [00:19<01:54,  1.62it/s]

Compiling FHE layers:  17%|█▋        | 37/221 [00:20<02:08,  1.44it/s]

Compiling FHE layers:  17%|█▋        | 38/221 [00:22<02:33,  1.19it/s]

Compiling FHE layers:  18%|█▊        | 39/221 [00:23<02:43,  1.12it/s]

Compiling FHE layers:  18%|█▊        | 40/221 [00:23<02:07,  1.42it/s]

Compiling FHE layers:  19%|█▊        | 41/221 [00:23<01:41,  1.78it/s]

Compiling FHE layers:  19%|█▉        | 43/221 [00:23<01:00,  2.94it/s]

Compiling FHE layers:  20%|██        | 45/221 [00:23<00:41,  4.26it/s]

Compiling FHE layers:  21%|██        | 46/221 [00:24<00:42,  4.11it/s]

Compiling FHE layers:  21%|██▏       | 47/221 [00:24<00:40,  4.27it/s]

Compiling FHE layers:  22%|██▏       | 48/221 [00:25<01:27,  1.98it/s]

Compiling FHE layers:  22%|██▏       | 49/221 [00:26<02:00,  1.43it/s]

Compiling FHE layers:  23%|██▎       | 50/221 [00:28<02:40,  1.07it/s]

Compiling FHE layers:  23%|██▎       | 51/221 [00:29<03:00,  1.06s/it]

Compiling FHE layers:  24%|██▎       | 52/221 [00:31<03:37,  1.29s/it]

Compiling FHE layers:  24%|██▍       | 53/221 [00:33<03:54,  1.39s/it]

Compiling FHE layers:  24%|██▍       | 54/221 [00:33<02:56,  1.05s/it]

Compiling FHE layers:  25%|██▍       | 55/221 [00:33<02:13,  1.24it/s]

Compiling FHE layers:  26%|██▌       | 57/221 [00:33<01:17,  2.12it/s]

Compiling FHE layers:  27%|██▋       | 59/221 [00:33<00:50,  3.21it/s]

Compiling FHE layers:  27%|██▋       | 60/221 [00:34<00:47,  3.41it/s]

Compiling FHE layers:  28%|██▊       | 61/221 [00:34<00:44,  3.63it/s]

Compiling FHE layers:  28%|██▊       | 62/221 [00:36<01:38,  1.61it/s]

Compiling FHE layers:  29%|██▊       | 63/221 [00:37<02:21,  1.12it/s]

Compiling FHE layers:  29%|██▉       | 64/221 [00:39<03:11,  1.22s/it]

Compiling FHE layers:  29%|██▉       | 65/221 [00:41<03:31,  1.36s/it]

Compiling FHE layers:  30%|██▉       | 66/221 [00:43<03:54,  1.51s/it]

Compiling FHE layers:  30%|███       | 67/221 [00:45<03:58,  1.55s/it]

Compiling FHE layers:  31%|███       | 68/221 [00:45<03:12,  1.26s/it]

Compiling FHE layers:  31%|███       | 69/221 [00:46<02:41,  1.06s/it]

Compiling FHE layers:  32%|███▏      | 70/221 [00:46<01:59,  1.26it/s]

Compiling FHE layers:  33%|███▎      | 72/221 [00:46<01:13,  2.04it/s]

Compiling FHE layers:  33%|███▎      | 74/221 [00:47<01:06,  2.22it/s]

Compiling FHE layers:  34%|███▍      | 75/221 [00:47<01:09,  2.09it/s]

Compiling FHE layers:  34%|███▍      | 76/221 [00:51<02:39,  1.10s/it]

Compiling FHE layers:  35%|███▍      | 77/221 [00:52<02:58,  1.24s/it]

Compiling FHE layers:  35%|███▌      | 78/221 [00:54<03:17,  1.38s/it]

Compiling FHE layers:  36%|███▌      | 79/221 [00:56<03:28,  1.47s/it]

Compiling FHE layers:  36%|███▌      | 80/221 [00:59<04:23,  1.87s/it]

Compiling FHE layers:  37%|███▋      | 81/221 [01:00<04:12,  1.80s/it]

Compiling FHE layers:  37%|███▋      | 82/221 [01:00<03:08,  1.36s/it]

Compiling FHE layers:  38%|███▊      | 83/221 [01:01<02:19,  1.01s/it]

Compiling FHE layers:  38%|███▊      | 84/221 [01:01<01:42,  1.34it/s]

Compiling FHE layers:  39%|███▉      | 86/221 [01:01<00:58,  2.30it/s]

Compiling FHE layers:  40%|███▉      | 88/221 [01:01<00:42,  3.11it/s]

Compiling FHE layers:  40%|████      | 89/221 [01:01<00:39,  3.36it/s]

Compiling FHE layers:  41%|████      | 90/221 [01:03<01:28,  1.48it/s]

Compiling FHE layers:  41%|████      | 91/221 [01:05<02:00,  1.08it/s]

Compiling FHE layers:  42%|████▏     | 92/221 [01:07<02:29,  1.16s/it]

Compiling FHE layers:  42%|████▏     | 93/221 [01:08<02:46,  1.30s/it]

Compiling FHE layers:  43%|████▎     | 94/221 [01:10<03:05,  1.46s/it]

Compiling FHE layers:  43%|████▎     | 95/221 [01:12<03:10,  1.51s/it]

Compiling FHE layers:  43%|████▎     | 96/221 [01:12<02:23,  1.14s/it]

Compiling FHE layers:  44%|████▍     | 97/221 [01:12<01:49,  1.13it/s]

Compiling FHE layers:  45%|████▍     | 99/221 [01:13<01:02,  1.94it/s]

Compiling FHE layers:  46%|████▌     | 101/221 [01:13<00:40,  2.97it/s]

Compiling FHE layers:  46%|████▌     | 102/221 [01:13<00:38,  3.10it/s]

Compiling FHE layers:  47%|████▋     | 103/221 [01:13<00:35,  3.33it/s]

Compiling FHE layers:  47%|████▋     | 104/221 [01:16<01:55,  1.01it/s]

Compiling FHE layers:  48%|████▊     | 105/221 [01:18<02:15,  1.17s/it]

Compiling FHE layers:  48%|████▊     | 106/221 [01:20<02:37,  1.37s/it]

Compiling FHE layers:  48%|████▊     | 107/221 [01:22<02:54,  1.53s/it]

Compiling FHE layers:  49%|████▉     | 108/221 [01:24<03:04,  1.64s/it]

Compiling FHE layers:  49%|████▉     | 109/221 [01:25<03:01,  1.62s/it]

Compiling FHE layers:  50%|████▉     | 110/221 [01:25<02:15,  1.22s/it]

Compiling FHE layers:  50%|█████     | 111/221 [01:26<01:41,  1.08it/s]

Compiling FHE layers:  51%|█████     | 113/221 [01:26<00:58,  1.85it/s]

Compiling FHE layers:  52%|█████▏    | 114/221 [01:26<00:47,  2.25it/s]

Compiling FHE layers:  52%|█████▏    | 116/221 [01:26<00:36,  2.88it/s]

Compiling FHE layers:  53%|█████▎    | 117/221 [01:27<00:32,  3.17it/s]

Compiling FHE layers:  53%|█████▎    | 118/221 [01:29<01:13,  1.41it/s]

Compiling FHE layers:  54%|█████▍    | 119/221 [01:30<01:38,  1.04it/s]

Compiling FHE layers:  54%|█████▍    | 120/221 [01:32<02:02,  1.21s/it]

Compiling FHE layers:  55%|█████▍    | 121/221 [01:34<02:15,  1.36s/it]

Compiling FHE layers:  55%|█████▌    | 122/221 [01:36<02:31,  1.53s/it]

Compiling FHE layers:  56%|█████▌    | 123/221 [01:37<02:33,  1.57s/it]

Compiling FHE layers:  56%|█████▌    | 124/221 [01:38<01:54,  1.18s/it]

Compiling FHE layers:  57%|█████▋    | 125/221 [01:38<01:25,  1.12it/s]

Compiling FHE layers:  57%|█████▋    | 127/221 [01:38<00:48,  1.93it/s]

Compiling FHE layers:  58%|█████▊    | 129/221 [01:38<00:31,  2.94it/s]

Compiling FHE layers:  59%|█████▉    | 130/221 [01:38<00:29,  3.11it/s]

Compiling FHE layers:  59%|█████▉    | 131/221 [01:39<00:26,  3.42it/s]

Compiling FHE layers:  60%|█████▉    | 132/221 [01:40<00:59,  1.51it/s]

Compiling FHE layers:  60%|██████    | 133/221 [01:42<01:21,  1.09it/s]

Compiling FHE layers:  61%|██████    | 134/221 [01:44<01:44,  1.20s/it]

Compiling FHE layers:  61%|██████    | 135/221 [01:46<01:59,  1.39s/it]

Compiling FHE layers:  62%|██████▏   | 136/221 [01:48<02:08,  1.51s/it]

Compiling FHE layers:  62%|██████▏   | 137/221 [01:49<02:10,  1.55s/it]

Compiling FHE layers:  62%|██████▏   | 138/221 [01:50<01:41,  1.22s/it]

Compiling FHE layers:  63%|██████▎   | 139/221 [01:50<01:15,  1.08it/s]

Compiling FHE layers:  64%|██████▍   | 141/221 [01:50<00:42,  1.87it/s]

Compiling FHE layers:  65%|██████▍   | 143/221 [01:50<00:27,  2.85it/s]

Compiling FHE layers:  65%|██████▌   | 144/221 [01:50<00:25,  3.04it/s]

Compiling FHE layers:  66%|██████▌   | 145/221 [01:51<00:23,  3.27it/s]

Compiling FHE layers:  66%|██████▌   | 146/221 [01:53<00:53,  1.41it/s]

Compiling FHE layers:  67%|██████▋   | 147/221 [01:54<01:10,  1.05it/s]

Compiling FHE layers:  67%|██████▋   | 148/221 [01:56<01:29,  1.23s/it]

Compiling FHE layers:  67%|██████▋   | 149/221 [01:58<01:35,  1.32s/it]

Compiling FHE layers:  68%|██████▊   | 150/221 [02:00<01:46,  1.49s/it]

Compiling FHE layers:  68%|██████▊   | 151/221 [02:01<01:47,  1.54s/it]

Compiling FHE layers:  69%|██████▉   | 152/221 [02:02<01:19,  1.16s/it]

Compiling FHE layers:  69%|██████▉   | 153/221 [02:02<00:59,  1.15it/s]

Compiling FHE layers:  70%|██████▉   | 154/221 [02:02<00:43,  1.55it/s]

Compiling FHE layers:  71%|███████   | 156/221 [02:02<00:24,  2.61it/s]

Compiling FHE layers:  71%|███████▏  | 158/221 [02:02<00:18,  3.39it/s]

Compiling FHE layers:  72%|███████▏  | 159/221 [02:02<00:16,  3.66it/s]

Compiling FHE layers:  72%|███████▏  | 160/221 [02:04<00:39,  1.53it/s]

Compiling FHE layers:  73%|███████▎  | 161/221 [02:06<00:55,  1.08it/s]

Compiling FHE layers:  73%|███████▎  | 162/221 [02:08<01:17,  1.31s/it]

Compiling FHE layers:  74%|███████▍  | 163/221 [02:10<01:21,  1.41s/it]

Compiling FHE layers:  74%|███████▍  | 164/221 [02:12<01:28,  1.55s/it]

Compiling FHE layers:  75%|███████▍  | 165/221 [02:15<01:45,  1.88s/it]

Compiling FHE layers:  75%|███████▌  | 166/221 [02:15<01:17,  1.40s/it]

Compiling FHE layers:  76%|███████▌  | 167/221 [02:15<00:57,  1.06s/it]

Compiling FHE layers:  76%|███████▋  | 169/221 [02:15<00:31,  1.64it/s]

Compiling FHE layers:  77%|███████▋  | 171/221 [02:15<00:19,  2.53it/s]

Compiling FHE layers:  78%|███████▊  | 172/221 [02:16<00:17,  2.76it/s]

Compiling FHE layers:  78%|███████▊  | 173/221 [02:16<00:15,  3.02it/s]

Compiling FHE layers:  79%|███████▊  | 174/221 [02:18<00:33,  1.39it/s]

Compiling FHE layers:  79%|███████▉  | 175/221 [02:19<00:45,  1.02it/s]

Compiling FHE layers:  80%|███████▉  | 176/221 [02:21<00:55,  1.23s/it]

Compiling FHE layers:  80%|████████  | 177/221 [02:23<01:00,  1.37s/it]

Compiling FHE layers:  81%|████████  | 178/221 [02:25<01:06,  1.54s/it]

Compiling FHE layers:  81%|████████  | 179/221 [02:27<01:06,  1.58s/it]

Compiling FHE layers:  81%|████████▏ | 180/221 [02:27<00:49,  1.20s/it]

Compiling FHE layers:  82%|████████▏ | 181/221 [02:27<00:36,  1.10it/s]

Compiling FHE layers:  83%|████████▎ | 183/221 [02:27<00:20,  1.90it/s]

Compiling FHE layers:  84%|████████▎ | 185/221 [02:28<00:12,  2.87it/s]

Compiling FHE layers:  84%|████████▍ | 186/221 [02:28<00:11,  3.02it/s]

Compiling FHE layers:  85%|████████▍ | 187/221 [02:28<00:10,  3.30it/s]

Compiling FHE layers:  85%|████████▌ | 188/221 [02:30<00:23,  1.41it/s]

Compiling FHE layers:  86%|████████▌ | 189/221 [02:33<00:40,  1.27s/it]

Compiling FHE layers:  86%|████████▌ | 190/221 [02:35<00:44,  1.44s/it]

Compiling FHE layers:  86%|████████▋ | 191/221 [02:36<00:45,  1.52s/it]

Compiling FHE layers:  87%|████████▋ | 192/221 [02:39<00:52,  1.80s/it]

Compiling FHE layers:  87%|████████▋ | 193/221 [02:40<00:48,  1.72s/it]

Compiling FHE layers:  88%|████████▊ | 194/221 [02:41<00:34,  1.29s/it]

Compiling FHE layers:  88%|████████▊ | 195/221 [02:41<00:25,  1.03it/s]

Compiling FHE layers:  89%|████████▉ | 197/221 [02:41<00:13,  1.77it/s]

Compiling FHE layers:  90%|█████████ | 199/221 [02:41<00:08,  2.70it/s]

Compiling FHE layers:  90%|█████████ | 200/221 [02:41<00:07,  2.84it/s]

Compiling FHE layers:  91%|█████████ | 201/221 [02:42<00:06,  3.11it/s]

Compiling FHE layers:  91%|█████████▏| 202/221 [02:44<00:13,  1.40it/s]

Compiling FHE layers:  92%|█████████▏| 203/221 [02:45<00:17,  1.03it/s]

Compiling FHE layers:  92%|█████████▏| 204/221 [02:47<00:20,  1.22s/it]

Compiling FHE layers:  93%|█████████▎| 205/221 [02:49<00:21,  1.36s/it]

Compiling FHE layers:  93%|█████████▎| 206/221 [02:51<00:22,  1.52s/it]

Compiling FHE layers:  94%|█████████▎| 207/221 [02:52<00:21,  1.56s/it]

Compiling FHE layers:  94%|█████████▍| 208/221 [02:53<00:15,  1.18s/it]

Compiling FHE layers:  95%|█████████▍| 209/221 [02:53<00:10,  1.11it/s]

Compiling FHE layers:  95%|█████████▌| 211/221 [02:53<00:05,  1.92it/s]

Compiling FHE layers:  96%|█████████▋| 213/221 [02:53<00:02,  2.93it/s]

Compiling FHE layers:  97%|█████████▋| 214/221 [02:53<00:02,  3.16it/s]

Compiling FHE layers:  97%|█████████▋| 215/221 [02:54<00:01,  3.41it/s]

Compiling FHE layers:  98%|█████████▊| 216/221 [02:56<00:03,  1.30it/s]

Compiling FHE layers:  98%|█████████▊| 217/221 [02:57<00:04,  1.03s/it]

Compiling FHE layers:  99%|█████████▊| 218/221 [02:59<00:03,  1.27s/it]

Compiling FHE layers:  99%|█████████▉| 219/221 [03:02<00:03,  1.60s/it]

Compiling FHE layers: 100%|█████████▉| 220/221 [03:04<00:01,  1.69s/it]

Compiling FHE layers: 100%|██████████| 221/221 [03:05<00:00,  1.65s/it]

Compiling FHE layers: 100%|██████████| 221/221 [03:05<00:00,  1.19it/s]




## Test-run Concrete ML LoRA fine-tuning on clear data with quantization

To check that everything works properly, it's possible to dry-run the fine-tuning on clear data.

In [11]:
# Train the model using LoraTrainer
print("Starting training using LoraTrainer...")
lora_trainer.train(train_dataloader, num_epochs=EPOCHS, fhe="disable", device=device)

Starting training using LoraTrainer...


Training:   0%|          | 0/10 [00:00<?, ?epoch/s]

Training:   0%|          | 0/10 [00:19<?, ?epoch/s, Epoch=1, Avg Loss=2.5540, FHE Mode=disable]

Training:  10%|█         | 1/10 [00:19<02:56, 19.59s/epoch, Epoch=1, Avg Loss=2.5540, FHE Mode=disable]

Training:  10%|█         | 1/10 [00:33<02:56, 19.59s/epoch, Epoch=2, Avg Loss=1.5369, FHE Mode=disable]

Training:  20%|██        | 2/10 [00:33<02:11, 16.38s/epoch, Epoch=2, Avg Loss=1.5369, FHE Mode=disable]

Training:  20%|██        | 2/10 [00:48<02:11, 16.38s/epoch, Epoch=3, Avg Loss=0.8921, FHE Mode=disable]

Training:  30%|███       | 3/10 [00:48<01:50, 15.74s/epoch, Epoch=3, Avg Loss=0.8921, FHE Mode=disable]

Training:  30%|███       | 3/10 [01:05<01:50, 15.74s/epoch, Epoch=4, Avg Loss=0.5614, FHE Mode=disable]

Training:  40%|████      | 4/10 [01:05<01:37, 16.27s/epoch, Epoch=4, Avg Loss=0.5614, FHE Mode=disable]

Training:  40%|████      | 4/10 [01:20<01:37, 16.27s/epoch, Epoch=5, Avg Loss=0.2630, FHE Mode=disable]

Training:  50%|█████     | 5/10 [01:20<01:18, 15.62s/epoch, Epoch=5, Avg Loss=0.2630, FHE Mode=disable]

Training:  50%|█████     | 5/10 [01:34<01:18, 15.62s/epoch, Epoch=6, Avg Loss=0.1861, FHE Mode=disable]

Training:  60%|██████    | 6/10 [01:34<01:01, 15.27s/epoch, Epoch=6, Avg Loss=0.1861, FHE Mode=disable]

Training:  60%|██████    | 6/10 [01:49<01:01, 15.27s/epoch, Epoch=7, Avg Loss=0.1352, FHE Mode=disable]

Training:  70%|███████   | 7/10 [01:49<00:44, 14.93s/epoch, Epoch=7, Avg Loss=0.1352, FHE Mode=disable]

Training:  70%|███████   | 7/10 [02:03<00:44, 14.93s/epoch, Epoch=8, Avg Loss=0.1094, FHE Mode=disable]

Training:  80%|████████  | 8/10 [02:03<00:29, 14.74s/epoch, Epoch=8, Avg Loss=0.1094, FHE Mode=disable]

Training:  80%|████████  | 8/10 [02:17<00:29, 14.74s/epoch, Epoch=9, Avg Loss=0.0994, FHE Mode=disable]

Training:  90%|█████████ | 9/10 [02:17<00:14, 14.59s/epoch, Epoch=9, Avg Loss=0.0994, FHE Mode=disable]

Training:  90%|█████████ | 9/10 [02:32<00:14, 14.59s/epoch, Epoch=10, Avg Loss=0.0885, FHE Mode=disable]

Training: 100%|██████████| 10/10 [02:32<00:00, 14.64s/epoch, Epoch=10, Avg Loss=0.0885, FHE Mode=disable]

Training: 100%|██████████| 10/10 [02:32<00:00, 15.24s/epoch, Epoch=10, Avg Loss=0.0885, FHE Mode=disable]

Training completed. Final Avg Loss: 0.0885, FHE Mode: disable





## Evaluation

We show code generation using the original model versus the fine-tuned model. This is done
by disabling the lora layers in the HybridFHEModel.

In [12]:
# Compare generation before and after fine-tuning
peft_model.disable_adapter_layers()
print("Original model generation:")
print(generate_and_print(PROMPT, peft_model, tokenizer, seed=SEED))

Original model generation:


from concrete.ml.sklearn import LogisticRegression

model = LogisticRegression( eta=0.1, max_iter=1000, random_state=1)
None


In [13]:
peft_model.enable_adapter_layers()
print("Fine-tuned model generation:")
print(generate_and_print(PROMPT, peft_model, tokenizer, seed=SEED))

Fine-tuned model generation:


from concrete.ml.sklearn import LogisticRegression

model = LogisticRegression( eta=0.01, n_bits=8)
None


## Fine-tuning on encrypted data

Next, we benchmark the time to train on a single encrypted example, a 
code snippet of ~130 tokens. 

In [14]:
# Create a small data loader with a single example
hf_trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=tokenized_dataset.select(list(range(PER_DEVICE_TRAIN_BATCH_SIZE))),
    data_collator=data_collator,
)

train_dataloader = hf_trainer.get_train_dataloader()

In [15]:
# Execute fine-tuning, using the GPU when it is available
# lora_trainer.train(train_dataloader, num_epochs=1, fhe="execute")

## Save the fine-tuned LoRA weights

In [16]:
# Save the fine-tuned model
save_path = Path("deployment/llama_lora_finetuned")
if save_path.is_dir() and any(save_path.iterdir()):
    shutil.rmtree(save_path)
lora_trainer.save_and_clear_private_info(save_path)

print("Model saved to:", save_path)

Model saved to: deployment/llama_lora_finetuned
