In [2]:
from os import listdir
from os.path import isfile, join

def get_all_files_from_path(mypath):
    filenames = [join(mypath, f) for f in listdir(mypath) if isfile(join(mypath, f))]
    return filenames

from bs4 import BeautifulSoup
import re
import json

def get_article(articles):
    result = {}
    current_statue = "(non-statute)"
    for i in re.split(r"(.*)", articles.strip()):
        if len(i) == 0 or i == "\n":
            continue
        if re.search(r"^\(.*\)$", i):
            current_statue = i.strip()
            if current_statue not in result:
                result.update({current_statue: []})
        else:
            if current_statue not in result:
                result.update({current_statue: []})
            result[current_statue].append(i)
    return result

def build_test(filename):
    result = {}
    with open(filename, 'r') as f:
        data = f.read()

    data = BeautifulSoup(data, "xml").find_all('pair')
    for i in data:
        id = i.get('id')
        result.update({id: {}})
        result[id].update({"label": i.get('label')})
        articles = i.find('t1').text.strip()
        # articles = get_article(articles)
        result[id].update({"result": articles})
        result[id].update({"content": i.find('t2').text.strip()})
    return result

def write_json(filename, data):
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)

import xml.etree.ElementTree as Et
import glob

def format_first_line(text):
    lines = text.split("\n")
    results = []
    for line in lines:
        if line[0] == "":
            continue
        if line[0] == "(" and line[-1] == ")":
            continue
        results.append(line)
    return "\n".join(results)

def load_samples(filexml):
    # try:
    tree = Et.parse(filexml)
    root = tree.getroot()
    samples = []
    for i in range(0, len(root)):
        sample = {'result': []}
        for j, e in enumerate(root[i]):
            if e.tag == "t1":
                sample['result'] = format_first_line(e.text.strip())
            elif e.tag == "t2":
                question = e.text.strip()
                sample['content'] = question if len(question) > 0 else None
        sample.update(
            {'index': root[i].attrib['id'], 'label': root[i].attrib.get('label', "N")})
        # filter the noise samples
        if sample['content'] is not None:
            samples.append(sample)
        else:
            print("[Important warning] samples {} is ignored".format(sample))
    return samples

def load_test_data_samples(path_folder_base, test_id):
    data = []
    test = load_samples(f"{path_folder_base}/riteval_{test_id}.xml")
    for file_path in glob.glob(f"{path_folder_base}/riteval_{test_id}.xml"):
        data = data + load_samples(file_path)
    return data


def load_all_data_samples(path_folder_base):
    data = []
    for file_path in glob.glob("{}/*.xml".format(path_folder_base)):
        data = data + load_samples(file_path)
    return data

def check_false_labels(pred, false_labels):
	for label in false_labels:
		if label in pred:
			return True
	return False

from tqdm import tqdm

def format_output(text):
	CLEANR = re.compile('<.*?>') 
	cleantext = re.sub(CLEANR, '', text)
	return cleantext.strip().lower()

def readfile(filename):
    f = open(filename)
    data = json.load(f)
    return data

In [3]:
output_data_path = "../data/finetune_exp/fewshot_query"
input_data_path = "../data/COLIEE2024statute_data-English/train"

In [4]:
from sentence_transformers import SentenceTransformer, util
query_encoder = SentenceTransformer('facebook-dpr-question_encoder-single-nq-base')
passage_encoder = SentenceTransformer('facebook-dpr-ctx_encoder-single-nq-base')

def dpr(testfile=None, path="../data/COLIEE2024statute_data-English/train"):
    datas = get_all_files_from_path(path)
    corpus = []
    content = []
    labels = []
    for data in datas:
        if testfile != None and testfile in data:
            continue
        data = load_samples(data)
        for item in data:
            # corpus.append(item["result"].replace("\n", " ").strip())
            corpus.append(item["result"].strip())
            content.append(item["content"].strip().replace(".", ""))
            labels.append(item["label"].strip())
    print(len(corpus))
    retrival_passage_embeddings = passage_encoder.encode(corpus)
    content_passage_embeddings = passage_encoder.encode(content)
    return corpus, content, labels, retrival_passage_embeddings, content_passage_embeddings


In [5]:
def prompting(premise, hypothesis, template=None):
    text = template.replace("{{premise}}", premise).replace("{{hypothesis}}", hypothesis)
    return text

def writefile(data, filename):
    # Serializing json
    json_object = json.dumps(data, indent=1)
    # Writing to sample.json
    with open(filename, "w") as outfile:
        outfile.write(json_object)
        
def few_shot_prompting(indexes, corpus, content, labels, prompt_template):
    result = ""
    for i in indexes:
        if "true or false" in prompt_template.lower():
            answer = "True"
            if "N" == labels[i]:
                answer = "False"
        else:
            answer = "Yes"
            if "N" == labels[i]:
                answer = "No"
        prompt = prompt_template.replace("{{premise}}", corpus[i]).replace('{{hypothesis}}', content[i]).replace('{{answer}}', answer)
        result += prompt
    return result

In [6]:
fewshot = "Document: {{premise}}\nQuestion: {{hypothesis}}? True or False\nAnswer: {{answer}}\n\n"
template = "Document: {{premise}}\nQuestion: {{hypothesis}}? True or False "

In [7]:
import torch

files = get_all_files_from_path(input_data_path)

for file in files:
    outfile = file.split("/")[-1].replace(".xml", "")
    outdata = []
    data = load_samples(file)
    corpus, content, labels, retrival_passage_embeddings, content_passage_embeddings = dpr(outfile)
    for item in data:
        result = {}
        label = item["label"]
        if label == "N":
            label = "false"
        else:
            label = "true"
        hypothesis = item["content"]
        premise = item["result"]
        #Important: You must use dot-product, not cosine_similarity
        query_embedding = query_encoder.encode(hypothesis)
        scores = util.dot_score(query_embedding, content_passage_embeddings)
        indexes = torch.topk(scores, 3).indices[0]
        few_shot = few_shot_prompting(indexes, corpus, content, labels, fewshot)
        text = few_shot + prompting(premise, hypothesis, template)
        result.update({"index": item["index"]})
        result.update({"content": hypothesis})
        result.update({"result": premise})
        result.update({"prompt": text})
        result.update({"label": label})
        outdata.append(result)
    writefile(outdata, f"{output_data_path}/{outfile}.jsonl")

659
625
637
635
648
641
616
646
654
621
654
646
658


In [8]:
data = "../data/finetune_exp/fewshot_retrival/riteval_H18_en.jsonl"

readfile(data)

FileNotFoundError: [Errno 2] No such file or directory: '../data/finetune_exp/fewshot_retrival/riteval_H18_en.jsonl'