In [None]:
#This notebook fine-tunes Llama-3.1-8B-Instruct to explain medical diagnoses to patients using the Med-EASi dataset.

!pip install -q "transformers>=4.43.0" peft bitsandbytes accelerate datasets evaluate wandb huggingface_hub


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m39.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from huggingface_hub import login
from getpass import getpass

HF_TOKEN = getpass("Enter your Hugging Face token (with access to Llama-3.1-8B): ")
login(HF_TOKEN)


Enter your Hugging Face token (with access to Llama-3.1-8B): ··········


In [3]:
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import wandb

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
ds = load_dataset("cbasu/Med-EASi")
print(ds)
print("Train sample:\n", ds["train"][0])

README.md: 0.00B [00:00, ?B/s]

train.csv: 0.00B [00:00, ?B/s]

validation.csv: 0.00B [00:00, ?B/s]

test.csv: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/1397 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/196 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/300 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['Expert', 'Simple', 'Annotation', 'sim', 'sentence_sim', 'compression', 'expert_fk_grade', 'expert_ari', 'layman_fk_grade', 'layman_ari', 'umls_expert', 'umls_layman', 'expert_terms', 'layman_terms', 'idx'],
        num_rows: 1397
    })
    validation: Dataset({
        features: ['Expert', 'Simple', 'Annotation', 'sim', 'sentence_sim', 'compression', 'expert_fk_grade', 'expert_ari', 'layman_fk_grade', 'layman_ari', 'umls_expert', 'umls_layman', 'expert_terms', 'layman_terms', 'idx'],
        num_rows: 196
    })
    test: Dataset({
        features: ['Expert', 'Simple', 'Annotation', 'sim', 'sentence_sim', 'compression', 'expert_fk_grade', 'expert_ari', 'layman_fk_grade', 'layman_ari', 'umls_expert', 'umls_layman', 'expert_terms', 'layman_terms', 'idx'],
        num_rows: 300
    })
})
Train sample:
 {'Expert': '75-90 % of the affected people have mild intellectual disability.', 'Simple': "People with syndromic intellectual disabi

In [None]:
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    use_fast=True,
    token=HF_TOKEN,
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    token=HF_TOKEN,
)

model.config.use_cache = False


tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

In [None]:
model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


trainable params: 20,971,520 || all params: 8,051,232,768 || trainable%: 0.2605


In [None]:
MAX_LENGTH = 256 #change later if needed

SYSTEM_PROMPT = (
    "You are a medical explainer. Your job is to explain medical diagnoses "
    "in clear, patient-friendly English. Use short sentences, avoid medical "
    "jargon when possible, and do not add information that is not in the "
    "original text.\n\n"
)

def build_training_prompt(expert_text: str) -> str:
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {
            "role": "user",
            "content": (
                "Explain the following medical diagnosis to a patient in simple, "
                "everyday language:\n\n"
                f"{expert_text}\n\n"
                "Patient-friendly explanation:"
            ),
        },
    ]
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    return prompt

def preprocess_function(batch):
    experts = batch["Expert"]
    simples = batch["Simple"]

    prompts = [build_training_prompt(e) for e in experts]
    targets = simples

    full_texts = [p + t + tokenizer.eos_token for p, t in zip(prompts, targets)]

    tokenized = tokenizer(
        full_texts,
        max_length=MAX_LENGTH,
        truncation=True,
        padding="max_length",
    )

    input_ids = tokenized["input_ids"]
    labels = []
    attention_masks = tokenized["attention_mask"]

    for i, full_ids in enumerate(input_ids):
        #tokenize prompt to get number of prompt tokens
        prompt_ids = tokenizer(
            prompts[i],
            add_special_tokens=False,
            max_length=MAX_LENGTH,
            truncation=True,
        )["input_ids"]

        label_ids = full_ids.copy()
        prompt_len = min(len(prompt_ids), len(label_ids))

        #we only want loss on answer tokens, so mask out the prompt tokens
        label_ids[:prompt_len] = [-100] * prompt_len
        labels.append(label_ids)

    tokenized["labels"] = labels
    return tokenized

cols_to_remove = ds["train"].column_names

tokenized_ds = ds.map(
    preprocess_function,
    batched=True,
    remove_columns=cols_to_remove,
)

tokenized_ds


Map:   0%|          | 0/1397 [00:00<?, ? examples/s]

Map:   0%|          | 0/196 [00:00<?, ? examples/s]

Map:   0%|          | 0/300 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1397
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 196
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 300
    })
})

In [None]:
def data_collator(features):
    batch = {}
    for key in ["input_ids", "attention_mask", "labels"]:
        batch[key] = torch.stack(
            [torch.tensor(f[key]) for f in features]
        )
    return batch

In [None]:
wandb.init(
    project="med-easi-llama3",
    name="llama3_1_8b_model",
)

training_args = TrainingArguments(
    output_dir="checkpoints/llama3_1_8b_med_easi",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=2,
    num_train_epochs=2,
    learning_rate=2e-4,
    warmup_ratio=0.05,
    logging_steps=20,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    fp16=True,
    optim="paged_adamw_8bit",
    report_to=["wandb"],
)


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

train_result = trainer.train()
print("Training finished.")

  trainer = Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 128009, 'pad_token_id': 128009}.
  return fn(*args, **kwargs)


Epoch,Training Loss,Validation Loss
1,0.2264,0.376193
2,0.1339,0.403679


  return fn(*args, **kwargs)


Training finished.


In [None]:
HF_TOKEN = getpass("Enter token with write access to upload model to HF repo")
login(HF_TOKEN)

In [None]:
repo_id = "smedara/llama3-med-easi-explainer" #REMINDER: change to your HF repo ID
model.push_to_hub(repo_id=repo_id, private=True)
tokenizer.push_to_hub(repo_id)

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...adapter_model.safetensors:   1%|          |  555kB / 83.9MB            

README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...mpxkz5ltdl/tokenizer.json: 100%|##########| 17.2MB / 17.2MB            

CommitInfo(commit_url='https://huggingface.co/smedara/llama3-med-easi-explainer/commit/f1180ff60d998d8e1d2235b9de5b8a503c38bd6b', commit_message='Upload tokenizer', commit_description='', oid='f1180ff60d998d8e1d2235b9de5b8a503c38bd6b', pr_url=None, repo_url=RepoUrl('https://huggingface.co/smedara/llama3-med-easi-explainer', endpoint='https://huggingface.co', repo_type='model', repo_id='smedara/llama3-med-easi-explainer'), pr_revision=None, pr_num=None)