In [2]:
%%capture
%pip install -U transformers
%pip install -U datasets
%pip install -U accelerate
%pip install -U peft
%pip install -U trl
%pip install -U bitsandbytes
%pip install -U wandb

In [3]:
from transformers import(
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging
)

from peft import(
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model
)

import os, torch, wandb
from datasets import load_dataset
from trl import SFTTrainer, setup_chat_format
from huggingface_hub import login

In [4]:
login(token = "HF_Token")

In [5]:
wandb.login(key = "wandb_token")

run = wandb.init(
    project = "Fine-tune Gemma2b-it on Legal qna",
    job_type = "training",
    anonymous = "allow"
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mabhinav0231[0m ([33mabhinav0231-krmangalam[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
base_model = "google/gemma-2-2b-it"
dataset_name = "viber1/indian-law-dataset"
new_model = "Gemma-2b-it-legal-assistant"

In [7]:
if torch.cuda.get_device_capability()[0] >= 8:
  !pip install -qqq flash-attn
  torch_dtype = torch.bfloat16
  attn_implementation = "flash_attention_2"
else:
  torch_dtype = torch.float16
  attn_implementation = "eager"

In [8]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float32,
    bnb_4bit_use_double_quant=True,
)

In [9]:
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)

tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

In [10]:
import bitsandbytes as bnb

def find_all_linear_names(model):
  cls = bnb.nn.Linear4bit
  lora_module_names = set()
  for name, module in model.named_modules():
    if isinstance(module, cls):
      names = name.split(".")
      lora_module_names.add(names[0] if len(names) == 1 else names[-1])
  if "lm_head" in lora_module_names:
    lora_module_names.remove("lm_head")
  return list(lora_module_names)

modules = find_all_linear_names(model)

In [11]:
peft_config = LoraConfig(
    r = 16,
    lora_alpha = 32,
    lora_dropout = 0.05,
    bias = "none",
    task_type = "CAUSAL_LM",
    target_modules = modules
)

tokenizer.chat_template = None
model, tokenizer = setup_chat_format(model, tokenizer)
model = get_peft_model(model, peft_config)

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


In [19]:
dataset = load_dataset(dataset_name, split="all")

def format_chat_template(row):
  row_json = [{"role": "system", "content": "You are a helpful legal assistant who responds to question with respect to India law and contitution. Write a response that appropriately completes the request."},
              {"role": "user", "content": row["Instruction"]},
              {"role": "assistant", "content": row["Response"]}]

  row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
  return row

dataset = dataset.map(
    format_chat_template,
    num_proc=4
)

dataset

Map (num_proc=4):   0%|          | 0/24607 [00:00<?, ? examples/s]

Dataset({
    features: ['Instruction', 'Response', 'text'],
    num_rows: 24607
})

In [20]:
dataset["text"][2000]

'<|im_start|>system\nYou are a helpful legal assistant who responds to question with respect to India law and contitution. Write a response that appropriately completes the request.<|im_end|>\n<|im_start|>user\nWhat are the restrictions placed on the legislative powers of the Union and States regarding trade and commerce?<|im_end|>\n<|im_start|>assistant\nThe restrictions placed on the legislative powers of the Union and States regarding trade and commerce are as follows:\n\n1. Freedom of trade, commerce, and intercourse within the territory of India (Article 301).\n2. Parliament has the power to impose restrictions on trade, commerce, and intercourse between states and between India and other countries (Article 302).\n3. Restrictions can be imposed on the legislative powers of the Union and States with regard to trade and commerce (Article 303).\n4. There may be restrictions on trade, commerce, and intercourse among states (Article 304).\n5. Existing laws and laws providing for state 

In [21]:
def tokenize_function(examples):
    # Tokenize the text field
    tokenized_inputs = tokenizer(
        examples["text"], 
        padding="max_length", 
        truncation=True, 
        max_length=512
    )
    
    # This creates the targets for the language model
    tokenized_inputs["labels"] = tokenized_inputs["input_ids"].copy()
    
    return tokenized_inputs

# Apply tokenization to create input_ids, attention_mask, etc.
tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["Instruction", "Response", "text"]  # Remove original columns
)

tokenized_dataset = tokenized_dataset.train_test_split(test_size=0.1)

Map:   0%|          | 0/24607 [00:00<?, ? examples/s]

In [23]:
from peft import prepare_model_for_kbit_training

# Prepare the model for training
model = prepare_model_for_kbit_training(model)

In [24]:
from peft import get_peft_model

# Apply LoRA adapters
model = get_peft_model(model, peft_config)

In [25]:
for name, param in model.named_parameters():
    if "lora" in name:
        param.requires_grad = True

In [27]:
training_arguments = TrainingArguments(
    output_dir=new_model,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,
    optim="paged_adamw_8bit",
    num_train_epochs=1,
    eval_strategy="steps",
    eval_steps=0.1,
    logging_steps=20,
    warmup_steps=50,
    logging_strategy="steps",
    learning_rate=1e-4,
    fp16=False,
    bf16=False,
    group_by_length=True,
    report_to="wandb",
    save_strategy="steps",
    save_steps=0.2,
    save_total_limit=2,
    gradient_checkpointing=True,
)

trainer = SFTTrainer(
    model=model,
    args=training_arguments,
    train_dataset=dataset.select(range(900)),  # Using the original dataset, not tokenized
    eval_dataset=dataset.select(range(900, 1000)),
    tokenizer=tokenizer
)

  trainer = SFTTrainer(


Converting train dataset to ChatML:   0%|          | 0/900 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/900 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/900 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/900 [00:00<?, ? examples/s]

Converting eval dataset to ChatML:   0%|          | 0/100 [00:00<?, ? examples/s]

Applying chat template to eval dataset:   0%|          | 0/100 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/100 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/100 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [44]:
start_time = time.time()
model.config.use_cache = False
trainer.train()



Step,Training Loss,Validation Loss
6,No log,4.140848
12,No log,3.518024
18,No log,2.894062
24,2.604300,2.258404
30,2.604300,1.8904
36,2.604300,1.860423
42,1.836300,1.750801
48,1.836300,1.712162
54,1.836300,1.69817


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Tr

TrainOutput(global_step=56, training_loss=2.0257974011557445, metrics={'train_runtime': 1801.1849, 'train_samples_per_second': 0.5, 'train_steps_per_second': 0.031, 'total_flos': 3409914860734464.0, 'train_loss': 2.0257974011557445})

In [48]:
end_time = time.time()
train_time = (end_time - start_time) / 60
print(f"Total Train Time: {train_time} minutes.")

Total Train Time: 30.035316054026286 minutes.


In [50]:
wandb.finish()
model.config.use_cache = True

0,1
eval/loss,█▆▄▃▂▁▁▁▁
eval/mean_token_accuracy,▂▁▃▅▇███▇
eval/runtime,▂▅█▄▅▇▁▆▃
eval/samples_per_second,▇▄▁▅▄▂█▃▆
eval/steps_per_second,▇▄▁▅▄▂█▃▆
train/epoch,▁▂▃▃▄▄▅▆▆▇██
train/global_step,▁▂▃▃▄▄▅▆▆▇██
train/grad_norm,█▁
train/learning_rate,▁█
train/loss,█▁

0,1
eval/loss,1.69817
eval/mean_token_accuracy,0.67585
eval/runtime,28.5602
eval/samples_per_second,3.501
eval/steps_per_second,0.875
total_flos,3409914860734464.0
train/epoch,0.99556
train/global_step,56.0
train/grad_norm,1.27441
train/learning_rate,8e-05


In [49]:
trainer.model.save_pretrained(new_model)
trainer.model.push_to_hub(new_model, use_temp_dir=False)

adapter_model.safetensors:   0%|          | 0.00/83.1M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/abhinav0231/Gemma-2b-it-legal-assistant/commit/b4696f0b3499aa48bad9ec475ef168f129d9c4c0', commit_message='Upload model', commit_description='', oid='b4696f0b3499aa48bad9ec475ef168f129d9c4c0', pr_url=None, repo_url=RepoUrl('https://huggingface.co/abhinav0231/Gemma-2b-it-legal-assistant', endpoint='https://huggingface.co', repo_type='model', repo_id='abhinav0231/Gemma-2b-it-legal-assistant'), pr_revision=None, pr_num=None)

In [51]:
import shutil
import os

dir_path = "/kaggle/working/"

zip_path = "gemma2b_finetune.zip"

shutil.make_archive(zip_path.replace(".zip", ""), 'zip', dir_path)

# Download the zip file
from IPython.display import FileLink
FileLink(zip_path)
