## 安裝套件 (For google collab)

In [None]:
! pip install datasets evaluate transformers[sentencepiece]
! pip install scikit-learn
! pip install pytorch-ignite
! pip install jieba
! pip install faiss-cpu
! pip install faiss-gpu

In [1]:
import transformers as T
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm
from ignite.metrics import Rouge
import re
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
import jieba
import matplotlib.pyplot as plt
from opencc import OpenCC

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [2]:
lr = 2e-5
epochs = 25
train_batch_size = 1
validation_batch_size = 1
test_batch_size = 1
random_seed = 42
max_length = 300

# path setting that souldn't be changed
tokenizer_path = "./saved_tokenizer"
dataset_path = "./train.xlsx"
test_dataset_path = './test.xlsx'
model_path = "./saved_models"

In [None]:
# The pattern recognition for the main judgment text
facts_pattern = re.compile(r'(.+?)，本院判決如下：', re.DOTALL)
arguments_pattern = re.compile(r'理由(.+?)據上論結', re.DOTALL)
decisions_pattern_1 = re.compile(r'據上論結，(.+?)判決', re.DOTALL)
decisions_pattern_2 = re.compile(r'如下：主文(.+?)。', re.DOTALL)

cc = OpenCC('tw2sp')
stopwords = set(["的", "在", "是", "了", "和"])
chinese_number_pattern = r'[零一二三四五六七八九十百千万亿]+'

def chinese_to_arabic_number(chinese_num):
    chinese_digits = {
        '零': 0, '一': 1, '二': 2, '三': 3, '四': 4,
        '五': 5, '六': 6, '七': 7, '八': 8, '九': 9, 
        # '○': 0
    }

    chinese_units = {
        '十': 10, '百': 100, '千': 1000,
        '萬': 10000, '億': 100000000
    }
    
    def convert_section(section):
        num = 0
        unit = 1
        for char in reversed(section):
            if char in chinese_units:
                unit = chinese_units[char]
            else:
                digit = chinese_digits[char]
                num += digit * unit
                unit *= 10
        return num

    if chinese_num == '零':
        return 0

    sections = chinese_num.split('億')
    num = 0

    if len(sections) == 2:
        num += convert_section(sections[0]) * 100000000
        sections = sections[1].split('萬')
    else:
        sections = sections[0].split('萬')

    if len(sections) == 2:
        num += convert_section(sections[0]) * 10000
        num += convert_section(sections[1])
    else:
        num += convert_section(sections[0])

    return num

def transfer_special_token(text):
    token_mapping = {
        '㈠': '一',
        '㈡': '二',
        '㈢': '三',
        '㈣': '四',
        '㈤': '五',
        '㈥': '六',
        '㈦': '七',
        '㈧': '八',
        '㈨': '九',
        '㈩': '十'
    }
    
    for token, replacement in token_mapping.items():
        text = text.replace(token, replacement)
        
    return text
    

def clean_text(text):
    text = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fff\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    
    return text

def trad_to_zh(text):
    cc = OpenCC('tw2sp')
    zh_text = cc.convert(text)
    
    return zh_text

def zh_to_trad(text):
    cc = OpenCC('s2twp')
    trad_text = cc.convert(text)
    
    return trad_text

def remove_stopwords(text):
    filtered_texts = [word for word in text if word not in stopwords]
    
    return filtered_texts

# The function that does the text preprocessing for the custom dataset
def preprocess_text(text):
    # text = clean_text(text)
    num_matches = re.findall(chinese_number_pattern, text)
    for match in num_matches:
        arabic_number = str(chinese_to_arabic_number(match))
        text = text.replace(match, arabic_number)
    text = transfer_special_token(text)
    text = trad_to_zh(text)
    text = remove_stopwords(text)
    text = ''.join(text)
    return text

## RAG

In [None]:
from transformers import DPRQuestionEncoder, DPRContextEncoder, DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer

rag_model = T.AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50", cache_dir="./cache/").to(device)
rag_tokenizer = T.AutoTokenizer.from_pretrained("facebook/mbart-large-50", cache_dir="./cache/", model_max_length=512)

# Load the question encoder and tokenizer
question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base', cache_dir="./cache/").to(device)
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base', cache_dir="./cache/")

# Load the context encoder and tokenizer
context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base', cache_dir="./cache/").to(device)
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base', cache_dir="./cache/")


In [None]:
class Legal_Judgment_RAG_Dataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        data_df = pd.read_excel(dataset_path)
        self.data = []
        for index, row in data_df.iterrows():
            main_text = str(row['裁判原文'])
            
            facts = facts_pattern.search(main_text)
            if (facts):
                facts =  '<facts> ' + str(facts.group(1)) + ' </facts>'
            else:
                facts = '<facts>  </facts>'
                
            arguments = arguments_pattern.search(main_text)
            if (arguments):
                arguments = '<arguments> ' + str(arguments.group(1)) + ' </arguments>'
            else:
                arguments = '<arguments>  </arguments>'
                
            decisions = '<decisions> '
            decisions_1 = decisions_pattern_1.search(main_text)
            if (decisions_1):
                decisions_1 = str(decisions_1.group(1))
                decisions += decisions_1
            decisions_2 = decisions_pattern_2.search(main_text)
            if (decisions_2):
                decisions_2 = str(decisions_2.group(1))
                decisions += decisions_2
            decisions += ' </decisions>'
            
            origin_context = f"{facts} {arguments}"
            self.data.append({"origin_context": origin_context, "decisions": f"{decisions}"})

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

rag_dataset = Legal_Judgment_RAG_Dataset()
train_rag_dataset, untrain_rag_dataset = train_test_split(rag_dataset, test_size=0.2, random_state=random_seed)

In [None]:
def get_rag_tensor(sample):
    # 將模型的輸入和ground truth打包成Tensor
    return [each["origin_context"] for each in sample], [each["decisions"] for each in sample]
summary_rag = DataLoader(rag_dataset, collate_fn=get_rag_tensor, batch_size=validation_batch_size, shuffle=False)
d_train = DataLoader(train_rag_dataset, collate_fn=get_rag_tensor, batch_size=validation_batch_size, shuffle=False)
encoded_docs = []
for o, s in summary_rag:
    combined_text = o[0]
    inputs = context_tokenizer(combined_text, return_tensors='pt', truncation=True, padding='max_length', max_length=max_length).to(device)
    with torch.no_grad():
        embeddings = context_encoder(**inputs).pooler_output
    embeddings = embeddings.to(torch.float32)
    encoded_docs.append(embeddings)

In [None]:
import torch
import faiss

def retrieve(query, k=5):

    inputs = question_tokenizer(query, return_tensors='pt', truncation=True, padding='max_length', max_length=max_length).to(device)
    query_embedding = question_encoder(**inputs).pooler_output
    
    # Calculate cosine similarities
    similarities = []
    for doc_embedding in encoded_docs:
        similarity = torch.nn.functional.cosine_similarity(query_embedding, doc_embedding)
        similarities.append(similarity.item())
    
    # Get top-k documents
    top_k_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:k]
    top_k_docs = ["origin_context: " + rag_dataset.__getitem__(i)["origin_context"] + " decisions: " + rag_dataset.__getitem__(i)["decisions"] for i in top_k_indices]
    return top_k_docs

for d, _ in summary_rag:
    # Example usage
    query = d[0]
    top_docs = retrieve(query)
    print(top_docs)
    break


In [None]:
def rag_generate(query):
    # Retrieve relevant documents
    rag_model.eval()
    top_docs = retrieve(query)
    context = " ".join(top_docs)
    
    # Generate response using BART
    inputs = rag_tokenizer("make decision: " + query + "\nexamples:\n" + context, return_tensors='pt', truncation=True, padding=True).to(device)
    summary_ids = rag_model.generate(inputs['input_ids'], max_length=max_length, early_stopping=False)
    generated_text = rag_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return generated_text

for d, _ in summary_rag:
    # Example usage
    query = d[0]
    print(rag_generate(query))
    break

### Test RAG

In [None]:
def test_rag(model, test_dataset):
    
    pbar = tqdm(test_dataset)
    pbar.set_description(f"Testing")

    for idx, testdata in tqdm(test_dataset.iterrows(), total=test_dataset.shape[0]):
        main_text = str(testdata['裁判原文'])
            
        facts = facts_pattern.search(main_text)
        if (facts):
            facts =  '<facts> ' + str(facts.group(1)) + ' </facts>'
        else:
            facts = '<facts>  </facts>'
            
        arguments = arguments_pattern.search(main_text)
        if (arguments):
            arguments = '<arguments> ' + str(arguments.group(1)) + ' </arguments>'
        else:
            arguments = '<arguments>  </arguments>'
            
        decisions = '<decisions> '
        decisions_1 = decisions_pattern_1.search(main_text)
        if (decisions_1):
            decisions_1 = str(decisions_1.group(1))
            decisions += decisions_1
        decisions_2 = decisions_pattern_2.search(main_text)
        if (decisions_2):
            decisions_2 = str(decisions_2.group(1))
            decisions += decisions_2
        decisions += ' </decisions>'
        
        ori_data = "summary: " + f"{facts} {arguments}"
        outputs_context = rag_generate(ori_data)
        outputs = outputs_context.replace("<unk>", "").replace("<pad>","")
        test_dataset.loc[idx, "判決"] = outputs
    
    return test_dataset

In [None]:
test_dataset = pd.read_excel(test_dataset_path)
test_dataset['判決'] = None
result_data = test_rag(rag_model, test_dataset)
result_data.to_excel("rag_result.xlsx", index=False)