In [1]:
import math
import shutil
from pathlib import Path

import torch
from custom_module import CustomConv1D
from peft import LoraConfig, TaskType, get_peft_model
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Conv1D,
    DataCollatorForLanguageModeling,
    TextDataset,
    Trainer,
    TrainingArguments,
)

from concrete.ml.torch.hybrid_model import HybridFHEModel

SEED = 0
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# FREEZE WEIGHTS
for param in model.parameters():
    param.requires_grad = False

In [2]:
def generate_text(prompt, model, tokenizer, max_length=30, fhe="disable"):
    # Encode the input prompt
    inputs = tokenizer.encode_plus(prompt, return_tensors="pt")

    attention_mask = inputs["attention_mask"]

    # Generate text
    output = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=attention_mask,
        max_length=max_length,
        num_return_sequences=1,
        no_repeat_ngram_size=2,
        top_k=50,
        top_p=0.95,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
    )

    # Decode the generated text
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text

In [3]:
# Example usage
prompt = "What is FHE ?"
generated_text = generate_text(prompt, model, tokenizer)
print(generated_text)

What is FHE? FH: A basic program that is used to calculate the height of an object, and then sets the minimum height to be


In [4]:
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=4,
    lora_alpha=32,
    lora_dropout=0.05,
    fan_in_fan_out=True,
)

peft_model = get_peft_model(model, peft_config)

In [5]:
def replace_conv1d(module, module_index_to_skip=0):
    for name, child in module.named_children():
        if isinstance(child, Conv1D):

            # Skip the module if the index has not been reached, and decrement the index
            if module_index_to_skip >= 0:
                module_index_to_skip -= 1
            else:
                custom_linear = CustomConv1D(child.weight, bias=child.bias)
                setattr(module, name, custom_linear)
        else:
            module_index_to_skip = replace_conv1d(child, module_index_to_skip=module_index_to_skip)

    return module_index_to_skip


# Gradients of the first base layer that is used for fine-tuning are not needed. We
# therefore need to exclude the backward module from the remote_names since calibration
# won't get through it (which raises an issue with hybrid models)
replace_conv1d(peft_model, module_index_to_skip=0);

In [6]:
class LoraTraining(torch.nn.Module):
    def __init__(self, inference_model, gradient_accumulation_steps) -> None:
        super().__init__()

        self.inference_model = inference_model

        self.optimizer = None
        self.lr_scheduler = None

        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.max_grad_norm = None

        self.calibrate = False
        self.run_optimizer = False

    def update_training_parameters(self, optimizer, lr_scheduler, training_args):
        assert self.gradient_accumulation_steps == training_args.gradient_accumulation_steps

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.max_grad_norm = training_args.max_grad_norm

    def forward(self, inputs):
        # FIXME: handle multi-inputs in hybrid model
        x, y = inputs

        # some parts on server side
        outputs = self.inference_model(input_ids=x, labels=y)

        loss = outputs.loss
        loss = loss / self.gradient_accumulation_steps

        # Update gradients
        loss.backward()

        grad_norm = None
        if not self.calibrate and self.run_optimizer:
            assert self.optimizer is not None
            assert self.lr_scheduler is not None
            assert self.max_grad_norm is not None

            grad_norm = torch.nn.utils.clip_grad_norm_(
                self.inference_model.parameters(), max_norm=self.max_grad_norm, norm_type=2
            )

            self.optimizer.step()
            self.lr_scheduler.step()

            self.inference_model.zero_grad()

        # Clean gradients after calibration
        elif self.calibrate:
            self.inference_model.zero_grad()

        return (loss, grad_norm)

    def toggle_calibrate(self, enable: bool = True):
        self.calibrate = enable

    def toggle_run_optimizer(self, enable: bool = True):
        self.run_optimizer = enable

In [7]:
GRADIENT_ACCUMULATION_STEPS = 2

lora_training = LoraTraining(peft_model, GRADIENT_ACCUMULATION_STEPS)

In [8]:
BLOCK_SIZE = 128


def load_dataset(file_path, tokenizer):
    dataset = TextDataset(
        tokenizer=tokenizer,
        file_path=file_path,
        block_size=BLOCK_SIZE,
        cache_dir="cache_dataset",
    )
    return dataset


train_dataset = load_dataset("data_finetune/what_is_fhe.txt", tokenizer)

In [9]:
tokenizer.parallelism = False

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

EPOCHS = 100
PER_DEVICE_TRAIN_BATCH_SIZE = 4

training_args = TrainingArguments(
    output_dir="./checkpoints",
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    save_total_limit=1,
    use_cpu=True,
    learning_rate=5e-4,
    logging_strategy="epoch",
    optim="adamw_torch",
    seed=SEED,
    data_seed=SEED,
    weight_decay=0.0,
    warmup_steps=0,
    max_grad_norm=1.0,
)

In [10]:
trainer = Trainer(
    model=peft_model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
)

train_dataloader = trainer.get_train_dataloader()

len_dataloader = len(train_dataloader)
num_update_steps_per_epoch = len_dataloader // training_args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
max_steps = math.ceil(training_args.num_train_epochs * num_update_steps_per_epoch)

trainer.create_optimizer_and_scheduler(num_training_steps=max_steps)

In [11]:
lora_training.update_training_parameters(trainer.optimizer, trainer.lr_scheduler, training_args)

In [12]:
def get_remote_names(model):
    remote_names = []
    for name, module in model.named_modules():
        # Some gradients are not needed for fine-tuning, so need to exclude the backward module
        # from the remote_names since calibration won't get through it (which raises an issue with
        # hybrid models). We however still need to include the associated module's forward pass in
        # the hybrid model
        if isinstance(module, Conv1D):
            remote_names.append(name)

        elif isinstance(module, CustomConv1D):
            remote_names.append(name + ".forward_module")
            remote_names.append(name + ".backward_module")

    return remote_names


remote_names = get_remote_names(lora_training)

hybrid_model = HybridFHEModel(lora_training, module_names=remote_names)

In [13]:
input_tensor = torch.randint(0, 2, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE)) * (
    tokenizer.vocab_size - 1
)
label_tensor = torch.randint(0, 2, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE)) * (
    tokenizer.vocab_size - 1
)

inputset = (input_tensor, label_tensor)

In [14]:
hybrid_model.model.toggle_calibrate(enable=True)

hybrid_model.compile_model(
    inputset, n_bits=8, rounding_threshold_bits={"n_bits": 6, "method": "approximate"}
)

hybrid_model.model.toggle_calibrate(enable=False)

In [15]:
def train_custom_model(hybrid_model, train_dataloader, training_args, fhe="disable"):
    device = "cpu"
    hybrid_model.model.to(device)

    # Training loop
    hybrid_model.model.inference_model.train()

    total_epochs = int(training_args.num_train_epochs)
    epoch_pbar = tqdm(total=total_epochs, desc="Training Progress", position=0)

    total_batched_samples = 0
    for epoch in range(total_epochs):
        total_loss = 0
        grad_norms = []

        steps_in_epoch = len(train_dataloader)
        for step, batch in enumerate(train_dataloader):
            total_batched_samples += 1

            batch = {k: v.to(device) for k, v in batch.items()}

            # Gradient accumulation
            is_last_batch_step = (
                steps_in_epoch <= training_args.gradient_accumulation_steps
                and (step + 1) == steps_in_epoch
            )
            accumulate_gradients = (
                total_batched_samples % training_args.gradient_accumulation_steps == 0
            )

            run_optimizer = is_last_batch_step or accumulate_gradients

            hybrid_model.model.toggle_run_optimizer(enable=run_optimizer)

            loss, grad_norm = hybrid_model((batch["input_ids"], batch["labels"]), fhe=fhe)

            total_loss += loss.item()

            if grad_norm is not None:
                grad_norms.append(grad_norm)

        # Get current learning rate
        current_lr = hybrid_model.model.lr_scheduler.get_last_lr()[0]

        # Get last grad norm
        current_grad_norm = grad_norms[-1]

        # Log epoch results
        print(
            f"Epoch {epoch + 1}/{training_args.num_train_epochs}, "
            f"Loss: {total_loss:.4f}, grad norm: {current_grad_norm}, lr: {current_lr}"
        )

        epoch_pbar.update(1)

    # Save model checkpoint
    if training_args.output_dir is not None:
        save_path = f"{training_args.output_dir}/checkpoint-{epoch + 1}"
        hybrid_model.model.inference_model.save_pretrained(save_path)

    epoch_pbar.close()

In [16]:
torch.manual_seed(SEED)

train_custom_model(hybrid_model, train_dataloader, training_args, fhe="disable")


Training Progress:   0%|          | 0/100 [00:00<?, ?it/s]


Training Progress:   1%|          | 1/100 [00:01<01:56,  1.17s/it]

Epoch 1/100, Loss: 1.5326, grad norm: 0.6547388434410095, lr: 0.000495



Training Progress:   2%|▏         | 2/100 [00:01<01:10,  1.39it/s]

Epoch 2/100, Loss: 1.5092, grad norm: 0.5734729170799255, lr: 0.00049



Training Progress:   3%|▎         | 3/100 [00:01<00:54,  1.77it/s]

Epoch 3/100, Loss: 1.4762, grad norm: 0.4197540581226349, lr: 0.00048499999999999997



Training Progress:   4%|▍         | 4/100 [00:02<00:47,  2.02it/s]

Epoch 4/100, Loss: 1.5085, grad norm: 0.569969117641449, lr: 0.00048



Training Progress:   5%|▌         | 5/100 [00:02<00:42,  2.21it/s]

Epoch 5/100, Loss: 1.4667, grad norm: 0.5897998213768005, lr: 0.000475



Training Progress:   6%|▌         | 6/100 [00:03<00:40,  2.34it/s]

Epoch 6/100, Loss: 1.4486, grad norm: 0.44352057576179504, lr: 0.00047



Training Progress:   7%|▋         | 7/100 [00:03<00:38,  2.42it/s]

Epoch 7/100, Loss: 1.4159, grad norm: 0.506279706954956, lr: 0.000465



Training Progress:   8%|▊         | 8/100 [00:03<00:37,  2.48it/s]

Epoch 8/100, Loss: 1.4051, grad norm: 0.6538838148117065, lr: 0.00046



Training Progress:   9%|▉         | 9/100 [00:04<00:36,  2.52it/s]

Epoch 9/100, Loss: 1.3889, grad norm: 0.6592888236045837, lr: 0.000455



Training Progress:  10%|█         | 10/100 [00:04<00:35,  2.55it/s]

Epoch 10/100, Loss: 1.3730, grad norm: 0.6219719052314758, lr: 0.00045000000000000004



Training Progress:  11%|█         | 11/100 [00:05<00:34,  2.57it/s]

Epoch 11/100, Loss: 1.3494, grad norm: 0.6324566602706909, lr: 0.00044500000000000003



Training Progress:  12%|█▏        | 12/100 [00:05<00:34,  2.59it/s]

Epoch 12/100, Loss: 1.3169, grad norm: 0.5421789288520813, lr: 0.00044



Training Progress:  13%|█▎        | 13/100 [00:05<00:34,  2.56it/s]

Epoch 13/100, Loss: 1.3013, grad norm: 0.5423504114151001, lr: 0.000435



Training Progress:  14%|█▍        | 14/100 [00:06<00:33,  2.59it/s]

Epoch 14/100, Loss: 1.3303, grad norm: 0.6302087903022766, lr: 0.00043



Training Progress:  15%|█▌        | 15/100 [00:06<00:32,  2.59it/s]

Epoch 15/100, Loss: 1.2771, grad norm: 0.5095004439353943, lr: 0.000425



Training Progress:  16%|█▌        | 16/100 [00:06<00:32,  2.59it/s]

Epoch 16/100, Loss: 1.2506, grad norm: 0.5400538444519043, lr: 0.00042



Training Progress:  17%|█▋        | 17/100 [00:07<00:31,  2.60it/s]

Epoch 17/100, Loss: 1.2341, grad norm: 0.5874373316764832, lr: 0.000415



Training Progress:  18%|█▊        | 18/100 [00:07<00:31,  2.61it/s]

Epoch 18/100, Loss: 1.2215, grad norm: 0.5731167793273926, lr: 0.00041



Training Progress:  19%|█▉        | 19/100 [00:08<00:30,  2.62it/s]

Epoch 19/100, Loss: 1.1856, grad norm: 0.5122016072273254, lr: 0.00040500000000000003



Training Progress:  20%|██        | 20/100 [00:08<00:30,  2.62it/s]

Epoch 20/100, Loss: 1.1938, grad norm: 0.5971183180809021, lr: 0.0004



Training Progress:  21%|██        | 21/100 [00:08<00:30,  2.62it/s]

Epoch 21/100, Loss: 1.1668, grad norm: 0.6376621127128601, lr: 0.000395



Training Progress:  22%|██▏       | 22/100 [00:09<00:29,  2.60it/s]

Epoch 22/100, Loss: 1.1436, grad norm: 0.5452390909194946, lr: 0.00039000000000000005



Training Progress:  23%|██▎       | 23/100 [00:09<00:29,  2.62it/s]

Epoch 23/100, Loss: 1.1361, grad norm: 0.5471293330192566, lr: 0.00038500000000000003



Training Progress:  24%|██▍       | 24/100 [00:10<00:29,  2.58it/s]

Epoch 24/100, Loss: 1.1000, grad norm: 0.6130229234695435, lr: 0.00038



Training Progress:  25%|██▌       | 25/100 [00:10<00:28,  2.60it/s]

Epoch 25/100, Loss: 1.0795, grad norm: 0.6525614261627197, lr: 0.000375



Training Progress:  26%|██▌       | 26/100 [00:10<00:28,  2.57it/s]

Epoch 26/100, Loss: 1.0930, grad norm: 0.9915198683738708, lr: 0.00037



Training Progress:  27%|██▋       | 27/100 [00:11<00:28,  2.57it/s]

Epoch 27/100, Loss: 1.0531, grad norm: 0.590857207775116, lr: 0.000365



Training Progress:  28%|██▊       | 28/100 [00:11<00:28,  2.55it/s]

Epoch 28/100, Loss: 1.0564, grad norm: 0.8754357695579529, lr: 0.00035999999999999997



Training Progress:  29%|██▉       | 29/100 [00:11<00:27,  2.58it/s]

Epoch 29/100, Loss: 1.0520, grad norm: 0.8149130344390869, lr: 0.000355



Training Progress:  30%|███       | 30/100 [00:12<00:27,  2.59it/s]

Epoch 30/100, Loss: 1.0313, grad norm: 0.5920228958129883, lr: 0.00035



Training Progress:  31%|███       | 31/100 [00:12<00:26,  2.58it/s]

Epoch 31/100, Loss: 1.0182, grad norm: 0.6779032349586487, lr: 0.000345



Training Progress:  32%|███▏      | 32/100 [00:13<00:26,  2.58it/s]

Epoch 32/100, Loss: 0.9980, grad norm: 0.5544361472129822, lr: 0.00034



Training Progress:  33%|███▎      | 33/100 [00:13<00:26,  2.58it/s]

Epoch 33/100, Loss: 0.9932, grad norm: 0.7196674942970276, lr: 0.000335



Training Progress:  34%|███▍      | 34/100 [00:13<00:25,  2.60it/s]

Epoch 34/100, Loss: 0.9819, grad norm: 0.7083548903465271, lr: 0.00033



Training Progress:  35%|███▌      | 35/100 [00:14<00:24,  2.62it/s]

Epoch 35/100, Loss: 0.9250, grad norm: 0.7313346266746521, lr: 0.00032500000000000004



Training Progress:  36%|███▌      | 36/100 [00:14<00:24,  2.63it/s]

Epoch 36/100, Loss: 0.9198, grad norm: 0.6564635634422302, lr: 0.00032



Training Progress:  37%|███▋      | 37/100 [00:15<00:23,  2.64it/s]

Epoch 37/100, Loss: 0.9157, grad norm: 0.7937288880348206, lr: 0.000315



Training Progress:  38%|███▊      | 38/100 [00:15<00:23,  2.65it/s]

Epoch 38/100, Loss: 0.8932, grad norm: 0.6338443756103516, lr: 0.00031



Training Progress:  39%|███▉      | 39/100 [00:15<00:23,  2.65it/s]

Epoch 39/100, Loss: 0.9295, grad norm: 0.8935690522193909, lr: 0.000305



Training Progress:  40%|████      | 40/100 [00:16<00:22,  2.65it/s]

Epoch 40/100, Loss: 0.8730, grad norm: 0.7592346668243408, lr: 0.0003



Training Progress:  41%|████      | 41/100 [00:16<00:22,  2.64it/s]

Epoch 41/100, Loss: 0.8485, grad norm: 0.7101594805717468, lr: 0.000295



Training Progress:  42%|████▏     | 42/100 [00:16<00:21,  2.65it/s]

Epoch 42/100, Loss: 0.8411, grad norm: 0.6478201150894165, lr: 0.00029



Training Progress:  43%|████▎     | 43/100 [00:17<00:21,  2.65it/s]

Epoch 43/100, Loss: 0.8544, grad norm: 0.7164880037307739, lr: 0.000285



Training Progress:  44%|████▍     | 44/100 [00:17<00:21,  2.64it/s]

Epoch 44/100, Loss: 0.8414, grad norm: 0.7436962127685547, lr: 0.00028000000000000003



Training Progress:  45%|████▌     | 45/100 [00:18<00:20,  2.63it/s]

Epoch 45/100, Loss: 0.8121, grad norm: 0.9844059944152832, lr: 0.000275



Training Progress:  46%|████▌     | 46/100 [00:18<00:20,  2.62it/s]

Epoch 46/100, Loss: 0.8048, grad norm: 0.9871523976325989, lr: 0.00027



Training Progress:  47%|████▋     | 47/100 [00:18<00:20,  2.62it/s]

Epoch 47/100, Loss: 0.8153, grad norm: 0.8394853472709656, lr: 0.00026500000000000004



Training Progress:  48%|████▊     | 48/100 [00:19<00:19,  2.62it/s]

Epoch 48/100, Loss: 0.8046, grad norm: 0.9217925667762756, lr: 0.00026000000000000003



Training Progress:  49%|████▉     | 49/100 [00:19<00:19,  2.61it/s]

Epoch 49/100, Loss: 0.7614, grad norm: 1.015302062034607, lr: 0.000255



Training Progress:  50%|█████     | 50/100 [00:19<00:19,  2.60it/s]

Epoch 50/100, Loss: 0.7760, grad norm: 0.9043252468109131, lr: 0.00025



Training Progress:  51%|█████     | 51/100 [00:20<00:18,  2.60it/s]

Epoch 51/100, Loss: 0.7693, grad norm: 0.8068227767944336, lr: 0.000245



Training Progress:  52%|█████▏    | 52/100 [00:20<00:18,  2.61it/s]

Epoch 52/100, Loss: 0.7422, grad norm: 0.9263298511505127, lr: 0.00024



Training Progress:  53%|█████▎    | 53/100 [00:21<00:17,  2.61it/s]

Epoch 53/100, Loss: 0.7486, grad norm: 1.0840318202972412, lr: 0.000235



Training Progress:  54%|█████▍    | 54/100 [00:21<00:17,  2.61it/s]

Epoch 54/100, Loss: 0.7469, grad norm: 0.8277450799942017, lr: 0.00023



Training Progress:  55%|█████▌    | 55/100 [00:21<00:17,  2.61it/s]

Epoch 55/100, Loss: 0.7148, grad norm: 0.8486602306365967, lr: 0.00022500000000000002



Training Progress:  56%|█████▌    | 56/100 [00:22<00:16,  2.61it/s]

Epoch 56/100, Loss: 0.7018, grad norm: 0.9315493106842041, lr: 0.00022



Training Progress:  57%|█████▋    | 57/100 [00:22<00:16,  2.61it/s]

Epoch 57/100, Loss: 0.6978, grad norm: 0.8715642690658569, lr: 0.000215



Training Progress:  58%|█████▊    | 58/100 [00:23<00:16,  2.61it/s]

Epoch 58/100, Loss: 0.6954, grad norm: 0.9117729067802429, lr: 0.00021



Training Progress:  59%|█████▉    | 59/100 [00:23<00:15,  2.61it/s]

Epoch 59/100, Loss: 0.6805, grad norm: 0.8932844996452332, lr: 0.000205



Training Progress:  60%|██████    | 60/100 [00:23<00:15,  2.60it/s]

Epoch 60/100, Loss: 0.6801, grad norm: 1.0779385566711426, lr: 0.0002



Training Progress:  61%|██████    | 61/100 [00:24<00:14,  2.61it/s]

Epoch 61/100, Loss: 0.6582, grad norm: 0.9519742131233215, lr: 0.00019500000000000002



Training Progress:  62%|██████▏   | 62/100 [00:24<00:14,  2.60it/s]

Epoch 62/100, Loss: 0.6777, grad norm: 1.0926264524459839, lr: 0.00019



Training Progress:  63%|██████▎   | 63/100 [00:24<00:14,  2.59it/s]

Epoch 63/100, Loss: 0.6813, grad norm: 1.2714309692382812, lr: 0.000185



Training Progress:  64%|██████▍   | 64/100 [00:25<00:13,  2.60it/s]

Epoch 64/100, Loss: 0.6696, grad norm: 1.0693631172180176, lr: 0.00017999999999999998



Training Progress:  65%|██████▌   | 65/100 [00:25<00:13,  2.59it/s]

Epoch 65/100, Loss: 0.6621, grad norm: 1.1618248224258423, lr: 0.000175



Training Progress:  66%|██████▌   | 66/100 [00:26<00:13,  2.59it/s]

Epoch 66/100, Loss: 0.6314, grad norm: 0.9860178232192993, lr: 0.00017



Training Progress:  67%|██████▋   | 67/100 [00:26<00:12,  2.59it/s]

Epoch 67/100, Loss: 0.6267, grad norm: 1.0095081329345703, lr: 0.000165



Training Progress:  68%|██████▊   | 68/100 [00:26<00:12,  2.57it/s]

Epoch 68/100, Loss: 0.6275, grad norm: 0.9747483134269714, lr: 0.00016



Training Progress:  69%|██████▉   | 69/100 [00:27<00:12,  2.58it/s]

Epoch 69/100, Loss: 0.6391, grad norm: 1.1718988418579102, lr: 0.000155



Training Progress:  70%|███████   | 70/100 [00:27<00:11,  2.54it/s]

Epoch 70/100, Loss: 0.6172, grad norm: 0.8902360796928406, lr: 0.00015



Training Progress:  71%|███████   | 71/100 [00:28<00:11,  2.55it/s]

Epoch 71/100, Loss: 0.6180, grad norm: 1.0743216276168823, lr: 0.000145



Training Progress:  72%|███████▏  | 72/100 [00:28<00:10,  2.56it/s]

Epoch 72/100, Loss: 0.6124, grad norm: 1.4731453657150269, lr: 0.00014000000000000001



Training Progress:  73%|███████▎  | 73/100 [00:28<00:10,  2.57it/s]

Epoch 73/100, Loss: 0.5906, grad norm: 1.2012979984283447, lr: 0.000135



Training Progress:  74%|███████▍  | 74/100 [00:29<00:10,  2.57it/s]

Epoch 74/100, Loss: 0.5892, grad norm: 1.3028196096420288, lr: 0.00013000000000000002



Training Progress:  75%|███████▌  | 75/100 [00:29<00:09,  2.58it/s]

Epoch 75/100, Loss: 0.5887, grad norm: 1.0304925441741943, lr: 0.000125



Training Progress:  76%|███████▌  | 76/100 [00:30<00:09,  2.58it/s]

Epoch 76/100, Loss: 0.5687, grad norm: 0.9565426707267761, lr: 0.00012



Training Progress:  77%|███████▋  | 77/100 [00:30<00:09,  2.55it/s]

Epoch 77/100, Loss: 0.5913, grad norm: 1.1523699760437012, lr: 0.000115



Training Progress:  78%|███████▊  | 78/100 [00:30<00:08,  2.52it/s]

Epoch 78/100, Loss: 0.5948, grad norm: 1.1738296747207642, lr: 0.00011



Training Progress:  79%|███████▉  | 79/100 [00:31<00:08,  2.49it/s]

Epoch 79/100, Loss: 0.5651, grad norm: 1.260327696800232, lr: 0.000105



Training Progress:  80%|████████  | 80/100 [00:31<00:08,  2.50it/s]

Epoch 80/100, Loss: 0.5798, grad norm: 1.1174153089523315, lr: 0.0001



Training Progress:  81%|████████  | 81/100 [00:32<00:07,  2.39it/s]

Epoch 81/100, Loss: 0.5539, grad norm: 0.9862734079360962, lr: 9.5e-05



Training Progress:  82%|████████▏ | 82/100 [00:32<00:07,  2.45it/s]

Epoch 82/100, Loss: 0.5794, grad norm: 0.9966534972190857, lr: 8.999999999999999e-05



Training Progress:  83%|████████▎ | 83/100 [00:32<00:06,  2.48it/s]

Epoch 83/100, Loss: 0.5628, grad norm: 1.0192633867263794, lr: 8.5e-05



Training Progress:  84%|████████▍ | 84/100 [00:33<00:06,  2.50it/s]

Epoch 84/100, Loss: 0.5633, grad norm: 0.9687011241912842, lr: 8e-05



Training Progress:  85%|████████▌ | 85/100 [00:33<00:05,  2.55it/s]

Epoch 85/100, Loss: 0.5684, grad norm: 1.0944551229476929, lr: 7.5e-05



Training Progress:  86%|████████▌ | 86/100 [00:33<00:05,  2.61it/s]

Epoch 86/100, Loss: 0.5491, grad norm: 1.0738519430160522, lr: 7.000000000000001e-05



Training Progress:  87%|████████▋ | 87/100 [00:34<00:04,  2.69it/s]

Epoch 87/100, Loss: 0.5521, grad norm: 1.282754898071289, lr: 6.500000000000001e-05



Training Progress:  88%|████████▊ | 88/100 [00:34<00:04,  2.76it/s]

Epoch 88/100, Loss: 0.5486, grad norm: 0.9802149534225464, lr: 6e-05



Training Progress:  89%|████████▉ | 89/100 [00:35<00:03,  2.81it/s]

Epoch 89/100, Loss: 0.5413, grad norm: 1.0144387483596802, lr: 5.5e-05



Training Progress:  90%|█████████ | 90/100 [00:35<00:03,  2.85it/s]

Epoch 90/100, Loss: 0.5474, grad norm: 1.223002552986145, lr: 5e-05



Training Progress:  91%|█████████ | 91/100 [00:35<00:03,  2.87it/s]

Epoch 91/100, Loss: 0.5436, grad norm: 1.1522656679153442, lr: 4.4999999999999996e-05



Training Progress:  92%|█████████▏| 92/100 [00:36<00:02,  2.88it/s]

Epoch 92/100, Loss: 0.5376, grad norm: 1.1914536952972412, lr: 4e-05



Training Progress:  93%|█████████▎| 93/100 [00:36<00:02,  2.90it/s]

Epoch 93/100, Loss: 0.5398, grad norm: 1.0207066535949707, lr: 3.5000000000000004e-05



Training Progress:  94%|█████████▍| 94/100 [00:36<00:02,  2.91it/s]

Epoch 94/100, Loss: 0.5116, grad norm: 1.0995105504989624, lr: 3e-05



Training Progress:  95%|█████████▌| 95/100 [00:37<00:01,  3.06it/s]

Epoch 95/100, Loss: 0.5242, grad norm: 1.0830743312835693, lr: 2.5e-05



Training Progress:  96%|█████████▌| 96/100 [00:37<00:01,  3.00it/s]

Epoch 96/100, Loss: 0.5603, grad norm: 1.2351734638214111, lr: 2e-05



Training Progress:  97%|█████████▋| 97/100 [00:37<00:01,  2.97it/s]

Epoch 97/100, Loss: 0.5152, grad norm: 0.9580557346343994, lr: 1.5e-05



Training Progress:  98%|█████████▊| 98/100 [00:38<00:00,  2.96it/s]

Epoch 98/100, Loss: 0.5217, grad norm: 0.9174291491508484, lr: 1e-05



Training Progress:  99%|█████████▉| 99/100 [00:38<00:00,  2.95it/s]

Epoch 99/100, Loss: 0.5333, grad norm: 1.0415540933609009, lr: 5e-06



Training Progress: 100%|██████████| 100/100 [00:38<00:00,  2.94it/s]

Epoch 100/100, Loss: 0.5352, grad norm: 0.9800519943237305, lr: 0.0



Training Progress: 100%|██████████| 100/100 [00:38<00:00,  2.56it/s]




In [17]:
fine_tuned_model = hybrid_model.model.inference_model

hybrid_model.set_fhe_mode("disable")

In [18]:
# Example usage
prompt = "What is FHE ?"
generated_text = generate_text(prompt, fine_tuned_model, tokenizer)
print(generated_text)

What is FHE?

FHE is a cryptographic technique that enables computations on arbitrary data structures. It consists in generating computable FAs


In [None]:
with peft_model.disable_adapter_layers():
    # Example usage
    prompt = "What is FHE ?"
    generated_text = generate_text(prompt, fine_tuned_model, tokenizer)
    print(generated_text)

In [18]:
def print_weights_and_size(model, print_detail=False):
    total_weights = 0
    total_lora_weights = 0
    for name, param in model.named_parameters():
        total_weights += param.numel()

        if "lora" in name:
            total_lora_weights += param.numel()

        if print_detail:
            print(name, param.numel())

    print(f"Total number of weights: {total_weights}")
    print(f"Total number of LoRA weights: {total_lora_weights}")

    return total_weights

In [19]:
total_weights_size = print_weights_and_size(hybrid_model.model)

Total number of weights: 124587264
Total number of LoRA weights: 147456


In [21]:
path = Path("gpt2_lora_finetuned_hybrid_deployment")

if path.is_dir() and any(path.iterdir()):
    shutil.rmtree(path)

hybrid_model.save_and_clear_private_info(path)

In [22]:
total_weights_size_private = print_weights_and_size(hybrid_model.model)

Total number of weights: 39569664


In [23]:
print(
    "Total weights removed: "
    f"{(total_weights_size - total_weights_size_private) / total_weights_size * 100:.2f} %"
)

Weights removed: 68.24 %
