In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import json
from tqdm import tqdm
import os
import re
import pandas as pd

os.chdir('/home/s2310409/workspace/coliee-2024/')
from utils.misc import get_summary, get_query



In [2]:
device = "cuda" # the device to load the model onto

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
def load_data(dir):
    with open(dir, 'r') as fp:
        train_data = json.load(fp)

    data = []
    for key in train_data.keys():
        data.append([key, train_data[key]])

    return pd.DataFrame(data, columns=['source', 'target'])

with open('dataset/c2023/bm25_candidates_test.json', 'r') as fp:
    candidate_dict = json.load(fp)

data_df = load_data(f'dataset/test.json')
data_df['candidates'] = data_df['source'].apply(lambda x: [c for c in candidate_dict[x] if c != x])
data_df['negative_candidates'] = data_df.apply(lambda x: [c for c in x['candidates'] if c not in x['target']], axis=1)
data_df

Unnamed: 0,source,target,candidates,negative_candidates
0,070318.txt,[015076.txt],"[032432.txt, 071237.txt, 019716.txt, 027423.tx...","[032432.txt, 071237.txt, 019716.txt, 027423.tx..."
1,077960.txt,"[009054.txt, 040860.txt]","[071412.txt, 060516.txt, 024547.txt, 087722.tx...","[071412.txt, 060516.txt, 024547.txt, 087722.tx..."
2,042319.txt,"[093691.txt, 075956.txt, 084953.txt, 022987.txt]","[027719.txt, 067612.txt, 059275.txt, 026904.tx...","[027719.txt, 067612.txt, 059275.txt, 026904.tx..."
3,041766.txt,[039269.txt],"[071818.txt, 056351.txt, 009599.txt, 046346.tx...","[071818.txt, 056351.txt, 009599.txt, 046346.tx..."
4,077407.txt,[038669.txt],"[038092.txt, 096647.txt, 056351.txt, 060210.tx...","[038092.txt, 096647.txt, 056351.txt, 060210.tx..."
...,...,...,...,...
314,085079.txt,"[044669.txt, 003144.txt]","[080328.txt, 056351.txt, 068423.txt, 041404.tx...","[080328.txt, 056351.txt, 068423.txt, 041404.tx..."
315,031370.txt,"[096341.txt, 060602.txt, 047107.txt, 084522.tx...","[027678.txt, 086122.txt, 060516.txt, 031040.tx...","[027678.txt, 086122.txt, 060516.txt, 031040.tx..."
316,085828.txt,"[004301.txt, 074887.txt, 088994.txt]","[008459.txt, 053850.txt, 003821.txt, 087722.tx...","[008459.txt, 053850.txt, 003821.txt, 087722.tx..."
317,024957.txt,"[015009.txt, 080348.txt]","[066045.txt, 077315.txt, 075868.txt, 022332.tx...","[066045.txt, 077315.txt, 075868.txt, 022332.tx..."


# Zero-shot reranking

In [29]:
# messages = [
#     {"role": "user", "content": "What is your favourite condiment?"},
#     {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
#     {"role": "user", "content": "Do you have mayonnaise recipes?"}
# ]

# def reranking_prompting(list_articles, query_content):
#     prompting = f"In bellow articles:  "
#     for a_id in list_articles:
#         a_content = get_summary(a_id)
#         prompting = prompting + f"\n##Article {a_id}: {a_content},"
        
#     prompting = prompting +f"\n##Question: which articles really relevant to the following article? Answer the article name only. \n##Article: {query_content}"
#     prompting = prompting + "\n##Answer:"
#     return prompting

# source = data_df['source'][0]
# candidates = candidate_dict[source][0:7]
# text = reranking_prompting(candidates, get_summary(source))

# encodeds = tokenizer(text, return_tensors="pt")
# model_inputs = encodeds.to(device)

# generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True)
# decoded = tokenizer.batch_decode(generated_ids)
# print(candidates)
# print(decoded[0].replace(text, ''))

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [4]:
def clean_text(text):
    text = re.sub(r'\n', '', text)
    text = re.sub(r'\[', '', text)
    text = re.sub(r'\]', '', text)
    return text

def instruct_rerank_prompting(base_case, candidates):
    prompt = f"""[INST] You are a helpful legal assistant. You are helping a user to find relevant articles to the following base article.
    ## Base Article : \n{clean_text(get_summary(base_case))}
    ## Candidates : """
    for c in candidates:
        content = get_summary(c)
        content = clean_text(content)
        prompt = prompt + f"\nAricle {c.split('.')[0]}: {content}"
    # tokenizer.encode(prompt)
    prompt = prompt + f"\n## Question : Which articles are closely relevant to the base article? Answer the relevant article name only:[\INST]"
    return prompt
    
base_case = data_df['source'][0]
candidates = candidate_dict[base_case][0:5]

prompt = instruct_rerank_prompting(base_case, candidates)

encodeds = tokenizer(prompt, return_tensors="pt")
model_inputs = encodeds.to(device)

generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)
decoded = tokenizer.batch_decode(generated_ids)
print(candidates)
print(decoded[0].replace(prompt, ''))
print(list(set(re.findall(r'\d{6}', decoded[0].replace(prompt, '')))))

['032432.txt', '071237.txt', '019716.txt', '027423.txt', '012462.txt']
<s>  Article 032432, Article 071237, and Article 019716. 

Explanation: The articles pertain to issues related to bias, natural justice, and fairness in decision-making processes, which are also present in the base article. In Article 032432, the Court of Appeal discussed the ability to be impartial when deciding on two occasions and the length of the interrogations during a hearing raising an apprehension of bias. In Article 071237, the Federal Court of Appeal considered the applicant's allegations of his right to a fair hearing and the board's interventions that interfered with his ability to present his case. In Article 019716, the applicants requested materials that were relevant to their claims of bias and breach of procedural fairness. These articles address similar issues and concepts as the base article.</s>
['032432', '019716', '071237']


In [34]:
result = []
for _ in tqdm(range(50)):
    generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id)
    decoded = tokenizer.batch_decode(generated_ids)
    result.extend(list(set(re.findall(r'\d{6}', decoded[0].replace(prompt, '')))))

# count each article
from collections import Counter
Counter(result)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [13:23<00:00, 16.07s/it]


Counter({'032432': 47, '071237': 35, '012462': 31, '019716': 18, '027423': 14})

In [None]:
# predict on data_df
prediction_dict = {}
for i in tqdm(range(len(data_df))):
    base_case = data_df['source'][i]
    prediction_dict[base_case] = []
    # group of 5 candidates
    for j in range(0, len(data_df['candidates'][i]), 5):
        candidates = data_df['candidates'][i][j:j+5]
        prompt = instruct_rerank_prompting(base_case, candidates)

        encodeds = tokenizer(prompt, return_tensors="pt")
        model_inputs = encodeds.to(device)

        generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)
        decoded = tokenizer.batch_decode(generated_ids)

        predictions = list(set(re.findall(r'\d{6}', decoded[0].replace(prompt, ''))))
        prediction_dict[base_case].extend(predictions)
        

# Few-shot classification

In [None]:
def few_shot_prompting(list_articles, query_content):
    pass

# Summarize

In [4]:
def summarize_prompt(doc):
    prompt = f"""[INST] You are a helpful legal assistant. You are helping a user to summarize case law documents.
    ## Article : \n{doc}"""
    # tokenizer.encode(prompt)
    prompt = prompt + f"\n## TLDR:[\INST]"
    return prompt

model.eval()

list_files = os.listdir('dataset/processed')
list_files = [f for f in list_files if f.endswith('.txt')]
for file in tqdm(list_files):
    with open(f'dataset/processed/{file}', 'r') as fp:
        doc = fp.read()
    doc = tokenizer.decode(tokenizer.encode(doc, max_length=10000, truncation=True))
    prompt = summarize_prompt(doc)
    with torch.no_grad():
        encodeds = tokenizer(prompt, return_tensors="pt")
        model_inputs = encodeds.to(device)
        generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)
        decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    summarized_doc = decoded[0].split('[\INST]')[1].strip()
    with open(f'dataset/mixtral_summarized/{file}', 'w') as fp:
        fp.write(summarized_doc)

  0%|▍                                                                                                         | 27/7469 [25:34<118:32:04, 57.34s/it]