# LegalEasy: Fine-tuning a Model to Produce Natural Language Descriptions of Legalese

Today we'll be looking at a specific example of fine-tuning [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2): Summarization in a particular style for a particular domain!

Let's start, as we always do, with installing our dependencies.



In [1]:
# !pip install -qU transformers peft trl accelerate bitsandbytes datasets

Next, let's set-up some data!

## Data

We'll be using the Legal Summarization [dataset](https://github.com/lauramanor/legal_summarization) from the paper [Plain English Summarization of Contracts](https://www.aclweb.org/anthology/W19-2201) today.

This dataset contains pairs in the following format:

- Original Text: A blob of legal text, think ToS
- Reference Summary: A short natural language summary of the legal text

We'll start by cloning the repository containing our data.

In [2]:
# !git clone https://github.com/lauramanor/legal_summarization.git

Let's convert this into an expected format - in this case a list of `json` objects.

In [3]:
import json

jsonl_array = []

with open('legal_summarization/tldrlegal_v1.json') as f:
  data = json.load(f)
  for key, value in data.items():
    jsonl_array.append(value)

In [None]:
jsonl_array

Now we can convert that into the desired format for our fine-tuning - a Hugging Face `Dataset`!

In [4]:
from datasets import Dataset, load_dataset

legal_dataset = Dataset.from_list(jsonl_array)

Let's see how many items we're working with in our dataset.

> NOTE: Keep in mind that this is a relatively simple example meant to showcase fine-tuning - in practice, we'd want to use somewhere in the neighbourhood of ~500-50,000 examples.

In [5]:
legal_dataset

Dataset({
    features: ['doc', 'id', 'original_text', 'reference_summary', 'title', 'uid'],
    num_rows: 85
})

Let's look at an example of our original text and summary!

In [6]:
print(f"Original Text: {legal_dataset[0]['original_text']}\n\nSummary: {legal_dataset[0]['reference_summary']}")

Original Text: welcome to the pokémon go video game services which are accessible via the niantic inc niantic mobile device application the app. to make these pokémon go terms of service the terms easier to read our video game services the app and our websites located at http pokemongo nianticlabs com and http www pokemongolive com the site are collectively called the services. please read carefully these terms our trainer guidelines and our privacy policy because they govern your use of our services.

Summary: hi.


Now, we mentioned earlier we were going to fine-tune Mistral-7B-Instruct-v0.2, which is important for our next step: Creating the instruction format.

Let's take a look at the instructions (so meta) to generate an instruction prompt from the [model card](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2#instruction-format)


> In order to leverage instruction fine-tuning, your prompt should be surrounded by [INST] and [/INST] tokens. The very first instruction should begin with a begin of sentence id. The next instructions should not. The assistant generation will be ended by the end-of-sentence token id.

Let's look at an example of how we might format our instruction - and then reproduce that in code.

```python
<s>[INST]Please convert the following legal content into a human-readable summary

[LEGAL_DOC]
welcome to the pokémon go video game services which are accessible via the niantic inc niantic mobile device application the app. to make these pokémon go terms of service the terms easier to read our video game services the app and our websites located at http pokemongo nianticlabs com and http www pokemongolive com the site are collectively called the services. please read carefully these terms our trainer guidelines and our privacy policy because they govern your use of our services.
[END_LEGAL_DOC]

[/INST]
hi.</s>
```

> NOTE: We're adding our own special tokens here in `[LEGAL_DOC]` and `[END_LEGAL_DOC]` to encourage the model to better understand our context, but these are not special tokens that are already understood by the model

In [7]:
INSTRUCTION_PROMPT_TEMPLATE = """\
<s>[INST]Please convert the following legal content into a human-readable summary

[LEGAL_DOC]
{LEGAL_TEXT}
[END_LEGAL_DOC]

[/INST]
"""

RESPONSE_TEMPLATE = """\
{NATURAL_LANGUAGE_SUMMARY}</s>"""

Now we can create a helper function that will convert our dataset row into the above prompt!

In [8]:
def create_instruction(sample, return_response=True):
  prompt = INSTRUCTION_PROMPT_TEMPLATE.format(LEGAL_TEXT=sample["original_text"])

  if return_response:
    prompt += RESPONSE_TEMPLATE.format(NATURAL_LANGUAGE_SUMMARY=sample["reference_summary"])

  return prompt

Let's try it out!

In [10]:
create_instruction(legal_dataset[0])

'<s>[INST]Please convert the following legal content into a human-readable summary\n\n[LEGAL_DOC]\nwelcome to the pokémon go video game services which are accessible via the niantic inc niantic mobile device application the app. to make these pokémon go terms of service the terms easier to read our video game services the app and our websites located at http pokemongo nianticlabs com and http www pokemongolive com the site are collectively called the services. please read carefully these terms our trainer guidelines and our privacy policy because they govern your use of our services.\n[END_LEGAL_DOC]\n\n[/INST]\nhi.</s>'

We'll partition our dataset so we can test some of the outputs after we've completed our training.

In [11]:
prepared_legal_dataset = legal_dataset.train_test_split(test_size=0.1)

In [12]:
prepared_legal_dataset

DatasetDict({
    train: Dataset({
        features: ['doc', 'id', 'original_text', 'reference_summary', 'title', 'uid'],
        num_rows: 76
    })
    test: Dataset({
        features: ['doc', 'id', 'original_text', 'reference_summary', 'title', 'uid'],
        num_rows: 9
    })
})

## Loading Our Model

Now we can move onto loading our model!

We're going to be dependent on two major technologies to allow us to train our model with <=16GB GPU RAM.

1. Quantization
2. LoRA

> NOTE: We've done some events on [LoRA](https://www.youtube.com/watch?v=kV8yXIUC5_4&list=PLrSHiQgy4VjGMzyXsSlvN-TjPaqFFsAGP&index=4) and [QLoRA](https://www.youtube.com/watch?v=XOb-djcw6hs&list=PLrSHiQgy4VjGMzyXsSlvN-TjPaqFFsAGP&index=5) for deeper dives into those respective technologies

In [15]:
!pip install accelerate
!pip install -i https://pypi.org/simple/ bitsandbytes


Looking in indexes: https://pypi.org/simple/


In [16]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_id = "mistralai/Mistral-7B-Instruct-v0.2"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.float16,
    quantization_config=quantization_config,
)

ImportError: Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` and the latest version of bitsandbytes: `pip install -i https://pypi.org/simple/ bitsandbytes`

We'll load our tokenizer and do a few pre-processing steps to prepare it for training.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = "right"

Now we can set-up our LoRA configuration file - which will let the TRL library know how to create our LoRA adapters!

In [None]:
from peft import LoraConfig

peft_config = LoraConfig(
    r=32,
    lora_alpha=64,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
)

### Fine-tuning!

Now onto the star of today's show: Fine-tuning!

We're going to use the `SFTTrainer` or "Supervised Fine-tuning Trainer" from the [TRL](https://github.com/huggingface/trl/tree/main) library today.

In essence, this is a trainer that will handle most of the fine-tuning process for us - including but not limited to:

- Dataset packing
- LoRA initialization
- Tokenizing

Let's set up some training hyper-parameters through transformers `TrainingArguments` class to get started. Here's a breakdown of which parameters are doing what:

- `outpur_dir` - indicates where we store the results of training locally
- `num_train_epochs` - how many epochs we will train for (somewhere ~3-4 is a good place to start)
- `per_device_train_batch_size` - how many batches we want per device. In this case, we only have one device - but we set this to a low value to keep memory consumption below 16GB GPU RAM
- `gradient_accumulation_steps` - this hyper-parameter lets us indicate how many steps we wish to accumulate our gradients for, this is a way to "cheat out" a larger batch size without extra memory consumption
- `gradient_checkpointing` - this lets us [trade off memory consumption for increased training time](https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing)
- `optim` - our optimizer! In this case, we're using  a fused ADAMW optimiser. The fused method is a faster version of the ADAMW optimizer but is reliant on CUDA (GPU). More information can be read [here](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)

The rest of the hyper-parameters are taken directly from the QLoRA [paper](https://arxiv.org/abs/2305.14314) and are discussed in more detail there!

In [None]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="leagaleasy-mistral-7b-instruct-v0.2-v1",
    num_train_epochs=4,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    save_strategy="epoch",
    learning_rate=2e-4,
    fp16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    push_to_hub=True,
)

Because we're going to automatically push our model to the hub, thanks to `push_to_hub=True`, we'll want to provide a Hugging Face Write token.

> NOTE: You can skip this step by commenting out `push_to_hub=True`

In [None]:
from huggingface_hub import notebook_login

notebook_login()

Now, finally, we can set-up our `SFTTrainer` which is going to help us fine-tune this model on our dataset we create at the beginning of the notebook!

We'll discuss a few parameters to clarify what they're doing:

- `formatting_func` - since we created a helper function to convert our dataset row into a Mistral-style Instruction prompt, we need to let TRL know to use it to create our prompts!
- `peft_config` - TRL will automatically load our model in LoRA format with this config.
- `packing` - this will let our model "pack" the context window to ensure we're not wasting precious memory on padding tokens where posssible
- `dataset_kwargs` - because we already added the special tokens to our prompts, we want to ensure we don't "re-add" them!

With those parameters set - we're clear for training!

In [None]:
from trl import SFTTrainer

max_seq_length=2048

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=prepared_legal_dataset["train"],
    formatting_func=create_instruction,
    peft_config=peft_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    dataset_kwargs={
        "add_special_tokens" : False,
        "append_concat_token" : False,
    }
)

All that's left to do is fine-tune our model!

In [None]:
trainer.train()

Now we can save it.

In [None]:
trainer.save_model()

Let's clear up memory so we can do inference while staying under our memory budget.

In [None]:
del model
del trainer
torch.cuda.empty_cache()

We'll need to load our mode back as a PEFT model, due to the use of LoRA, and then merge the LoRA layers back into the original model for use in Hugging Face pipelines.

In [None]:
from peft import AutoPeftModelForCausalLM

model = AutoPeftModelForCausalLM.from_pretrained(
    args.output_dir,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto"
)

merged_model = model.merge_and_unload()

Now we can load our pipeline for `text-generation`.

In [None]:
from transformers import pipeline

text_pipe = pipeline("text-generation", merged_model, tokenizer=tokenizer)

## Testing Fine-tuned Model

Now that we've fine-tuned, lets see how we did!

In [None]:
prepared_legal_dataset["test"][0]["original_text"]

In [None]:
outputs = text_pipe(create_instruction(prepared_legal_dataset["test"][0], return_response=False), max_new_tokens=256, temperature=0.1, top_k=50, top_p=0.1, return_full_text=False)

In [None]:
outputs

In [None]:
prepared_legal_dataset["test"][0]["reference_summary"]

Another example!

In [None]:
prepared_legal_dataset["test"][1]["original_text"]

In [None]:
outputs = text_pipe(create_instruction(prepared_legal_dataset["test"][1], return_response=False), max_new_tokens=256, temperature=0.1, top_k=50, top_p=0.1, return_full_text=False)

In [None]:
outputs

In [None]:
prepared_legal_dataset["test"][1]["reference_summary"]

Overall, our fine-tuning did a great job and allowed our model to generate our desired output - all with <16GB GPU memory, and 4 epochs of fine-tuning!