In [1]:
import pandas as pd

import argparse
from time import time

from datasets import load_from_disk
from peft import LoraConfig, get_peft_model
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments
from utils_dev import *

from concrete.ml.torch.hybrid_model import HybridFHEMode
from concrete.ml.torch.lora import LoraTrainer

  from .autonotebook import tqdm as notebook_tqdm


# Hybrid Fine-Tuning of LLaMA with LoRA

This notebook showcases how to fine-tune the LLaMA-3.2-1B model using LoRA (Low-Rank Adaptation) on the Orca Math Word Problems dataset. The fine-tuning is performed using the _HybridModel_ paradigm, which enables a seamless separation of the computational workload of large language models between the client and a remote server.

To preserve data privacy while maintaining performance, this hybrid setup leverages Fully Homomorphic Encryption (FHE) on the remote side. The execution pipeline is structured as follows:

- Remote linear layers — which account for the majority of the model's weights and computational cost — are offloaded to a distant machine and executed under encryption using FHE.
- Local non-linear layers — such as activation functions — are retained on-premise and executed in plaintext on the client side.
- The client’s dataset remains strictly local and is never transferred externally.

This approach allows for privacy-preserving fine-tuning and inference, while reducing the computational burden on the client and ensuring that sensitive data never leaves the local environment.

In [2]:
PEFT_ARGS = {
    "r": 8,
    "lora_alpha": 32,
    "lora_dropout": 0.1,
    "bias": "none",
    "task_type": "CAUSAL_LM",
    "target_modules": "all-linear",
}

TRAINING_ARGS = {
    "output_dir": "./checkpoints",
    "num_train_epochs": 1,
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 1,
    "save_total_limit": 1,
    "use_cpu": True,
    "learning_rate": 2e-4,
    "lr_scheduler_type": "linear",
    "seed": SEED,
    "data_seed": SEED,
    "warmup_steps": 10,
    "weight_decay": 0.01,
    "prediction_loss_only": True,
    "report_to": "none",
}


DEVICE = get_device(force_device='cpu')

## Load data

The question-answer dataset has been preprocessed and filtered in the `processed_data.py` script and saved to disk for convenience. We load it here directly to simplify the fine-tuning workflow.

> ⚠️ If the files are missing, please run `processed_data.py` to regenerate them.


In [3]:
collator = DataCollator(TOKENIZER)
train_dataset = load_from_disk(TRAIN_PATH)
test_dataset = load_from_disk(TEST_PATH)

## Load Model and Tokenizer

Load the LLaMA model and tokenizer, and test the base model output.

In [4]:
pretrained_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, token=HF_TOKEN).to(DEVICE)
pretrained_model.config.pad_token_id = pretrained_model.config.eos_token_id

for param in pretrained_model.parameters():
    param.requires_grad = False

PROMPT = "When you multiply a number by 7, it becomes 98. What is that number?\n"
_ = generate_and_print(PROMPT, pretrained_model, TOKENIZER, seed=SEED)

Prompt: `When you multiply a number by 7, it becomes 98. What is that number?
`
Response: `A. 0
B. 1
C. 2
D. 3
E. 4
Answer: B`



## LoRA Configuration

Set up LoRA parameters and apply them to the model.

In [5]:
peft_model = get_peft_model(pretrained_model, LoraConfig(**PEFT_ARGS)).to(DEVICE)

## Training Arguments

Configure the training hyperparameters.

In [6]:
hf_trainer = Trainer(
    model=peft_model,
    args=TrainingArguments(**TRAINING_ARGS),
    train_dataset=train_dataset,
    data_collator=collator,
)

train_dl = hf_trainer.get_train_dataloader()
eval_dl = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collator)

hf_trainer.create_optimizer_and_scheduler(len(train_dl) * TRAINING_ARGS["num_train_epochs"])
optimizer, lr_scheduler = hf_trainer.optimizer, hf_trainer.lr_scheduler


lora_trainer = LoraTrainer(
    model=peft_model,
    optimizer=optimizer,
    loss_fn=causal_lm_loss,
    lr_scheduler=lr_scheduler,
    training_args=TRAINING_ARGS,
    n_layers_to_skip_for_backprop=3,
    eval_loader=eval_dl,
    eval_metric_fn=metric_fn,
    logging_steps=1,
    eval_steps=100,
    train_log_path=TRAIN_LOG_FILE,
    machine_type="M4",
    server_remote_address="http://13.36.240.77:8001",
    model_name=f"meta-llama",
)


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
2025-07-16 08:35:19,187 - INFO - === Starting new training session ===
2025-07-16 08:35:19,190 - INFO - Processing '5' Remote Modules.
2025-07-16 08:35:19,191 - INFO - Benchmark file already created: '/Users/kcelia/Zama/concrete-ml/use_case_examples/deploy_llama_finetuning/client_benchmarks.csv'


LoRA layers detected in the model.


## Compilation

In [None]:
inputset = get_random_inputset(
        vocab_size=VOCAB_SIZE, batch_size=BATCH_SIZE, max_length=MAX_LENGTH, device=DEVICE
    )
start_time = time()
lora_trainer.compile(inputset, n_bits=N_BITS, device=DEVICE)
print(f"Compilation completed under: {time() - start_time:.2f}s using {DEVICE=}")

Compiling FHE layers: 100%|██████████| 5/5 [00:00<00:00,  6.38it/s]


# Evaluate the model before fine-tuning

In [None]:
peft_model.eval()

initial_weights = extract_lora_weights(peft_model)

initial_metrics = metric_fn(peft_model, eval_dl, PROMPT, EVAL_RESPONSES_FILE, DEVICE)
print(f"Final perplexity after extended training: {initial_metrics['perplexity']:.2f}")

Prompt: `When you multiply a number by 7, it becomes 98. What is that number?
`
Response: `A. 0
B. 1
C. 2
D. 3
E. 4
Answer: B`



                                                           

Final perplexity after extended training: 116.61




## Separate Remote Modules

In a hybrid execution setup, we must isolate the parts of the model that will run remotely (typically, the linear layers) from those that will stay on the client side (non-linear layers, activations, etc.).

The following line performs this separation by:

- Saving the compiled remote modules (linear layers quantized and ready for remote execution),
- Removing sensitive information such as calibration data or client-side metadata,


In [9]:
lora_trainer.save_and_clear_private_info(COMPILED_MODELS_PATH, via_mlir=True)

2025-07-16 08:34:25,986 - INFO - Model saved at compiled_models/meta-llama


## Initialize Client-Side Model

Here we generate keys and send the public evaluation key to the server.

In [None]:
client_path = COMPILED_MODELS_PATH / "client"

lora_trainer.hybrid_model.init_client(
    path_to_clients=client_path, path_to_keys=PATH_TO_CLIENTS_KEYS
)

# Enable remote FHE mode: linear layers will be executed on the server
lora_trainer.hybrid_model.set_fhe_mode(HybridFHEMode.REMOTE)

2025-07-16 08:34:26,520 - INFO - Generating keys...
2025-07-16 08:34:27,151 - INFO - Keys generated...
2025-07-16 08:34:27,170 - INFO - Saving the public evaluation key at compiled_models/meta-llama/client/public_evaluation_key.serverKey...


## Run a Short Fine-Tuning Loop with Remote FHE

We fine-tune the model for a few batches using remote FHE mode:

In [None]:
limited_batches = get_limited_batches(train_dl, 3)
lora_trainer.train(limited_batches, fhe="remote", device=DEVICE)

                                              

AssertionError: 

# Evaluate the model after fine-tuning

We evaluate the model on the validation set to compute its final perplexity, a standard metric for language modeling:

In [None]:

finetuned_weights = extract_lora_weights(peft_model)
peft_model.eval()
metrics_final = metric_fn(peft_model, eval_dl, PROMPT, EVAL_RESPONSES_FILE, DEVICE)
print(f"Final perplexity after extended training: {metrics_final['perplexity']:.2f}")

## Benchmark

In [None]:

client = pd.read_csv("client_benchmarks.csv", sep=";")
server = pd.read_csv("server_benchmarks.csv", sep=";")

In [16]:
server

Unnamed: 0,endpoint,date,device,machine,uid,layer_name,index,input_shape,remote_weight_shape,time_read_key,time_deserialization_key,time_serialization_key,time_storage_key,time_read_input,time_deserialize_input,encrypted_input_size,time_weight_quantization,time_serialization_output,time_matmul,time_packing_output_response,total_add_key_func,total_compute_func
0,Key,2025-07-15 19:28:15,cpu,g4dn.16xlarge,ade27f60-523a-4e26-b678-63bfddf1192f,,,,,0.009222,0.023158,0.016631,0.025622,,,,,,,,0.05841,
1,compute,2025-07-15 19:28:50,cpu,g4dn.16xlarge,ade27f60-523a-4e26-b678-63bfddf1192f,inference_model.base_model.model.model.layers....,0.0,"(64, 2048)","(2048, 2048)",,,,,2.3e-05,7.8e-05,530968.0,0.014169,0.000216,34.905582,9.1e-05,,34.929621
2,Key,2025-07-15 19:31:34,cpu,g4dn.16xlarge,72751a51-ad40-460f-8b59-32502bac92cb,,,,,0.007196,0.023966,0.019641,0.027247,,,,,,,,0.058755,
3,compute,2025-07-15 19:32:02,cpu,g4dn.16xlarge,72751a51-ad40-460f-8b59-32502bac92cb,inference_model.base_model.model.model.layers....,0.0,"(64, 2048)","(2048, 2048)",,,,,9e-06,8.4e-05,530968.0,0.00994,0.000159,28.538269,1.6e-05,,28.557032
4,compute,2025-07-15 19:32:11,cpu,g4dn.16xlarge,72751a51-ad40-460f-8b59-32502bac92cb,inference_model.base_model.model.model.layers....,1.0,"(64, 2048)","(2048, 512)",,,,,7.7e-05,0.000136,530968.0,0.005112,0.000112,8.304419,1.4e-05,,8.312503
5,compute,2025-07-15 19:32:19,cpu,g4dn.16xlarge,72751a51-ad40-460f-8b59-32502bac92cb,inference_model.base_model.model.model.layers....,2.0,"(64, 2048)","(2048, 512)",,,,,1.5e-05,5.1e-05,530968.0,0.004382,0.000201,8.446628,1.3e-05,,8.453925
6,compute,2025-07-15 19:32:52,cpu,g4dn.16xlarge,72751a51-ad40-460f-8b59-32502bac92cb,inference_model.base_model.model.model.layers....,3.0,"(64, 2048)","(2048, 2048)",,,,,2.4e-05,0.000305,530968.0,0.003923,0.00024,32.715363,1.4e-05,,32.726437
7,compute,2025-07-15 19:33:26,cpu,g4dn.16xlarge,72751a51-ad40-460f-8b59-32502bac92cb,inference_model.base_model.model.model.layers....,4.0,"(64, 2048)","(2048, 2048)",,,,,4e-05,9.4e-05,530968.0,0.003682,0.000177,33.900339,2.4e-05,,33.91084


In [17]:
client

Unnamed: 0,date,device,machine,uid,server_remote_address,layer_name,input_shape,remote_weight_shape,time_encryption_input,time_serialization_input,total_send_input_func,time_deserialization_output,time_decryption_output,time_dequantization_output,total_compute_func,total_timing
0,2025-07-15 19:28:50,cpu,M4,ade27f60-523a-4e26-b678-63bfddf1192f,http://127.0.0.1:8001,remote_weights_layer0.npy,"(64, 2048)","(2048, 2048)",0.020287,0.000111,,0.000145,0.020579,0.000434,34.934256,35.019693
1,2025-07-15 19:32:02,cpu,M4,72751a51-ad40-460f-8b59-32502bac92cb,http://127.0.0.1:8001,remote_weights_layer0.npy,"(64, 2048)","(2048, 2048)",0.020675,0.000111,,9.4e-05,0.011994,0.000334,28.560515,28.621814
2,2025-07-15 19:32:11,cpu,M4,72751a51-ad40-460f-8b59-32502bac92cb,http://127.0.0.1:8001,remote_weights_layer1.npy,"(64, 2048)","(2048, 512)",0.020222,0.000327,,6e-05,0.011821,0.000569,8.315347,8.366446
3,2025-07-15 19:32:19,cpu,M4,72751a51-ad40-460f-8b59-32502bac92cb,http://127.0.0.1:8001,remote_weights_layer2.npy,"(64, 2048)","(2048, 512)",0.020303,0.000115,,5.5e-05,0.015216,0.00037,8.456745,8.511736
4,2025-07-15 19:32:52,cpu,M4,72751a51-ad40-460f-8b59-32502bac92cb,http://127.0.0.1:8001,remote_weights_layer3.npy,"(64, 2048)","(2048, 2048)",0.023046,4.7e-05,,9.5e-05,0.012572,0.000298,32.729561,32.794023
5,2025-07-15 19:33:26,cpu,M4,72751a51-ad40-460f-8b59-32502bac92cb,http://127.0.0.1:8001,remote_weights_layer4.npy,"(64, 2048)","(2048, 2048)",0.023485,0.000132,,0.000118,0.018325,0.000671,33.914646,33.994649


This separation is enabled by the HybridFHEModel, allowing efficient encrypted inference while preserving data privacy and minimizing computational load on the client side.