This notebook provides a basic illustration of how to use different parts of LegalBench. 

In [None]:
import os
import torch
os.environ['VLLM_USE_MODELSCOPE'] = 'True'
os.environ['CUDA_VISIBLE_DEVICES'] = '7'
from vllm import LLM, SamplingParams
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="/home/jye/huggingface/pretrained_model/qwen/Qwen1.5-7B-Chat", trust_remote_code=True, dtype=torch.float16)
outputs = llm.generate(prompts, sampling_params)

#Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

In [None]:
from tqdm.auto import tqdm
import datasets

from tasks import TASKS, ISSUE_TASKS
from utils import generate_prompts

In [None]:
# Supress progress bars which appear every time a task is downloaded
datasets.utils.logging.set_verbosity_error()

### Task organization

`tasks.py` provides data structures which organize all LegalBench tasks. For instance, `TASKS` lists all LegalBench tasks, and `ISSUE_TASKS` lists all tasks in the issue-spotting reasoning category.

In [None]:
print(len(TASKS), TASKS[:10])
print()
print(len(ISSUE_TASKS), ISSUE_TASKS)

### Loading task data

LegalBench can be downloaded from Huggingface: https://huggingface.co/datasets/nguha/legalbench. Each LegalBench dataset comes with `train` and `test` split.

- The `train` split is small (usually fewer than 10 samples). Following the [RAFT](https://raft.elicit.org/) benchmark, it's intended to provide labaled samples that can be used as few-shot demonstrations for prompts.
- The `test` split is larger, and contains samples to evaluate an LLM on. 

Documentation for each task can be found on the Github repository, under the task-specific folder. For instance, the documentation for the `abercrombie` task can be found at <https://github.com/HazyResearch/legalbench/tree/main/tasks/abercrombie>.

In [None]:
dataset = datasets.load_dataset("nguha/legalbench", "abercrombie")
dataset["train"].to_pandas()

### Loading and applying prompts

Each task folder also stores prompt templates which can be used with different models. In LegalBench, prompt templates are represented as text files, in which "{{col_name}}" denote place holders for column names.

For instance:

In [None]:
# Load base prompt
with open(f"tasks/abercrombie/base_prompt.txt") as in_file:
    prompt_template = in_file.read()
print(prompt_template)

The script `utils.py` provides a simple function for generating prompts for a dataset given a template.

In [None]:
test_df = dataset["test"].to_pandas()
prompts = generate_prompts(prompt_template=prompt_template, data_df=test_df)
print(prompts[0])

In [None]:
from transformers import Qwen2ForCausalLM, Qwen2Tokenizer

model_name_or_path = "/home/jye/huggingface/pretrained_model/qwen/Qwen1.5-7B-Chat"

tokenizer = Qwen2Tokenizer.from_pretrained(model_name_or_path, )
model = Qwen2ForCausalLM.from_pretrained(model_name_or_path, )

text = prompts[0]
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(**inputs, max_length=100)
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
output_text

In [None]:
prompts[0]

### Evaluation

The majority of LegalBench tasks are evaluated using balanced-accuracy. A handful of tasks which involve extraction or multilabel classification are evaluated using F1. To simplify evaluation, we provide an evaluation which which scores performance.

In [None]:
from evaluation import evaluate
import numpy as np

# Generate random predictions for abercrombie
# Change this, using qwen or other models
# classes = ["generic", "descriptive", "suggestive", "arbitrary", "fanciful"]
# generations = np.random.choice(classes, len(test_df))
generations = [tokenizer.decode(model.generate(**tokenizer(prompt, return_tensors="pt")), skip_special_tokens=True) for prompt in prompts]

generations
# evaluate("abercrombie", generations, test_df["answer"].tolist())

### Selecting tasks by license

LegalBench tasks are covered under different licenses. The following code allows you to filter out tasks by license type.

In [None]:
target_license = "CC BY 4.0"
tasks_with_target_license = []
for task in tqdm(TASKS):
    dataset = datasets.load_dataset("nguha/legalbench", task, split="train")
    if dataset.info.license == target_license:
        tasks_with_target_license.append(task)
print("Tasks with target license:", tasks_with_target_license)