In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import Dataset

# Load LLaMA model and tokenizer from Hugging Face
model_name = "huggingface/llama"  # Replace with actual LLaMA model name
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Assume `data` is a list of tuples (word, plain_explanation, formal_definition)
data = [
    ("Hypertension", "High blood pressure", "A condition in which the force of the blood against the artery walls is too high."),
    # Add more examples here...
]

# Prepare dataset for training
def format_example(example):
    # Format the prompt for the model
    word, plain_explanation, formal_definition = example
    input_text = f"Explain in plain language: {word}. Formal definition: {formal_definition}"
    target_text = plain_explanation
    return {"input_text": input_text, "target_text": target_text}

# Convert data into a Dataset and apply formatting
dataset = Dataset.from_dict({"data": data})
dataset = dataset.map(lambda x: format_example(x["data"]), remove_columns=["data"])

# Tokenize the input and target texts
def tokenize_function(example):
    inputs = tokenizer(example["input_text"], padding="max_length", truncation=True, max_length=256)
    targets = tokenizer(example["target_text"], padding="max_length", truncation=True, max_length=128)
    inputs["labels"] = targets["input_ids"]
    return inputs

tokenized_dataset = dataset.map(tokenize_function, batched=True)

# Training arguments
training_args = TrainingArguments(
    output_dir="./llama-medical-plain-language",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    save_steps=10_000,
    save_total_limit=2,
    logging_dir="./logs",
    logging_steps=500,
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_dataset,  # Ideally, use a separate validation set
)

# Fine-tune the model
trainer.train()

# Save the fine-tuned model
trainer.save_model("./llama-medical-plain-language")
tokenizer.save_pretrained("./llama-medical-plain-language")
