In [None]:
import transformers
import torch
import torch.nn as nn
import transformers
from transformers import DataCollatorForLanguageModeling
from transformers import LongformerForSequenceClassification, LongformerTokenizer
from torch.utils.data import Dataset, DataLoader
from transformers import Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
import json
import os
import random

In [None]:
#load device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# n_gpu = torch.cuda.device_count()
# torch.cuda.get_device_name(0)

In [None]:
data_folder = "E:/personal/Code/Python/LegalRetrieval/data/longformer_out/"
data_info_file = "E:/personal/Code/Python/LegalRetrieval/data/full_data_labels.json"

list_skipped_words = ['the', 'a', 'an', 'in', 'on', 'at', 'to', 'of', 'for', 'with', 'by', 'and', 'or', 'but', 'so', 'nor', 'yet', 'from', 'into', 'onto', 'upon', 'out', 'off', 'over', 'under', 'below', 'above', 'between', 'among', 'through', 'during', 'before', 'after', 'since', 'until', 'while', 'as', 'like', 'about', 'against', 'among', 'around', 'before', 'behind', 'beneath', 'beside', 'between', 'beyond', 'during', 'inside', 'outside', 'underneath', 'within', 'without', 'throughout', 'along', 'across', 'toward', 'towards', 'up', 'down', 'forward', 'backward', 'right', 'left', 'here', 'there', 'where', 'when', 'why', 'how', 'what', 'which', 'who', 'whom', 'whose', 'whichever', 'whatever', 'whomever', 'whenever', 'wherever', 'however', 'whyever', ',', ';']
def remove_stopwords(text):
    word_tokens = text.split()
    filtered_text = [word for word in word_tokens if word not in list_skipped_words]
    # return a string
    return " ".join(filtered_text)

case_law = {}
for file in os.listdir(data_folder):
    with open(os.path.join(data_folder, file), "r", encoding="utf-8") as f:
        data = json.load(f)
        case_id = file.split(".")[0]
        full_text = data["meta"] + ' '
        for par in data["paragraphs"]:
            full_text += par + ' '
        full_text = remove_stopwords(full_text)
        case_law[case_id] = full_text

data_info = json.load(open(data_info_file, "r"))
pairs = []
for query in data_info:
    query_text = case_law[query]
    candidates = data_info[query]
    for candidate in candidates:
        candidate_text = case_law[candidate]
        pairs.append({"text": [query_text, candidate_text], "label": candidates[candidate]})

In [None]:
pairs[0]

In [None]:
class CustomDataset(Dataset):
    def __init__(self, pairs, tokenizer, max_length):
        self.pairs = pairs
        self.tokenizer = tokenizer
        self.seq_len = max_length
        self.num_pairs = len(pairs)

    def __len__(self):
        return self.num_pairs
    
    def __getitem__(self, idx):
        # combine the query and candidate text in pairs[idx]["text"]
        text = self.pairs[idx]["text"]
        label = self.pairs[idx]["label"]
        # encode the text, text is the list of query and candidate text
        encoding = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.seq_len, return_tensors="pt")
        # return the encoded text and the label
        return {"input_ids": encoding["input_ids"].squeeze(), "attention_mask": encoding["attention_mask"].squeeze(), "labels": torch.tensor(label)}

In [None]:
# seperate data into train and test
train_pairs, test_pairs = train_test_split(pairs, test_size=0.2)
tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
max_length = 4096
train_dataset = CustomDataset(train_pairs, tokenizer, max_length)
test_dataset = CustomDataset(test_pairs, tokenizer, max_length)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=True)

In [None]:
result = train_dataset[random.randrange(len(train_dataset))]
print(result)

In [None]:
model = LongformerForSequenceClassification.from_pretrained('allenai/longformer-base-4096', num_labels=2)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
for epoch in range(3):
    total_loss = 0
    for batch in train_loader:
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(**batch)
        loss = outputs.loss
        
        # Backward pass
        loss.backward()
        
        # Update weights
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch: {epoch}, Loss: {total_loss/len(train_loader)}")
    