In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import warnings
warnings.filterwarnings("ignore")

os.environ["TOKENIZERS_PARALLELISM"] = "true"

## 🧑‍🍳 Step 1: Grab the Recipe — Load Your Training Configuration

Every great dish starts with a recipe — and in this notebook, that recipe is your config file.

In this step, we’re using Hydra to load a YAML configuration that defines all the key ingredients and settings for your federated fine-tuning experiment. Think of it as pulling out the instruction card before cooking.

📦 What this does:
- Loads a configuration file (in this case, federated_7b.yml) using Hydra.
- Prints out the full config in a readable YAML format, thanks to OmegaConf.

In [12]:
from hydra import compose, initialize
from omegaconf import DictConfig, OmegaConf

def get_config(config_name: str, config_path: str = "../config/"):
    with initialize(config_path=config_path, version_base="1.1"):
        cfg = compose(config_name=config_name)

    return cfg

def print_config(config: DictConfig):
    print(OmegaConf.to_yaml(config))

In [11]:
cfg = get_config("federated_7b.yml")

In [13]:
print_config(cfg)

dataset:
  name: medalpaca/medical_meadow_medical_flashcards
model:
  name: mistralai/Mistral-7B-v0.1
  quantization: 4
  gradient_checkpointing: true
  use_fast_tokenizer: false
  lora:
    peft_lora_r: 16
    peft_lora_alpha: 64
    target_modules:
    - q_proj
    - v_proj
train:
  num_rounds: ${flower.num_rounds}
  save_every_round: 5
  learning_rate_max: 5.0e-05
  learning_rate_min: 1.0e-06
  seq_length: 512
  padding_side: left
  evaluate_split: true
  training_arguments:
    output_dir: null
    learning_rate: null
    per_device_train_batch_size: 16
    gradient_accumulation_steps: 1
    logging_steps: 10
    num_train_epochs: 3
    max_steps: 10
    report_to: null
    save_steps: 1000
    save_total_limit: 10
    gradient_checkpointing: ${model.gradient_checkpointing}
    lr_scheduler_type: constant
client_resources:
  num_cpus: 8
  num_gpus: 1.0
dp:
  noise_mult: 0.02
  clip_norm: 0.5
flower:
  num_clients: 20
  num_rounds: 200
  fraction_fit: 0.8
  client_resources:
    num

## 🧑‍🍳 Step 2: Inspect Your Ingredients — Visualize the Dataset Partitions

Now that we have our recipe config, let’s take a look at the ingredients we’ll be cooking with.

In this step, we load a federated dataset using flwr_datasets and plot how the training data is split across different partitions (or “clients”). Think of this as checking how much of each ingredient each chef (or server) gets!

📝 What this does:
- Loads one partition to trigger dataset setup.
- Determines how many partitions (or clients) we’re working with.
- Plots a bar chart to show how many samples are in each client’s slice of the dataset.

In [6]:
from flwr_datasets import FederatedDataset
import matplotlib.pyplot as plt

def visualize_partitions(fed_dataset: FederatedDataset):
    _ = fed_dataset.load_partition(0)
    num_partitions = fed_dataset.partitioners['train'].num_partitions
    
    plt.bar(range(num_partitions), [len(fed_dataset.load_partition(i)) for i in range(num_partitions)])
    plt.xticks(range(num_partitions))
    plt.xlabel("Partition ID")
    plt.ylabel("Number of examples")
    plt.title(f"IID partitioning into {num_partitions} partitions")

## Model functions

In [None]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import (
    PeftModel,
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
)
from peft.utils import prepare_model_for_kbit_training
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer

def get_model(model_cfg: DictConfig):
    """Load model with appropiate quantization config and
    other optimizations."""

    use_cuda = torch.cuda.is_available()
    quantization_config = None
    model_name = model_cfg.name
    if use_cuda:
        if model_cfg.quantization == 4:
            quantization_config = BitsAndBytesConfig(load_in_4bit=True)
        elif model_cfg.quantization == 8:
            quantization_config = BitsAndBytesConfig(load_in_8bit=True)
        else:
            raise ValueError(
                f"Use 4-bit or 8-bit quantization. You passed: {model_cfg.quantization}/"
            )

        model_name = model_cfg.name

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=quantization_config,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
    )

    if use_cuda:
        model = prepare_model_for_kbit_training(
            model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
        )

    target_modules = model_cfg.lora.target_modules
    if target_modules:
        target_modules = list(target_modules)
    peft_config = LoraConfig(
        r=model_cfg.lora.peft_lora_r,
        lora_alpha=model_cfg.lora.peft_lora_alpha,
        lora_dropout=0.075,
        task_type="CAUSAL_LM",
        target_modules=target_modules,
    )

    peft_model = get_peft_model(model, peft_config)
    if not (use_cuda):
        peft_model.enable_input_require_grads()

    if model_cfg.gradient_checkpointing:
        model.config.use_cache = False

    return peft_model

In [None]:
def compute_communication_costs(config, comm_bw_mbps: float = 20):
    model = get_model(config.model)

    trainable, all_parameters = model.get_nb_trainable_parameters()

    total_size = 4*all_parameters/(1024**2)
    trainable_size = 4*trainable/(1024**2)

    upload_time_total = total_size/(comm_bw_mbps/8)
    upload_time_finetune = trainable_size/(comm_bw_mbps/8)
    
    print(f"Full model:\n\t{all_parameters/1e6:.3f} M parameters\n\t{total_size:.2f} MB --> upload in {upload_time_total:.2f}s @ {comm_bw_mbps}Mbps")
    print(f"Finetuned model:\n\t{trainable/1e6:.3f} M parameters\n\t{trainable_size:.2f} MB --> upload in {upload_time_finetune:.2f}s @ {comm_bw_mbps}Mbps")
    # print(f"In a {comm_bw_mbps} Mbps channel --> {}")

    num_rounds = config.flower.num_rounds
    num_clients_per_round = int(config.flower.num_clients * config.flower.fraction_fit)
    print(f"Federated Learning setting: "
          f"\n\tNumber of rounds: {num_rounds}"
          f"\n\tNumber of clients per round: {num_clients_per_round}")
    
    print(f"-----------------------------------------------")
    print(f"Total Communication costs (Full model): {2*num_rounds*num_clients_per_round*total_size/1024:.1f} GB")
    print(f"Total Communication costs (Finetuning): {2*num_rounds*num_clients_per_round*trainable_size} MB")
    print(f"Communication savings: {all_parameters/trainable:.1f}x")