In [1]:
# ! pip install transformers trl accelerate torch bitsandbytes peft datasets -qU

In [2]:
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import load_dataset
from huggingface_hub import login
from datetime import datetime

from accelerate import FullyShardedDataParallelPlugin, Accelerator
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, PeftModel

In [3]:
# check your torch version make sure it is GPU version

from torch.cuda import is_available
print(is_available())
print(torch.version.cuda)
print(torch.__version__)

# set up accelerator may not be necessary for QLoRA
#device = 'cuda:0'
#device = torch.device(device)

fsdp_plugin = FullyShardedDataParallelPlugin(
    state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
    optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)

True
12.1
2.4.1+cu121


In [4]:
# load in datasets
train_dataset = load_dataset('jonathanli/law-stack-exchange', split='train')

train_dataset = train_dataset.shuffle(seed=1234)
# print to see the data point
print(train_dataset[134]["body"])
print(train_dataset[134]["text_label"])
print(train_dataset[134]["title"])

README.md:   0%|          | 0.00/407 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Drink and riding (bicycle) offence in Germany
<p>I am charged with a criminal offence in a drink-and-riding-a-bicycle case. I have received a letter mentioning the alcohol content 1.66 promile. I have accepted the charges and agreed to pay the fine.</p>
<ol>
<li>What are the possible fines for this offence?</li>
<li>Is there any possible legal effect on my resident permit?</li>
<li>How long will the record stay in the register?</li>
</ol>
<p>I do not have a driving license.</p>

[{'answer_id': 55185, 'body': '<ol>\n<li><p>The fine would typically be around your monthly income.<br />\n<sub>Legal basis: drunk driving per §316 StGB is punishable by up to one year in prison, but per §47 and §40 StGB short sentences are converted to a fine that depends on your daily net income (Tagessätze).</sub></p>\n</li>\n<li><p>There is likely no impact. Despite this being a crime, it will not appear in your criminal record that some employers need.<br />\n<sub>Legal basis: Per §32 BZRG the criminal rec

In [5]:
# login()

In [6]:
# load in the model
base_model_id = "mistralai/Mistral-7B-Instruct-v0.2"
# quantize to save memeory
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
#############################
####### Tokenization  #######
#############################
tokenizer = AutoTokenizer.from_pretrained(
    base_model_id,
    model_max_length=512,
    padding_side="left",
    add_eos_token=True)
tokenizer.pad_token = tokenizer.eos_token

def tokenize(prompt):
    result = tokenizer(
        prompt,
        truncation=True,
        max_length=512,
        padding="max_length",
    )
    result["labels"] = result["input_ids"].copy()
    return result

def generate_and_tokenize_prompt(data_point):
    bos_token = "<s>"
    original_system_message = "Below is a question that describes a problem about law. Write a response that could become the title of the problem."
    system_message = "Use the provided input to create a title that could have been used to generate the response with an LLM."
    response = data_point["title"].replace(original_system_message, "").replace("\n\n### Title\n", "").replace("\n### Body\n", "").strip()
    input = data_point["body"]
    eos_token = "</s>"

    full_prompt = ""
    full_prompt += bos_token
    full_prompt += "### Instruction:"
    full_prompt += "\n" + system_message
    full_prompt += "\n\n### Input:"
    full_prompt += "\n" + input
    full_prompt += "\n\n### Response:"
    full_prompt += "\n" + response
    full_prompt += eos_token
    # full_prompt =f"""Given a question about punishments when breaking laws. Answer with the correct punishments.

    #                 ### Question:
    #                 {data_point["question_body"]}

    #                 ### Answer:
    #                 {data_point["answers"]}
    #                """
    return tokenize(full_prompt)
####################################
####### End of Tokenization  #######
####################################

In [8]:
tokenized_train_dataset = train_dataset.map(generate_and_tokenize_prompt)

# print to see examples
print(tokenized_train_dataset[4]['input_ids'])
print(len(tokenized_train_dataset[4]['input_ids']))

[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 27332, 3133, 3112, 28747, 13, 8543, 272, 3857, 2787, 298, 2231, 264, 3941, 369, 829, 506, 750, 1307, 298, 8270, 272, 2899, 395, 396, 16704, 28755, 28723, 13, 13, 27332, 11232, 28747, 13, 28737, 837, 7101, 719, 288, 356, 264, 2572, 12181, 28733, 23158, 286, 21732, 2488, 298, 1950, 264, 5181, 28725, 395, 799, 9909, 28723, 415, 4099, 302, 272, 2696, 622, 347, 6431, 28723, 13, 13, 13, 28737, 837, 3653, 456, 4993, 395, 586, 25325, 28725, 298, 2231, 264, 5181, 354, 5088, 297, 272, 8932, 1834, 28723, 1136, 1259, 28725, 315, 506, 6140, 

In [9]:
# setup LoRA
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    bias="none",
    lora_dropout=0.05,  # Conventional
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
print_trainable_parameters(model)

trainable params: 21260288 || all params: 3773331456 || trainable%: 0.5634354746703705


In [10]:
# Apply the accelerator. You can comment this out to remove the accelerator.
model = accelerator.prepare_model(model)
# print(model)

In [11]:
import numpy
import transformers

print("Numpy version:", numpy.__version__)
print("Transformers version:", transformers.__version__)

Numpy version: 1.20.3
Transformers version: 4.44.2


In [12]:
# Training
project = "law-stack-exchange"
base_model_name = "mistral"
run_name = base_model_name + "-" + project
output_dir = "./" + run_name

tokenizer.pad_token = tokenizer.eos_token

trainer = transformers.Trainer(
    model=model,
    train_dataset=tokenized_train_dataset,
    args=transformers.TrainingArguments(
        output_dir=output_dir,
        warmup_steps=5,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        max_steps=70000,
        learning_rate=2.5e-5, # Want about 10x smaller than the Mistral learning rate
        logging_steps=1000,
        bf16=True,
        optim="paged_adamw_8bit",
        logging_dir="./logs",        # Directory for storing logs
        save_strategy="steps",       # Save the model checkpoint every logging step
        save_steps=1000,               # Save checkpoints every 5000 steps
        # evaluation_strategy="steps", # Evaluate the model every logging step
        # eval_steps=50,               # Evaluate and save checkpoints every 50 steps
        do_eval=False,                # Perform evaluation at the end of training
        # report_to="wandb",           # Comment this out if you don't want to use weights & baises
        run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}"          # Name of the W&B run (optional)
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train('C:/Users/User\Desktop/SideProject/content/drive/MyDrive/mistral-law-stack-exchange/checkpoint-5000')

max_steps is given, it will override any value given in num_train_epochs
  torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
	logging_steps: 1000 (from args) != 5000 (from trainer_state.json)
	save_steps: 1000 (from args) != 5000 (from trainer_state.json)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\User\_netrc


  0%|          | 0/70000 [00:00<?, ?it/s]

  checkpoint_rng_state = torch.load(rng_file)
  return fn(*args, **kwargs)
  attn_output = torch.nn.functional.scaled_dot_product_attention(
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
