In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model

# Load the pretrained LLaMA 2 model and tokenizer
model_name = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Define LoRA configuration
lora_config = LoraConfig(
    r=8,  # Low-rank dimension
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],  # Modules to apply LoRA
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
)

# Wrap the model with LoRA
lora_model = get_peft_model(model, lora_config)


In [None]:
from datasets import load_dataset

# Load domain-specific dataset
dataset = load_dataset("lex_glue", "ecthr_a")  # Example: Legal text dataset

# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)

tokenized_dataset = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_dataset["train"]
eval_dataset = tokenized_dataset["validation"]


In [None]:
from transformers import TrainingArguments, Trainer

# Define training arguments
training_args = TrainingArguments(
    output_dir="./llama2-legal-lora",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-4,  # Higher learning rate for LoRA
    num_train_epochs=3,
    logging_dir="./logs",
)

# Trainer for LoRA model
trainer = Trainer(
    model=lora_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)

# Train the model
trainer.train()

# Save the fine-tuned LoRA model
lora_model.save_pretrained("./llama2-legal-lora")
