In [7]:
import sys
import os

sys.path.append("./llms_for_legal")
sys.path.append("./llms_for_legal/modules/pygaggle")
os.environ["JVM_PATH"] = "/home/s2210405/jdk-19.0.2/lib/server/libjvm.so"

## Predict cases with monot5-large-10k hard-negative sampling

In [2]:
from eval_monot5 import predict_all_monot5, predict_all_bm25

2024-01-16 11:55:29.808035: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-01-16 11:55:29.855315: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
bm25_scores = predict_all_bm25(
    dataset_path="./data/task2_train_files_2024",
    bm25_index_path="./data/bm25_indexes/coliee_task2/test",
    eval_segment=None
)

In [6]:
monot5_scores = predict_all_monot5(
    ckpt_path="./llms_for_legal/train_logs/monot5-large-10k_ns/ckpt",
    dataset_path="./data/task2_train_files_2024",
    eval_segment=None
)

You are using the legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565


In [8]:
from llms_for_legal.src.data import get_task2_data
import os

corpus_dir, cases_dir, label_data = get_task2_data(
    dataset_path="./data/task2_train_files_2024",
    segment="val"
)


def get_pred_cases(scores, bm25_scores, top_k, margin, alpha):
    pred_cases = {}
    for case in cases_dir:
        bm25_score = bm25_scores[case]
        score = scores[case]

        candidate_dir = corpus_dir / case / "paragraphs"
        candidate_cases = sorted(os.listdir(candidate_dir))

        final_score = []
        for cand_case in candidate_cases:
            if alpha == 1:
                if cand_case not in bm25_score:
                    final_score.append(0)
                else:
                    final_score.append(score[cand_case])
            else:
                final_score.append(
                    [
                        cand_case,
                        alpha * score[cand_case]
                        + (1 - alpha) * bm25_score.get(cand_case, 0),
                    ]
                )
        final_score = list(sorted(final_score, key=lambda x: -x[1]))

        top_ind = final_score[:top_k]
        pred_ind = [top_ind[0]]
        for cand in top_ind[1:]:
            if top_ind[0][1] - cand[1] < margin:
                pred_ind.append([cand[0], cand[1]])

        pred_cases[case] = pred_ind
        pred_cases[case] = top_ind

    return pred_cases
    
monot5_pred_cases = get_pred_cases(
    scores=monot5_scores, bm25_scores=bm25_scores, top_k=1, margin=0, alpha=0.9
)

2024-01-16 13:45:43 [INFO] env: 
Using override env var JVM_PATH (/home/s2210405/jdk-19.0.2/lib/server/libjvm.so) to load libjvm.
Please report your system information (os version, java
version, etc), and the path that works for you, to the
PyJNIus project, at https://github.com/kivy/pyjnius/issues.
so we can improve the automatic discovery.

2024-01-16 13:45:44 [INFO] loader: Loading faiss with AVX2 support.
2024-01-16 13:45:44 [INFO] loader: Could not load library with AVX2 support due to:
ModuleNotFoundError("No module named 'faiss.swigfaiss_avx2'")
2024-01-16 13:45:44 [INFO] loader: Loading faiss.
2024-01-16 13:45:44 [INFO] loader: Successfully loaded faiss.
2024-01-16 13:45:45.948476: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-01-16 13:45:45.950513: I tenso

## Few-shot Learning with LLMs

In [2]:
import json

bm25_scores = json.load(open("./data/bm25_scores.json"))
monot5_scores = json.load(open("./data/monot5_scores.json"))

In [3]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model_checkpoint = "google/flan-t5-xxl"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_checkpoint, torch_dtype=torch.float16, load_in_8bit=True, device_map="auto"
)

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [4]:
import os
import re


def load_txt(file_path, skip=0):
    with open(file_path, encoding="utf-8") as f:
        while skip > 0:
            f.readline()
            skip -= 1
        data = f.read()
    return data


def preprocess_case_data(
    file_path,
    max_length=None,
    min_sentence_length=None,
    uncased=False,
    filter_min_length=None,
):
    if not os.path.exists(file_path):
        return None

    text = load_txt(file_path)

    text = (
        text.strip()
        .replace("\n", " ")
        .replace("FRAGMENT_SUPPRESSED", "")
        .replace("FACTUAL", "")
        .replace("BACKGROUND", "")
        .replace("ORDER", "")
    )
    if uncased:
        text = text.lower()
    text = re.sub("\s+", " ", text).strip()
    text = " ".join([w for w in text.split() if w])

    cite_number = re.search("\[[0-9]+\]", text)
    if cite_number:
        text = text[cite_number.span()[1] :].strip()
    if filter_min_length:
        words = text.split()
        if len(words) <= filter_min_length:
            return None

    if min_sentence_length:
        text = filter_document(text, min_sentence_length)
    if max_length:
        words = text.split()[:max_length]
        text = " ".join(words)
    if not text.endswith("."):
        text = text + "."
    return text


def format_output(text):
    CLEANR = re.compile("<.*?>")
    cleantext = re.sub(CLEANR, "", text)
    return cleantext.strip().lower()

### Zero-shot Learning

In [9]:
from tqdm.notebook import tqdm
import random

zero_shot_prompt_template = "In bellow documents:\n{}\nQuestion: which documents really relevant to query '{}'?"


def zero_short_generate_prompt(query, candidates):
    document_map = [""] * len(candidates)
    candidate_string_list = []
    for i, cand in enumerate(candidates):
        candidate_string_list.append(f"Document {i+1}: {cand[1]}")
        document_map[i] = cand[0]
    prompt = zero_shot_prompt_template.format(
        "\n".join(candidate_string_list), query
    )
    return prompt, document_map


def get_document_id(answer, document_map):
    return document_map[int(answer.split()[-1]) - 1]

In [15]:
final_preds = {}
for case, predictions in tqdm(monot5_pred_cases.items()):
    if len(predictions) >= 2:
        query = preprocess_case_data(
            f"./data/task2_train_files_2024/{case}/entailed_fragment.txt"
        )
        candidates = [
            (
                pred[0],
                preprocess_case_data(
                    f"./data/task2_train_files_2024/{case}/paragraphs/{pred[0]}"
                ),
            )
            for pred in predictions
        ]
        prompt, document_map = zero_short_generate_prompt(query, candidates)
        with torch.no_grad():
            inputs = tokenizer(prompt, return_tensors="pt", padding="longest").to("cuda")[
                "input_ids"
            ]
            outputs = model.generate(inputs, max_new_tokens=2)
            raw_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            output_text = [format_output(e.replace(prompt, "")) for e in raw_output]
            try:
                final_preds[case] = [
                    get_document_id(text, document_map) for text in output_text
                ]
            except:
                print(raw_output)
                print(case)
                print(prompt)
                break
    else:
        final_preds[case] = [predictions[0][0]]

  0%|          | 0/100 [00:00<?, ?it/s]

In [21]:
sample = monot5_pred_cases['526']

In [22]:
sample

[['024.txt', 92.46726432793902]]

### Few-shot Learning

In [10]:
import json

def load_json(file):
    content = {}
    with open(file) as f:
        content = json.load(f)
    return content

train_labels = load_json("./data/train_labels.json")
train_labels.update(load_json("./data/val_labels.json"))
test_labels = load_json("./data/test_labels.json")

In [32]:
from collections import Counter

labels_count = Counter([len(v) for _, v in test_labels.items()])
labels_count.most_common()

[(1, 86), (2, 9), (3, 4), (4, 1)]

In [11]:
from pathlib import Path

train_data = {}
for case in Path("./data/task2_train_files_2024").iterdir():
    if int(case.name) <= 625:
        entailed_fragment = preprocess_case_data(case / "entailed_fragment.txt")
        candidates = []
        for cand in Path(case / "paragraphs").iterdir():
            cand_content = preprocess_case_data(cand)
            candidates.append(
                [
                    cand.name,
                    cand_content,
                    1 if cand.name in train_labels[case.name] else 0,
                ]
            )
        train_data[case.name] = {
            "fragment": entailed_fragment,
            "candidates": candidates,
        }

In [None]:
shot_prompt_template = (
    'In bellow documents:\n{}\nThe documents really relevant to query "{}" '
)

num_doc_per_shot = 5


def shot_generate_prompt(query, candidates):
    document_map = []
    candidate_string_list = []
    positive_candidates = [cand for cand in candidates if cand[2] == 1]
    negative_candidates = [cand for cand in candidates if cand[2] == 0]
    candidates = positive_candidates + random.sample(
        negative_candidates,
        min(len(negative_candidates), num_doc_per_shot - len(positive_candidates)),
    )
    random.shuffle(candidates)
    for i, cand in enumerate(candidates):
        document_map.append(cand[0])
        candidate_string_list.append(f"Document {i+1}: {cand[1]}")

    prompt = shot_prompt_template.format("\n".join(candidate_string_list), query)

    answers = [document_map.index(cand[0]) for cand in positive_candidates]

    if len(answers) == 1:
        prompt += "is "
    else:
        prompt += "are "
    prompt += " ".join([f"document {a+1}" for a in answers]) + "."
    return prompt, document_map


num_shots = 3
final_preds = {}

for case, predictions in tqdm(monot5_pred_cases.items()):
    query = preprocess_case_data(
        f"./data/task2_train_files_2024/{case}/entailed_fragment.txt"
    )
    candidates = [
        (
            pred[0],
            preprocess_case_data(
                f"./data/task2_train_files_2024/{case}/paragraphs/{pred[0]}"
            ),
        )
        for pred in predictions
    ]
    samples = random.sample(list(train_data.items()), num_shots)
    few_shots = []
    for case, sample in samples:
        shot, document_map = shot_generate_prompt(
            sample["fragment"], sample["candidates"]
        )
        few_shots.append(shot)
    last_shot, document_map = zero_short_generate_prompt(query, candidates)
    few_shots.append(last_shot)
    prompt = "\n####\n".join(few_shots)

    with torch.no_grad():
        inputs = tokenizer(prompt, return_tensors="pt", padding="longest").to("cuda")[
            "input_ids"
        ]
        outputs = model.generate(inputs, max_new_tokens=2)
        raw_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        output_text = [format_output(e.replace(prompt, "")) for e in raw_output]
        final_preds[case] = [
            get_document_id(text, document_map) for text in output_text
        ]

### Evaluate

In [20]:
tp = 0
for case, pred in final_preds.items():
    tp += len([p for p in pred if p in label_data[case]])
p = tp / sum([len(v) for _, v in final_preds.items()])
r = tp / sum([len(v) for _, v in label_data.items()])
f = 2 * p * r / (p + r)
print(f, p, r)

0.7614678899082568 0.83 0.7033898305084746
