# LegalEasy: Fine-tuning a Model to Produce Natural Language Descriptions of Legalese

Today we'll be looking at a specific example of fine-tuning [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2): Summarization in a particular style for a particular domain!

Let's start, as we always do, with installing our dependencies.



In [None]:
!pip install -qU transformers peft trl accelerate bitsandbytes datasets

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/190.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.9/190.9 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m155.3/155.3 kB[0m [31m22.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m280.0/280.0 kB[0m [31m35.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.0/105.0 MB[0m [31m16.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m53.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.8/79.8 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━

Next, let's set-up some data!

## Data

We'll be using the Legal Summarization [dataset](https://github.com/lauramanor/legal_summarization) from the paper [Plain English Summarization of Contracts](https://www.aclweb.org/anthology/W19-2201) today.

This dataset contains pairs in the following format:

- Original Text: A blob of legal text, think ToS
- Reference Summary: A short natural language summary of the legal text

We'll start by cloning the repository containing our data.

In [None]:
!git clone https://github.com/lauramanor/legal_summarization.git

Cloning into 'legal_summarization'...
remote: Enumerating objects: 31, done.[K
remote: Counting objects: 100% (6/6), done.[K
remote: Compressing objects: 100% (6/6), done.[K
remote: Total 31 (delta 2), reused 0 (delta 0), pack-reused 25[K
Receiving objects: 100% (31/31), 136.60 KiB | 188.00 KiB/s, done.
Resolving deltas: 100% (10/10), done.


Let's convert this into an expected format - in this case a list of `json` objects.

In [None]:
import json

jsonl_array = []

with open('legal_summarization/tldrlegal_v1.json') as f:
  data = json.load(f)
  for key, value in data.items():
    jsonl_array.append(value)

Now we can convert that into the desired format for our fine-tuning - a Hugging Face `Dataset`!

In [None]:
from datasets import Dataset, load_dataset

legal_dataset = Dataset.from_list(jsonl_array)

Let's see how many items we're working with in our dataset.

> NOTE: Keep in mind that this is a relatively simple example meant to showcase fine-tuning - in practice, we'd want to use somewhere in the neighbourhood of ~500-50,000 examples.

In [None]:
legal_dataset

Dataset({
    features: ['doc', 'id', 'original_text', 'reference_summary', 'title', 'uid'],
    num_rows: 85
})

Let's look at an example of our original text and summary!

In [None]:
print(f"Original Text: {legal_dataset[0]['original_text']}\n\nSummary: {legal_dataset[0]['reference_summary']}")

Original Text: welcome to the pokémon go video game services which are accessible via the niantic inc niantic mobile device application the app. to make these pokémon go terms of service the terms easier to read our video game services the app and our websites located at http pokemongo nianticlabs com and http www pokemongolive com the site are collectively called the services. please read carefully these terms our trainer guidelines and our privacy policy because they govern your use of our services.

Summary: hi.


Now, we mentioned earlier we were going to fine-tune Mistral-7B-Instruct-v0.2, which is important for our next step: Creating the instruction format.

Let's take a look at the instructions (so meta) to generate an instruction prompt from the [model card](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2#instruction-format)


> In order to leverage instruction fine-tuning, your prompt should be surrounded by [INST] and [/INST] tokens. The very first instruction should begin with a begin of sentence id. The next instructions should not. The assistant generation will be ended by the end-of-sentence token id.

Let's look at an example of how we might format our instruction - and then reproduce that in code.

```python
<s>[INST]Please convert the following legal content into a human-readable summary

[LEGAL_DOC]
welcome to the pokémon go video game services which are accessible via the niantic inc niantic mobile device application the app. to make these pokémon go terms of service the terms easier to read our video game services the app and our websites located at http pokemongo nianticlabs com and http www pokemongolive com the site are collectively called the services. please read carefully these terms our trainer guidelines and our privacy policy because they govern your use of our services.
[END_LEGAL_DOC]

[/INST]
hi.</s>
```

> NOTE: We're adding our own special tokens here in `[LEGAL_DOC]` and `[END_LEGAL_DOC]` to encourage the model to better understand our context, but these are not special tokens that are already understood by the model

In [None]:
INSTRUCTION_PROMPT_TEMPLATE = """\
<s>[INST]Please convert the following legal content into a human-readable summary

[LEGAL_DOC]
{LEGAL_TEXT}
[END_LEGAL_DOC]

[/INST]
"""

RESPONSE_TEMPLATE = """\
{NATURAL_LANGUAGE_SUMMARY}</s>"""

Now we can create a helper function that will convert our dataset row into the above prompt!

In [None]:
def create_instruction(sample, return_response=True):
  prompt = INSTRUCTION_PROMPT_TEMPLATE.format(LEGAL_TEXT=sample["original_text"])

  if return_response:
    prompt += RESPONSE_TEMPLATE.format(NATURAL_LANGUAGE_SUMMARY=sample["reference_summary"])

  return prompt

Let's try it out!

In [None]:
create_instruction(legal_dataset[0])

'<s>[INST]Please convert the following legal content into a human-readable summary\n\n[LEGAL_DOC]\nwelcome to the pokémon go video game services which are accessible via the niantic inc niantic mobile device application the app. to make these pokémon go terms of service the terms easier to read our video game services the app and our websites located at http pokemongo nianticlabs com and http www pokemongolive com the site are collectively called the services. please read carefully these terms our trainer guidelines and our privacy policy because they govern your use of our services.\n[END_LEGAL_DOC]\n\n[/INST]\nhi.</s>'

We'll partition our dataset so we can test some of the outputs after we've completed our training.

In [None]:
prepared_legal_dataset = legal_dataset.train_test_split(test_size=0.1)

In [None]:
prepared_legal_dataset

DatasetDict({
    train: Dataset({
        features: ['doc', 'id', 'original_text', 'reference_summary', 'title', 'uid'],
        num_rows: 76
    })
    test: Dataset({
        features: ['doc', 'id', 'original_text', 'reference_summary', 'title', 'uid'],
        num_rows: 9
    })
})

## Loading Our Model

Now we can move onto loading our model!

We're going to be dependent on two major technologies to allow us to train our model with <=16GB GPU RAM.

1. Quantization
2. LoRA

> NOTE: We've done some events on [LoRA](https://www.youtube.com/watch?v=kV8yXIUC5_4&list=PLrSHiQgy4VjGMzyXsSlvN-TjPaqFFsAGP&index=4) and [QLoRA](https://www.youtube.com/watch?v=XOb-djcw6hs&list=PLrSHiQgy4VjGMzyXsSlvN-TjPaqFFsAGP&index=5) for deeper dives into those respective technologies

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_id = "mistralai/Mistral-7B-Instruct-v0.2"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.float16,
    quantization_config=quantization_config,
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/596 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

We'll load our tokenizer and do a few pre-processing steps to prepare it for training.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = "right"

tokenizer_config.json:   0%|          | 0.00/1.46k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

Now we can set-up our LoRA configuration file - which will let the TRL library know how to create our LoRA adapters!

In [None]:
from peft import LoraConfig

peft_config = LoraConfig(
    r=32,
    lora_alpha=64,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
)

### Fine-tuning!

Now onto the star of today's show: Fine-tuning!

We're going to use the `SFTTrainer` or "Supervised Fine-tuning Trainer" from the [TRL](https://github.com/huggingface/trl/tree/main) library today.

In essence, this is a trainer that will handle most of the fine-tuning process for us - including but not limited to:

- Dataset packing
- LoRA initialization
- Tokenizing

Let's set up some training hyper-parameters through transformers `TrainingArguments` class to get started. Here's a breakdown of which parameters are doing what:

- `outpur_dir` - indicates where we store the results of training locally
- `num_train_epochs` - how many epochs we will train for (somewhere ~3-4 is a good place to start)
- `per_device_train_batch_size` - how many batches we want per device. In this case, we only have one device - but we set this to a low value to keep memory consumption below 16GB GPU RAM
- `gradient_accumulation_steps` - this hyper-parameter lets us indicate how many steps we wish to accumulate our gradients for, this is a way to "cheat out" a larger batch size without extra memory consumption
- `gradient_checkpointing` - this lets us [trade off memory consumption for increased training time](https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing)
- `optim` - our optimizer! In this case, we're using  a fused ADAMW optimiser. The fused method is a faster version of the ADAMW optimizer but is reliant on CUDA (GPU). More information can be read [here](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)

The rest of the hyper-parameters are taken directly from the QLoRA [paper](https://arxiv.org/abs/2305.14314) and are discussed in more detail there!

In [None]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="leagaleasy-mistral-7b-instruct-v0.2-v1",
    num_train_epochs=4,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    save_strategy="epoch",
    learning_rate=2e-4,
    fp16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    push_to_hub=True,
)

Because we're going to automatically push our model to the hub, thanks to `push_to_hub=True`, we'll want to provide a Hugging Face Write token.

> NOTE: You can skip this step by commenting out `push_to_hub=True`

In [None]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Now, finally, we can set-up our `SFTTrainer` which is going to help us fine-tune this model on our dataset we create at the beginning of the notebook!

We'll discuss a few parameters to clarify what they're doing:

- `formatting_func` - since we created a helper function to convert our dataset row into a Mistral-style Instruction prompt, we need to let TRL know to use it to create our prompts!
- `peft_config` - TRL will automatically load our model in LoRA format with this config.
- `packing` - this will let our model "pack" the context window to ensure we're not wasting precious memory on padding tokens where posssible
- `dataset_kwargs` - because we already added the special tokens to our prompts, we want to ensure we don't "re-add" them!

With those parameters set - we're clear for training!

In [None]:
from trl import SFTTrainer

max_seq_length=2048

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=prepared_legal_dataset["train"],
    formatting_func=create_instruction,
    peft_config=peft_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    dataset_kwargs={
        "add_special_tokens" : False,
        "append_concat_token" : False,
    }
)



All that's left to do is fine-tune our model!

In [None]:
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss




TrainOutput(global_step=12, training_loss=0.8255784511566162, metrics={'train_runtime': 83.6627, 'train_samples_per_second': 0.526, 'train_steps_per_second': 0.143, 'total_flos': 3889889670070272.0, 'train_loss': 0.8255784511566162, 'epoch': 4.0})

Now we can save it.

In [None]:
trainer.save_model()

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

adapter_model.safetensors:   0%|          | 0.00/336M [00:00<?, ?B/s]

events.out.tfevents.1709816915.10f09a525530.362.0:   0%|          | 0.00/5.43k [00:00<?, ?B/s]

Let's clear up memory so we can do inference while staying under our memory budget.

In [None]:
del model
del trainer
torch.cuda.empty_cache()

We'll need to load our mode back as a PEFT model, due to the use of LoRA, and then merge the LoRA layers back into the original model for use in Hugging Face pipelines.

In [None]:
from peft import AutoPeftModelForCausalLM

model = AutoPeftModelForCausalLM.from_pretrained(
    args.output_dir,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto"
)

merged_model = model.merge_and_unload()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Now we can load our pipeline for `text-generation`.

In [None]:
from transformers import pipeline

text_pipe = pipeline("text-generation", merged_model, tokenizer=tokenizer)

## Testing Fine-tuned Model

Now that we've fine-tuned, lets see how we did!

In [None]:
prepared_legal_dataset["test"][0]["original_text"]

'we do our best to keep facebook safe but we cannot guarantee it. we need your help to keep facebook safe which includes the following commitments by you you will not post unauthorized commercial communications such as spam on facebook. you will not collect users content or information or otherwise access facebook using automated means such as harvesting bots robots spiders or scrapers without our prior permission. you will not engage in unlawful multi level marketing such as a pyramid scheme on facebook. you will not upload viruses or other malicious code. you will not solicit login information or access an account belonging to someone else. you will not bully intimidate or harass any user. you will not post content that is hate speech threatening or pornographic incites violence or contains nudity or graphic or gratuitous violence. you will not develop or operate a third party application containing alcohol related dating or other mature content including advertisements without appro

In [None]:
outputs = text_pipe(create_instruction(prepared_legal_dataset["test"][0], return_response=False), max_new_tokens=256, temperature=0.1, top_k=50, top_p=0.1, return_full_text=False)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [None]:
outputs

[{'generated_text': 'you will not spam post hate speech or hack facebook.'}]

In [None]:
prepared_legal_dataset["test"][0]["reference_summary"]

'don t break these detailed lists of rules. basically don t do anything that negatively impacts the facebook platform or community.'

Another example!

In [None]:
prepared_legal_dataset["test"][1]["original_text"]

'you agree not to use the sites to 1. try to gain unauthorized access to any portion of the sites or any other systems or networks connected to it or to any tldr server or to any of the content offered on or through the sites by circumventing the site s access control measures either by hacking password mining or any other means 2. take any action that imposes an unreasonable or disproportionately large load on the infrastructure of the sites or tldr s systems or networks or any systems or networks connected to the sites or to tldr 3. post illegal material or use the sites for illegal activity 4. post or use the sites to distribute junk mail spam chain letters pyramid schemes phishing or other unsolicited advertising or promotion or material without significant value to the community designed to drive traffic mask its source or deceive as to authorship distribute viruses trojans or other malware or whose purpose is affiliate marketing 5. remove circumvent disable damage or otherwise in

In [None]:
outputs = text_pipe(create_instruction(prepared_legal_dataset["test"][1], return_response=False), max_new_tokens=256, temperature=0.1, top_k=50, top_p=0.1, return_full_text=False)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [None]:
outputs

[{'generated_text': 'you agree not to hack the site.'}]

In [None]:
prepared_legal_dataset["test"][1]["reference_summary"]

'no illegal shady stuff with tldrlegal or usage of the site brand in ways not intended.'

Overall, our fine-tuning did a great job and allowed our model to generate our desired output - all with <16GB GPU memory, and 4 epochs of fine-tuning!