In [1]:
import math
import shutil
from pathlib import Path

import torch
from custom_module import CustomConv1D
from peft import LoraConfig, TaskType, get_peft_model
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Conv1D,
    DataCollatorForLanguageModeling,
    TextDataset,
    Trainer,
    TrainingArguments,
)

from concrete.ml.torch.hybrid_model import HybridFHEModel

SEED = 0
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# FREEZE WEIGHTS
for param in model.parameters():
    param.requires_grad = False

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [2]:
def generate_text(prompt, model, tokenizer, max_length=30, fhe="disable"):
    # Encode the input prompt
    inputs = tokenizer.encode_plus(prompt, return_tensors="pt")

    attention_mask = inputs["attention_mask"]

    # Generate text
    output = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=attention_mask,
        max_length=max_length,
        num_return_sequences=1,
        no_repeat_ngram_size=2,
        top_k=50,
        top_p=0.95,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
    )

    # Decode the generated text
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text

In [3]:
# Example usage
prompt = "What is FHE ?"
generated_text = generate_text(prompt, model, tokenizer)
print(generated_text)

What is FHE? FH: A basic program that is used to calculate the height of an object, and then sets the minimum height to be


In [4]:
# Example usage
prompt = "Who's Barack Obama ?"
generated_text = generate_text(prompt, model, tokenizer)
print(generated_text)

Who's Barack Obama?

I think he's very much in a position to be president. I mean, he was nominated by his party.


In [5]:
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=4,
    lora_alpha=32,
    lora_dropout=0.05,
    fan_in_fan_out=True,
)

peft_model = get_peft_model(model, peft_config)

In [6]:
def replace_conv1d(module, module_index_to_skip=0):
    for name, child in module.named_children():
        if isinstance(child, Conv1D):

            # Skip the module if the index has not been reached, and decrement the index
            if module_index_to_skip >= 0:
                module_index_to_skip -= 1
            else:
                custom_linear = CustomConv1D(child.weight, bias=child.bias)
                setattr(module, name, custom_linear)
        else:
            module_index_to_skip = replace_conv1d(child, module_index_to_skip=module_index_to_skip)

    return module_index_to_skip


# Gradients of the first base layer that is used for fine-tuning are not needed. We
# therefore need to exclude the backward module from the remote_names since calibration
# won't get through it (which raises an issue with hybrid models)
replace_conv1d(peft_model, module_index_to_skip=0);

In [7]:
class LoraTraining(torch.nn.Module):
    def __init__(self, inference_model, gradient_accumulation_steps) -> None:
        super().__init__()

        self.inference_model = inference_model

        self.optimizer = None
        self.lr_scheduler = None

        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.max_grad_norm = None

        self.calibrate = False
        self.run_optimizer = False

    def update_training_parameters(self, optimizer, lr_scheduler, training_args):
        assert self.gradient_accumulation_steps == training_args.gradient_accumulation_steps

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.max_grad_norm = training_args.max_grad_norm

    def forward(self, inputs):
        # FIXME: handle multi-inputs in hybrid model
        x, y = inputs

        # some parts on server side
        outputs = self.inference_model(input_ids=x, labels=y)

        loss = outputs.loss
        loss = loss / self.gradient_accumulation_steps

        # Update gradients
        loss.backward()

        grad_norm = None
        if not self.calibrate and self.run_optimizer:
            assert self.optimizer is not None
            assert self.lr_scheduler is not None
            assert self.max_grad_norm is not None

            grad_norm = torch.nn.utils.clip_grad_norm_(
                self.inference_model.parameters(), max_norm=self.max_grad_norm, norm_type=2
            )

            self.optimizer.step()
            self.lr_scheduler.step()

            self.inference_model.zero_grad()

        # Clean gradients after calibration
        elif self.calibrate:
            self.inference_model.zero_grad()

        return (loss, grad_norm)

    def toggle_calibrate(self, enable: bool = True):
        self.calibrate = enable

    def toggle_run_optimizer(self, enable: bool = True):
        self.run_optimizer = enable

In [8]:
GRADIENT_ACCUMULATION_STEPS = 2

lora_training = LoraTraining(peft_model, GRADIENT_ACCUMULATION_STEPS)

In [9]:
BLOCK_SIZE = 128


def load_dataset(file_path, tokenizer):
    dataset = TextDataset(
        tokenizer=tokenizer,
        file_path=file_path,
        block_size=BLOCK_SIZE,
        cache_dir="cache_dataset",
    )
    return dataset


train_dataset = load_dataset("data_finetune/what_is_fhe.txt", tokenizer)

In [10]:
tokenizer.parallelism = False

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

EPOCHS = 100
PER_DEVICE_TRAIN_BATCH_SIZE = 4

training_args = TrainingArguments(
    output_dir="./checkpoints",
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    save_total_limit=1,
    use_cpu=True,
    learning_rate=5e-4,
    logging_strategy="epoch",
    optim="adamw_torch",
    seed=SEED,
    data_seed=SEED,
    weight_decay=0.0,
    warmup_steps=0,
    max_grad_norm=1.0,
)

In [11]:
trainer = Trainer(
    model=peft_model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
)

train_dataloader = trainer.get_train_dataloader()

len_dataloader = len(train_dataloader)
num_update_steps_per_epoch = len_dataloader // training_args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
max_steps = math.ceil(training_args.num_train_epochs * num_update_steps_per_epoch)

trainer.create_optimizer_and_scheduler(num_training_steps=max_steps)

In [12]:
lora_training.update_training_parameters(trainer.optimizer, trainer.lr_scheduler, training_args)

In [13]:
def get_remote_names(model):
    remote_names = []
    for name, module in model.named_modules():
        # Some gradients are not needed for fine-tuning, so need to exclude the backward module
        # from the remote_names since calibration won't get through it (which raises an issue with
        # hybrid models). We however still need to include the associated module's forward pass in
        # the hybrid model
        if isinstance(module, Conv1D):
            remote_names.append(name)

        elif isinstance(module, CustomConv1D):
            remote_names.append(name + ".forward_module")
            remote_names.append(name + ".backward_module")

    return remote_names


remote_names = get_remote_names(lora_training)

hybrid_model = HybridFHEModel(lora_training, module_names=remote_names)

In [14]:
input_tensor = torch.randint(0, 2, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE)) * (
    tokenizer.vocab_size - 1
)
label_tensor = torch.randint(0, 2, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE)) * (
    tokenizer.vocab_size - 1
)

inputset = (input_tensor, label_tensor)

In [15]:
hybrid_model.model.toggle_calibrate(enable=True)

hybrid_model.compile_model(
    inputset, n_bits=8, rounding_threshold_bits={"n_bits": 6, "method": "approximate"}, p_error=1e-5
)

hybrid_model.model.toggle_calibrate(enable=False)

In [16]:
def train_custom_model(hybrid_model, train_dataloader, training_args, fhe="disable"):
    device = "cpu"
    hybrid_model.model.to(device)

    # Training loop
    hybrid_model.model.inference_model.train()

    total_epochs = int(training_args.num_train_epochs)
    epoch_pbar = tqdm(total=total_epochs, desc="Training Progress", position=0)

    total_batched_samples = 0
    for epoch in range(total_epochs):
        total_loss = 0
        grad_norms = []

        steps_in_epoch = len(train_dataloader)
        for step, batch in enumerate(train_dataloader):
            total_batched_samples += 1

            batch = {k: v.to(device) for k, v in batch.items()}

            # Gradient accumulation
            is_last_batch_step = (
                steps_in_epoch <= training_args.gradient_accumulation_steps
                and (step + 1) == steps_in_epoch
            )
            accumulate_gradients = (
                total_batched_samples % training_args.gradient_accumulation_steps == 0
            )

            run_optimizer = is_last_batch_step or accumulate_gradients

            hybrid_model.model.toggle_run_optimizer(enable=run_optimizer)

            loss, grad_norm = hybrid_model((batch["input_ids"], batch["labels"]), fhe=fhe)

            total_loss += loss.item()

            if grad_norm is not None:
                grad_norms.append(grad_norm)

        # Get current learning rate
        current_lr = hybrid_model.model.lr_scheduler.get_last_lr()[0]

        # Get last grad norm
        current_grad_norm = grad_norms[-1]

        # Log epoch results
        print(
            f"Epoch {epoch + 1}/{training_args.num_train_epochs}, "
            f"Loss: {total_loss:.4f}, grad norm: {current_grad_norm}, lr: {current_lr}"
        )

        epoch_pbar.update(1)

    # Save model checkpoint
    if training_args.output_dir is not None:
        save_path = f"{training_args.output_dir}/checkpoint-{epoch + 1}"
        hybrid_model.model.inference_model.save_pretrained(save_path)

    epoch_pbar.close()

In [17]:
torch.manual_seed(SEED)

train_custom_model(hybrid_model, train_dataloader, training_args, fhe="disable")

Training Progress:   0%|          | 0/100 [00:00<?, ?it/s]

Training Progress:   1%|          | 1/100 [00:01<02:12,  1.34s/it]

Epoch 1/100, Loss: 1.5326, grad norm: 0.660683274269104, lr: 0.000495


Training Progress:   2%|▏         | 2/100 [00:01<01:13,  1.33it/s]

Epoch 2/100, Loss: 1.5084, grad norm: 0.5709879994392395, lr: 0.00049


Training Progress:   3%|▎         | 3/100 [00:01<00:53,  1.81it/s]

Epoch 3/100, Loss: 1.4739, grad norm: 0.41818052530288696, lr: 0.00048499999999999997


Training Progress:   4%|▍         | 4/100 [00:02<00:44,  2.17it/s]

Epoch 4/100, Loss: 1.5047, grad norm: 0.5671930313110352, lr: 0.00048


Training Progress:   5%|▌         | 5/100 [00:02<00:38,  2.44it/s]

Epoch 5/100, Loss: 1.4619, grad norm: 0.5933054685592651, lr: 0.000475


Training Progress:   6%|▌         | 6/100 [00:02<00:35,  2.64it/s]

Epoch 6/100, Loss: 1.4429, grad norm: 0.4501505196094513, lr: 0.00047


Training Progress:   7%|▋         | 7/100 [00:03<00:33,  2.78it/s]

Epoch 7/100, Loss: 1.4080, grad norm: 0.5017417669296265, lr: 0.000465


Training Progress:   8%|▊         | 8/100 [00:03<00:33,  2.75it/s]

Epoch 8/100, Loss: 1.4019, grad norm: 0.5268748998641968, lr: 0.00046


Training Progress:   9%|▉         | 9/100 [00:03<00:32,  2.79it/s]

Epoch 9/100, Loss: 1.3801, grad norm: 0.6914800405502319, lr: 0.000455


Training Progress:  10%|█         | 10/100 [00:04<00:32,  2.79it/s]

Epoch 10/100, Loss: 1.3623, grad norm: 0.6342785358428955, lr: 0.00045000000000000004


Training Progress:  11%|█         | 11/100 [00:04<00:31,  2.78it/s]

Epoch 11/100, Loss: 1.3374, grad norm: 0.5781344771385193, lr: 0.00044500000000000003


Training Progress:  12%|█▏        | 12/100 [00:05<00:32,  2.75it/s]

Epoch 12/100, Loss: 1.3084, grad norm: 0.5696594715118408, lr: 0.00044


Training Progress:  13%|█▎        | 13/100 [00:05<00:31,  2.78it/s]

Epoch 13/100, Loss: 1.2947, grad norm: 0.6448227763175964, lr: 0.000435


Training Progress:  14%|█▍        | 14/100 [00:05<00:30,  2.86it/s]

Epoch 14/100, Loss: 1.3175, grad norm: 0.5205936431884766, lr: 0.00043


Training Progress:  15%|█▌        | 15/100 [00:06<00:28,  2.94it/s]

Epoch 15/100, Loss: 1.2651, grad norm: 0.5077526569366455, lr: 0.000425


Training Progress:  16%|█▌        | 16/100 [00:06<00:28,  2.99it/s]

Epoch 16/100, Loss: 1.2341, grad norm: 0.5621659159660339, lr: 0.00042


Training Progress:  17%|█▋        | 17/100 [00:06<00:27,  3.03it/s]

Epoch 17/100, Loss: 1.2169, grad norm: 0.6973215937614441, lr: 0.000415


Training Progress:  18%|█▊        | 18/100 [00:07<00:27,  2.97it/s]

Epoch 18/100, Loss: 1.2161, grad norm: 0.565183162689209, lr: 0.00041


Training Progress:  19%|█▉        | 19/100 [00:07<00:27,  2.98it/s]

Epoch 19/100, Loss: 1.1753, grad norm: 0.5360667705535889, lr: 0.00040500000000000003


Training Progress:  20%|██        | 20/100 [00:07<00:26,  2.97it/s]

Epoch 20/100, Loss: 1.1838, grad norm: 0.5932585000991821, lr: 0.0004


Training Progress:  21%|██        | 21/100 [00:08<00:27,  2.93it/s]

Epoch 21/100, Loss: 1.1520, grad norm: 0.5894972681999207, lr: 0.000395


Training Progress:  22%|██▏       | 22/100 [00:08<00:26,  2.93it/s]

Epoch 22/100, Loss: 1.1352, grad norm: 0.5806944370269775, lr: 0.00039000000000000005


Training Progress:  23%|██▎       | 23/100 [00:08<00:26,  2.93it/s]

Epoch 23/100, Loss: 1.1270, grad norm: 0.7183179259300232, lr: 0.00038500000000000003


Training Progress:  24%|██▍       | 24/100 [00:09<00:25,  2.95it/s]

Epoch 24/100, Loss: 1.0836, grad norm: 0.5701566934585571, lr: 0.00038


Training Progress:  25%|██▌       | 25/100 [00:09<00:25,  2.97it/s]

Epoch 25/100, Loss: 1.0673, grad norm: 0.5711877346038818, lr: 0.000375


Training Progress:  26%|██▌       | 26/100 [00:09<00:25,  2.92it/s]

Epoch 26/100, Loss: 1.0852, grad norm: 0.7138945460319519, lr: 0.00037


Training Progress:  27%|██▋       | 27/100 [00:10<00:24,  2.93it/s]

Epoch 27/100, Loss: 1.0376, grad norm: 0.5935937762260437, lr: 0.000365


Training Progress:  28%|██▊       | 28/100 [00:10<00:24,  2.93it/s]

Epoch 28/100, Loss: 1.0368, grad norm: 0.6115938425064087, lr: 0.00035999999999999997


Training Progress:  29%|██▉       | 29/100 [00:10<00:24,  2.94it/s]

Epoch 29/100, Loss: 1.0376, grad norm: 0.7490789294242859, lr: 0.000355


Training Progress:  30%|███       | 30/100 [00:11<00:23,  2.95it/s]

Epoch 30/100, Loss: 1.0188, grad norm: 0.58741295337677, lr: 0.00035


Training Progress:  31%|███       | 31/100 [00:11<00:23,  2.92it/s]

Epoch 31/100, Loss: 1.0044, grad norm: 0.6657335758209229, lr: 0.000345


Training Progress:  32%|███▏      | 32/100 [00:11<00:23,  2.93it/s]

Epoch 32/100, Loss: 0.9904, grad norm: 0.6093658804893494, lr: 0.00034


Training Progress:  33%|███▎      | 33/100 [00:12<00:22,  2.94it/s]

Epoch 33/100, Loss: 0.9754, grad norm: 0.6680149435997009, lr: 0.000335


Training Progress:  34%|███▍      | 34/100 [00:12<00:22,  2.95it/s]

Epoch 34/100, Loss: 0.9651, grad norm: 0.7779239416122437, lr: 0.00033


Training Progress:  35%|███▌      | 35/100 [00:12<00:22,  2.95it/s]

Epoch 35/100, Loss: 0.9071, grad norm: 0.635871171951294, lr: 0.00032500000000000004


Training Progress:  36%|███▌      | 36/100 [00:13<00:21,  2.96it/s]

Epoch 36/100, Loss: 0.9114, grad norm: 0.7584870457649231, lr: 0.00032


Training Progress:  37%|███▋      | 37/100 [00:13<00:21,  2.93it/s]

Epoch 37/100, Loss: 0.8917, grad norm: 0.6525894403457642, lr: 0.000315


Training Progress:  38%|███▊      | 38/100 [00:13<00:21,  2.93it/s]

Epoch 38/100, Loss: 0.8810, grad norm: 0.6806576251983643, lr: 0.00031


Training Progress:  39%|███▉      | 39/100 [00:14<00:20,  2.93it/s]

Epoch 39/100, Loss: 0.9112, grad norm: 0.8276675939559937, lr: 0.000305


Training Progress:  40%|████      | 40/100 [00:14<00:20,  2.93it/s]

Epoch 40/100, Loss: 0.8587, grad norm: 0.7379283308982849, lr: 0.0003


Training Progress:  41%|████      | 41/100 [00:14<00:20,  2.94it/s]

Epoch 41/100, Loss: 0.8368, grad norm: 0.758834719657898, lr: 0.000295


Training Progress:  42%|████▏     | 42/100 [00:15<00:19,  2.94it/s]

Epoch 42/100, Loss: 0.8286, grad norm: 0.7046515941619873, lr: 0.00029


Training Progress:  43%|████▎     | 43/100 [00:15<00:19,  2.91it/s]

Epoch 43/100, Loss: 0.8371, grad norm: 0.76859450340271, lr: 0.000285


Training Progress:  44%|████▍     | 44/100 [00:15<00:19,  2.91it/s]

Epoch 44/100, Loss: 0.8316, grad norm: 0.7844902276992798, lr: 0.00028000000000000003


Training Progress:  45%|████▌     | 45/100 [00:16<00:18,  2.91it/s]

Epoch 45/100, Loss: 0.8014, grad norm: 0.9919225573539734, lr: 0.000275


Training Progress:  46%|████▌     | 46/100 [00:16<00:18,  2.91it/s]

Epoch 46/100, Loss: 0.7876, grad norm: 0.8102647066116333, lr: 0.00027


Training Progress:  47%|████▋     | 47/100 [00:16<00:18,  2.92it/s]

Epoch 47/100, Loss: 0.7944, grad norm: 0.872577428817749, lr: 0.00026500000000000004


Training Progress:  48%|████▊     | 48/100 [00:17<00:17,  2.93it/s]

Epoch 48/100, Loss: 0.7882, grad norm: 1.0337998867034912, lr: 0.00026000000000000003


Training Progress:  49%|████▉     | 49/100 [00:17<00:17,  2.92it/s]

Epoch 49/100, Loss: 0.7497, grad norm: 0.8379231095314026, lr: 0.000255


Training Progress:  50%|█████     | 50/100 [00:17<00:17,  2.92it/s]

Epoch 50/100, Loss: 0.7629, grad norm: 0.828416109085083, lr: 0.00025


Training Progress:  51%|█████     | 51/100 [00:18<00:17,  2.87it/s]

Epoch 51/100, Loss: 0.7471, grad norm: 0.8946866393089294, lr: 0.000245


Training Progress:  52%|█████▏    | 52/100 [00:18<00:17,  2.78it/s]

Epoch 52/100, Loss: 0.7387, grad norm: 0.9598793387413025, lr: 0.00024


Training Progress:  53%|█████▎    | 53/100 [00:19<00:16,  2.77it/s]

Epoch 53/100, Loss: 0.7344, grad norm: 0.9103481769561768, lr: 0.000235


Training Progress:  54%|█████▍    | 54/100 [00:19<00:16,  2.81it/s]

Epoch 54/100, Loss: 0.7339, grad norm: 0.8195675611495972, lr: 0.00023


Training Progress:  55%|█████▌    | 55/100 [00:19<00:15,  2.84it/s]

Epoch 55/100, Loss: 0.7063, grad norm: 0.8284441232681274, lr: 0.00022500000000000002


Training Progress:  56%|█████▌    | 56/100 [00:20<00:15,  2.85it/s]

Epoch 56/100, Loss: 0.6895, grad norm: 0.9184778332710266, lr: 0.00022


Training Progress:  57%|█████▋    | 57/100 [00:20<00:14,  2.88it/s]

Epoch 57/100, Loss: 0.6907, grad norm: 0.9521369338035583, lr: 0.000215


Training Progress:  58%|█████▊    | 58/100 [00:20<00:14,  2.89it/s]

Epoch 58/100, Loss: 0.6779, grad norm: 0.8893874287605286, lr: 0.00021


Training Progress:  59%|█████▉    | 59/100 [00:21<00:14,  2.89it/s]

Epoch 59/100, Loss: 0.6663, grad norm: 1.0947909355163574, lr: 0.000205


Training Progress:  60%|██████    | 60/100 [00:21<00:13,  2.90it/s]

Epoch 60/100, Loss: 0.6705, grad norm: 1.2344825267791748, lr: 0.0002


Training Progress:  61%|██████    | 61/100 [00:21<00:13,  2.91it/s]

Epoch 61/100, Loss: 0.6427, grad norm: 0.9753999710083008, lr: 0.00019500000000000002


Training Progress:  62%|██████▏   | 62/100 [00:22<00:13,  2.92it/s]

Epoch 62/100, Loss: 0.6698, grad norm: 1.297074794769287, lr: 0.00019


Training Progress:  63%|██████▎   | 63/100 [00:22<00:12,  2.91it/s]

Epoch 63/100, Loss: 0.6517, grad norm: 1.1463088989257812, lr: 0.000185


Training Progress:  64%|██████▍   | 64/100 [00:22<00:12,  2.92it/s]

Epoch 64/100, Loss: 0.6580, grad norm: 1.096678376197815, lr: 0.00017999999999999998


Training Progress:  65%|██████▌   | 65/100 [00:23<00:12,  2.91it/s]

Epoch 65/100, Loss: 0.6380, grad norm: 1.0079728364944458, lr: 0.000175


Training Progress:  66%|██████▌   | 66/100 [00:23<00:11,  2.92it/s]

Epoch 66/100, Loss: 0.6181, grad norm: 1.0824249982833862, lr: 0.00017


Training Progress:  67%|██████▋   | 67/100 [00:23<00:11,  2.93it/s]

Epoch 67/100, Loss: 0.6007, grad norm: 1.3051782846450806, lr: 0.000165


Training Progress:  68%|██████▊   | 68/100 [00:24<00:10,  2.94it/s]

Epoch 68/100, Loss: 0.6196, grad norm: 1.015985131263733, lr: 0.00016


Training Progress:  69%|██████▉   | 69/100 [00:24<00:10,  2.92it/s]

Epoch 69/100, Loss: 0.6318, grad norm: 1.1285070180892944, lr: 0.000155


Training Progress:  70%|███████   | 70/100 [00:24<00:10,  2.93it/s]

Epoch 70/100, Loss: 0.6122, grad norm: 0.9725406169891357, lr: 0.00015


Training Progress:  71%|███████   | 71/100 [00:25<00:09,  2.92it/s]

Epoch 71/100, Loss: 0.6077, grad norm: 1.143409013748169, lr: 0.000145


Training Progress:  72%|███████▏  | 72/100 [00:25<00:09,  2.93it/s]

Epoch 72/100, Loss: 0.6069, grad norm: 1.1408828496932983, lr: 0.00014000000000000001


Training Progress:  73%|███████▎  | 73/100 [00:25<00:09,  2.93it/s]

Epoch 73/100, Loss: 0.5852, grad norm: 1.286801815032959, lr: 0.000135


Training Progress:  74%|███████▍  | 74/100 [00:26<00:08,  2.93it/s]

Epoch 74/100, Loss: 0.5694, grad norm: 1.2028415203094482, lr: 0.00013000000000000002


Training Progress:  75%|███████▌  | 75/100 [00:26<00:08,  2.93it/s]

Epoch 75/100, Loss: 0.5810, grad norm: 1.3242664337158203, lr: 0.000125


Training Progress:  76%|███████▌  | 76/100 [00:26<00:08,  2.94it/s]

Epoch 76/100, Loss: 0.5566, grad norm: 1.083222508430481, lr: 0.00012


Training Progress:  77%|███████▋  | 77/100 [00:27<00:07,  2.94it/s]

Epoch 77/100, Loss: 0.5792, grad norm: 1.2287765741348267, lr: 0.000115


Training Progress:  78%|███████▊  | 78/100 [00:27<00:07,  2.94it/s]

Epoch 78/100, Loss: 0.5823, grad norm: 1.0556411743164062, lr: 0.00011


Training Progress:  79%|███████▉  | 79/100 [00:27<00:07,  2.94it/s]

Epoch 79/100, Loss: 0.5649, grad norm: 0.9160956740379333, lr: 0.000105


Training Progress:  80%|████████  | 80/100 [00:28<00:06,  2.94it/s]

Epoch 80/100, Loss: 0.5611, grad norm: 1.1066553592681885, lr: 0.0001


Training Progress:  81%|████████  | 81/100 [00:28<00:06,  2.94it/s]

Epoch 81/100, Loss: 0.5437, grad norm: 0.964350163936615, lr: 9.5e-05


Training Progress:  82%|████████▏ | 82/100 [00:29<00:06,  2.94it/s]

Epoch 82/100, Loss: 0.5731, grad norm: 1.0591192245483398, lr: 8.999999999999999e-05


Training Progress:  83%|████████▎ | 83/100 [00:29<00:05,  2.92it/s]

Epoch 83/100, Loss: 0.5410, grad norm: 1.133894443511963, lr: 8.5e-05


Training Progress:  84%|████████▍ | 84/100 [00:29<00:05,  2.90it/s]

Epoch 84/100, Loss: 0.5574, grad norm: 1.059548258781433, lr: 8e-05


Training Progress:  85%|████████▌ | 85/100 [00:30<00:05,  2.90it/s]

Epoch 85/100, Loss: 0.5498, grad norm: 1.1908363103866577, lr: 7.5e-05


Training Progress:  86%|████████▌ | 86/100 [00:30<00:04,  2.92it/s]

Epoch 86/100, Loss: 0.5341, grad norm: 1.0931476354599, lr: 7.000000000000001e-05


Training Progress:  87%|████████▋ | 87/100 [00:30<00:04,  2.93it/s]

Epoch 87/100, Loss: 0.5560, grad norm: 1.2003991603851318, lr: 6.500000000000001e-05


Training Progress:  88%|████████▊ | 88/100 [00:31<00:04,  2.94it/s]

Epoch 88/100, Loss: 0.5429, grad norm: 1.0784202814102173, lr: 6e-05


Training Progress:  89%|████████▉ | 89/100 [00:31<00:03,  2.95it/s]

Epoch 89/100, Loss: 0.5375, grad norm: 1.1330125331878662, lr: 5.5e-05


Training Progress:  90%|█████████ | 90/100 [00:31<00:03,  2.95it/s]

Epoch 90/100, Loss: 0.5439, grad norm: 1.1624584197998047, lr: 5e-05


Training Progress:  91%|█████████ | 91/100 [00:32<00:03,  2.95it/s]

Epoch 91/100, Loss: 0.5356, grad norm: 1.1482391357421875, lr: 4.4999999999999996e-05


Training Progress:  92%|█████████▏| 92/100 [00:32<00:02,  2.96it/s]

Epoch 92/100, Loss: 0.5354, grad norm: 1.2584677934646606, lr: 4e-05


Training Progress:  93%|█████████▎| 93/100 [00:32<00:02,  2.95it/s]

Epoch 93/100, Loss: 0.5323, grad norm: 1.1905620098114014, lr: 3.5000000000000004e-05


Training Progress:  94%|█████████▍| 94/100 [00:33<00:02,  2.95it/s]

Epoch 94/100, Loss: 0.5086, grad norm: 1.1136118173599243, lr: 3e-05


Training Progress:  95%|█████████▌| 95/100 [00:33<00:01,  2.95it/s]

Epoch 95/100, Loss: 0.5075, grad norm: 1.0383167266845703, lr: 2.5e-05


Training Progress:  96%|█████████▌| 96/100 [00:33<00:01,  2.95it/s]

Epoch 96/100, Loss: 0.5545, grad norm: 1.2329591512680054, lr: 2e-05


Training Progress:  97%|█████████▋| 97/100 [00:34<00:01,  2.94it/s]

Epoch 97/100, Loss: 0.5128, grad norm: 1.0316741466522217, lr: 1.5e-05


Training Progress:  98%|█████████▊| 98/100 [00:34<00:00,  2.94it/s]

Epoch 98/100, Loss: 0.5146, grad norm: 0.9778032898902893, lr: 1e-05


Training Progress:  99%|█████████▉| 99/100 [00:34<00:00,  2.94it/s]

Epoch 99/100, Loss: 0.5316, grad norm: 1.1549869775772095, lr: 5e-06


Training Progress: 100%|██████████| 100/100 [00:35<00:00,  2.94it/s]

Epoch 100/100, Loss: 0.5379, grad norm: 1.066272497177124, lr: 0.0


Training Progress: 100%|██████████| 100/100 [00:35<00:00,  2.83it/s]




In [18]:
fine_tuned_model = hybrid_model.model.inference_model

hybrid_model.set_fhe_mode("disable")

In [19]:
fine_tuned_model.disable_adapter_layers()

# Example usage
prompt = "What is FHE ?"
generated_text = generate_text(prompt, fine_tuned_model, tokenizer)
print(generated_text)

fine_tuned_model.enable_adapter_layers()

What is FHE?

The FH (F) is a mathematical equation with respect to which a physical body is physically connected.
 (


In [20]:
# Example usage
prompt = "Who's Barack Obama ?"
generated_text = generate_text(prompt, fine_tuned_model, tokenizer)
print(generated_text)

Who's Barack Obama?
Obama is a highly influential figure in the international community, and he is the most influential person to date in shaping the Internet


In [21]:
# Example usage
prompt = "What is FHE ?"
generated_text = generate_text(prompt, fine_tuned_model, tokenizer)
print(generated_text)

What is FHE?

FHE is a cryptographic technique that allows the interpretation of computable data and generate encrypted data without public keystroke.


In [22]:
fine_tuned_model.disable_adapter_layers()

# Example usage
prompt = "What is FHE ?"
generated_text = generate_text(prompt, fine_tuned_model, tokenizer)
print(generated_text)

What is FHE?

FHE is a free-to-use application for users to create free apps that will give them a choice of


In [23]:
fine_tuned_model.enable_adapter_layers()

# Example usage
prompt = "What is FHE ?"
generated_text = generate_text(prompt, fine_tuned_model, tokenizer)
print(generated_text)

What is FHE? It is a cryptographic technique that allows the computations to be performed on arbitrary data sets. It can be used to perform comput


In [24]:
def print_weights_and_size(model, print_detail=False):
    total_weights = 0
    for name, param in model.named_parameters():
        total_weights += param.numel()
        if print_detail:
            print(name, param.numel())

    print(f"Total number of weights: {total_weights}")

    return total_weights

In [25]:
total_weights_size = print_weights_and_size(hybrid_model.model)

Total number of weights: 124587264


In [26]:
path = Path("gpt2_lora_finetuned_hybrid_deployment")

if path.is_dir() and any(path.iterdir()):
    shutil.rmtree(path)

hybrid_model.save_and_clear_private_info(path)

In [27]:
total_weights_size_private = print_weights_and_size(hybrid_model.model)

Total number of weights: 39569664


In [28]:
print(
    "Weights removed: "
    f"{(total_weights_size - total_weights_size_private) / total_weights_size * 100:.2f} %"
)

Weights removed: 68.24 %
