In [1]:
import math
import shutil
from pathlib import Path

import torch
from custom_module import CustomConv1D
from peft import LoraConfig, TaskType, get_peft_model
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Conv1D,
    DataCollatorForLanguageModeling,
    TextDataset,
    Trainer,
    TrainingArguments,
)

from concrete.ml.torch.hybrid_model import HybridFHEModel

SEED = 0
torch.manual_seed(SEED)
torch.use_deterministic_algorithms(True)

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# FREEZE WEIGHTS
for param in model.parameters():
    param.requires_grad = False

In [2]:
def generate_text(prompt, model, tokenizer, max_length=30, fhe="disable"):
    # Encode the input prompt
    inputs = tokenizer.encode_plus(prompt, return_tensors="pt")

    attention_mask = inputs["attention_mask"]

    # Generate text
    output = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=attention_mask,
        max_length=max_length,
        num_return_sequences=1,
        no_repeat_ngram_size=2,
        top_k=50,
        top_p=0.95,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
    )

    # Decode the generated text
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text

In [3]:
# Example usage
prompt = "What is FHE ?"
generated_text = generate_text(prompt, model, tokenizer)
print(generated_text)

What is FHE? FH: A basic program that is used to calculate the height of an object, and then sets the minimum height to be


In [4]:
# Example usage
prompt = "Who's Barack Obama ?"
generated_text = generate_text(prompt, model, tokenizer)
print(generated_text)

Who's Barack Obama?

I think he's very much in a position to be president. I mean, he was nominated by his party.


In [5]:
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=4,
    lora_alpha=32,
    lora_dropout=0.05,
    fan_in_fan_out=True,
)

peft_model = get_peft_model(model, peft_config)

In [6]:
def replace_conv1d(module, module_index_to_skip=0):
    for name, child in module.named_children():
        if isinstance(child, Conv1D):

            # Skip the module if the index has not been reached, and decrement the index
            if module_index_to_skip >= 0:
                module_index_to_skip -= 1
            else:
                custom_linear = CustomConv1D(child.weight, bias=child.bias)
                setattr(module, name, custom_linear)
        else:
            module_index_to_skip = replace_conv1d(child, module_index_to_skip=module_index_to_skip)

    return module_index_to_skip


# Gradients of the first base layer that is used for fine-tuning are not needed. We
# therefore need to exclude the backward module from the remote_names since calibration
# won't get through it (which raises an issue with hybrid models)
replace_conv1d(peft_model, module_index_to_skip=0);

In [7]:
class LoraTraining(torch.nn.Module):
    def __init__(self, inference_model, gradient_accumulation_steps) -> None:
        super().__init__()

        self.inference_model = inference_model

        self.optimizer = None
        self.lr_scheduler = None

        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.max_grad_norm = None

        self.calibrate = False
        self.run_optimizer = False

    def update_training_parameters(self, optimizer, lr_scheduler, training_args):
        assert self.gradient_accumulation_steps == training_args.gradient_accumulation_steps

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.max_grad_norm = training_args.max_grad_norm

    def forward(self, inputs):
        # FIXME: handle multi-inputs in hybrid model
        x, y = inputs

        # some parts on server side
        outputs = self.inference_model(input_ids=x, labels=y)

        loss = outputs.loss
        loss = loss / self.gradient_accumulation_steps

        # Update gradients
        loss.backward()

        grad_norm = None
        if not self.calibrate and self.run_optimizer:
            assert self.optimizer is not None
            assert self.lr_scheduler is not None
            assert self.max_grad_norm is not None

            grad_norm = torch.nn.utils.clip_grad_norm_(
                self.inference_model.parameters(), max_norm=self.max_grad_norm, norm_type=2
            )

            self.optimizer.step()
            self.lr_scheduler.step()

            self.inference_model.zero_grad()

        # Clean gradients after calibration
        elif self.calibrate:
            self.inference_model.zero_grad()

        return (loss, grad_norm)

    def toggle_calibrate(self, enable: bool = True):
        self.calibrate = enable

    def toggle_run_optimizer(self, enable: bool = True):
        self.run_optimizer = enable

In [8]:
GRADIENT_ACCUMULATION_STEPS = 2

lora_training = LoraTraining(peft_model, GRADIENT_ACCUMULATION_STEPS)

In [9]:
BLOCK_SIZE = 128


def load_dataset(file_path, tokenizer):
    dataset = TextDataset(
        tokenizer=tokenizer,
        file_path=file_path,
        block_size=BLOCK_SIZE,
        cache_dir="cache_dataset",
    )
    return dataset


train_dataset = load_dataset("data_finetune/what_is_fhe.txt", tokenizer)

In [10]:
tokenizer.parallelism = False

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

EPOCHS = 100
PER_DEVICE_TRAIN_BATCH_SIZE = 4

training_args = TrainingArguments(
    output_dir="./checkpoints",
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    save_total_limit=1,
    use_cpu=True,
    learning_rate=5e-4,
    logging_strategy="epoch",
    optim="adamw_torch",
    seed=SEED,
    data_seed=SEED,
    weight_decay=0.0,
    warmup_steps=0,
    max_grad_norm=1.0,
)

In [11]:
trainer = Trainer(
    model=peft_model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
)

train_dataloader = trainer.get_train_dataloader()

len_dataloader = len(train_dataloader)
num_update_steps_per_epoch = len_dataloader // training_args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
max_steps = math.ceil(training_args.num_train_epochs * num_update_steps_per_epoch)

trainer.create_optimizer_and_scheduler(num_training_steps=max_steps)

In [12]:
lora_training.update_training_parameters(trainer.optimizer, trainer.lr_scheduler, training_args)

In [13]:
def get_remote_names(model):
    remote_names = []
    for name, module in model.named_modules():
        # Some gradients are not needed for fine-tuning, so need to exclude the backward module
        # from the remote_names since calibration won't get through it (which raises an issue with
        # hybrid models). We however still need to include the associated module's forward pass in
        # the hybrid model
        if isinstance(module, Conv1D):
            remote_names.append(name)

        elif isinstance(module, CustomConv1D):
            remote_names.append(name + ".forward_module")
            remote_names.append(name + ".backward_module")

    return remote_names


remote_names = get_remote_names(lora_training)

hybrid_model = HybridFHEModel(lora_training, module_names=remote_names)

In [14]:
input_tensor = torch.randint(0, 2, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE)) * (
    tokenizer.vocab_size - 1
)
label_tensor = torch.randint(0, 2, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE)) * (
    tokenizer.vocab_size - 1
)

inputset = (input_tensor, label_tensor)

In [16]:
hybrid_model.model.toggle_calibrate(enable=True)

hybrid_model.compile_model(
    inputset, n_bits=8, rounding_threshold_bits={"n_bits": 6, "method": "approximate"}, p_error=1e-5
)

hybrid_model.model.toggle_calibrate(enable=False)

In [None]:
def train_custom_model(hybrid_model, train_dataloader, training_args, fhe="disable"):
    device = "cpu"
    hybrid_model.model.to(device)

    # Training loop
    hybrid_model.model.inference_model.train()

    total_epochs = int(training_args.num_train_epochs)
    epoch_pbar = tqdm(total=total_epochs, desc="Training Progress", position=0)

    total_batched_samples = 0
    for epoch in range(total_epochs):
        total_loss = 0
        grad_norms = []

        steps_in_epoch = len(train_dataloader)
        for step, batch in enumerate(train_dataloader):
            total_batched_samples += 1

            batch = {k: v.to(device) for k, v in batch.items()}

            # Gradient accumulation
            is_last_batch_step = (
                steps_in_epoch <= training_args.gradient_accumulation_steps
                and (step + 1) == steps_in_epoch
            )
            accumulate_gradients = (
                total_batched_samples % training_args.gradient_accumulation_steps == 0
            )

            run_optimizer = is_last_batch_step or accumulate_gradients

            hybrid_model.model.toggle_run_optimizer(enable=run_optimizer)

            loss, grad_norm = hybrid_model((batch["input_ids"], batch["labels"]), fhe=fhe)

            total_loss += loss.item()

            if grad_norm is not None:
                grad_norms.append(grad_norm)

        # Get current learning rate
        current_lr = hybrid_model.model.lr_scheduler.get_last_lr()[0]

        # Get last grad norm
        current_grad_norm = grad_norms[-1]

        # Log epoch results
        print(
            f"Epoch {epoch + 1}/{training_args.num_train_epochs}, "
            f"Loss: {total_loss:.4f}, grad norm: {current_grad_norm}, lr: {current_lr}"
        )

        epoch_pbar.update(1)

    # Save model checkpoint
    if training_args.output_dir is not None:
        save_path = f"{training_args.output_dir}/checkpoint-{epoch + 1}"
        hybrid_model.model.inference_model.save_pretrained(save_path)

    epoch_pbar.close()

In [None]:
torch.manual_seed(SEED)

train_custom_model(hybrid_model, train_dataloader, training_args, fhe="disable")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Training Progress:  50%|█████     | 1/2 [02:29<02:29, 149.35s/it]

Epoch 1/2, Loss: 1.8521, grad norm: 0.29645875096321106, lr: 0.0001


Training Progress: 100%|██████████| 2/2 [04:58<00:00, 149.01s/it]

Epoch 2/2, Loss: 1.8171, grad norm: 0.21446043252944946, lr: 0.0


Training Progress: 100%|██████████| 2/2 [04:58<00:00, 149.21s/it]


In [None]:
fine_tuned_model = hybrid_model.model.inference_model

hybrid_model.set_fhe_mode("disable")

In [None]:
fine_tuned_model.disable_adapter_layers()

# Example usage
prompt = "What is FHE ?"
generated_text = generate_text(prompt, fine_tuned_model, tokenizer)
print(generated_text)

fine_tuned_model.enable_adapter_layers()

What is FHE?

He is
HE is not
.
 (This is a not. HE
I
It is NOT.


In [None]:
# Example usage
prompt = "Who's Barack Obama ?"
generated_text = generate_text(prompt, fine_tuned_model, tokenizer)
print(generated_text)

Who's Barack Obama? I have an idea for the Obama, I think that is I can't even imagine. But let me just say, it


In [None]:
# Example usage
prompt = "What is FHE ?"
generated_text = generate_text(prompt, fine_tuned_model, tokenizer)
print(generated_text)

What is FHE? I don't know. I'm just a big big fat fat.

I have no idea, but I do not


In [None]:
fine_tuned_model.disable_adapter_layers()

# Example usage
prompt = "What is FHE ?"
generated_text = generate_text(prompt, fine_tuned_model, tokenizer)
print(generated_text)

What is FHE? F1 was FH is always F I think Fhehe F he F He He is, He, is he,


In [None]:
fine_tuned_model.enable_adapter_layers()

# Example usage
prompt = "What is FHE ?"
generated_text = generate_text(prompt, fine_tuned_model, tokenizer)
print(generated_text)

What is FHE?

You're probably thinking of a few different ways
As far as I'm thinking about it's not really that different


In [None]:
def print_weights_and_size(model, print_detail=False):
    total_weights = 0
    for name, param in model.named_parameters():
        total_weights += param.numel()
        if print_detail:
            print(name, param.numel())

    print(f"Total number of weights: {total_weights}")

    return total_weights

In [None]:
total_weights_size = print_weights_and_size(hybrid_model.model)

Total number of weights: 124587264


In [None]:
path = Path("gpt2_lora_finetuned_hybrid_deployment")

if path.is_dir() and any(path.iterdir()):
    shutil.rmtree(path)

hybrid_model.save_and_clear_private_info(path)

In [None]:
total_weights_size_private = print_weights_and_size(hybrid_model.model)

Total number of weights: 39569664


In [None]:
print(
    f"Weights removed: {(total_weights_size - total_weights_size_private) / total_weights_size * 100:.2f} %"
)

Weights removed: 68.24 %
