## Data preprocessing

In [1]:
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from torch.optim.lr_scheduler import StepLR
from torch.optim import Adam
from tqdm import tqdm
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from collections import defaultdict
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_df = pd.read_csv('./Data/train.csv')
test_df = pd.read_csv('./Data/test.csv')

train_df['fact_with_parties'] = 'First Party: ' + train_df['first_party'] + ', Second Party: ' + train_df['second_party'] + ', Legal Fact: ' + train_df['facts']
test_df['fact_with_parties'] = 'First Party: ' + test_df['first_party'] + ', Second Party: ' + test_df['second_party'] + ', Legal Fact: ' + test_df['facts']

train_facts = train_df['fact_with_parties'].tolist()
train_labels = train_df['first_party_winner'].astype(int).tolist()

test_facts = test_df['fact_with_parties'].tolist()

train_facts, val_facts, train_labels, val_labels = train_test_split(train_facts, train_labels, test_size=0.2, random_state=1004)

In [3]:
class LegalCaseDataset(Dataset):
    def __init__(self, facts, labels, tokenizer, max_length):
        self.facts = facts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.facts)
    
    def __getitem__(self, idx):
        fact = self.facts[idx]
        label = self.labels[idx]
        
        # encode_plus: 토커나이징 + 인코딩
        encoding = self.tokenizer.encode_plus(
            fact,
            add_special_tokens=True, # [CLS], [SEP] 토큰 추가
            max_length=self.max_length,
            padding='max_length', # max_length보다 짧은 문장을 padding
            return_token_type_ids=False, # 토큰 타입 id는 사용하지 않음
            truncation=True, # max_length보다 긴 문장을 자르는 옵션
            return_attention_mask=True, # 어텐션 마스크 생성
            return_tensors='pt',
        )
        
        return {
            'fact_text': fact,
            'input_ids': encoding['input_ids'].flatten(), # encoding에서 input_ids을 flatten
            'attention_mask': encoding['attention_mask'].flatten(), # encoding에서 attention_mask을 flatten
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [4]:
class LegalCaseDatasetWithoutLabels(Dataset):
    def __init__(self, facts, tokenizer, max_length):
        self.facts = facts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.facts)

    def __getitem__(self, idx):
        fact = str(self.facts[idx])
        encoding = self.tokenizer.encode_plus(
            fact,
            add_special_tokens=True, # [CLS], [SEP] 토큰 추가
            max_length=self.max_length,
            padding='max_length', # max_length보다 짧은 문장을 padding
            return_token_type_ids=False, # 토큰 타입 id는 사용하지 않음
            truncation=True, # max_length보다 긴 문장을 자르는 옵션
            return_attention_mask=True, # 어텐션 마스크 생성
            return_tensors='pt',
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
        }

In [5]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')

tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForSequenceClassification.from_pretrained('roberta-base').to(device)

optimizer = Adam(model.parameters(), lr=1e-5)
scheduler = StepLR(optimizer, step_size=1, gamma=0.1)
loss_fn = nn.CrossEntropyLoss().to(device)
epochs = 20

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.dense.weight']
You should pr

In [6]:
train_dataset = LegalCaseDataset(train_facts, train_labels, tokenizer, max_length=512)
val_dataset = LegalCaseDataset(val_facts, val_labels, tokenizer, max_length=512)

train_data_loader = DataLoader(train_dataset, batch_size=32, num_workers=2)
val_data_loader = DataLoader(val_dataset, batch_size=32, num_workers=2)

In [7]:
test_dataset = LegalCaseDatasetWithoutLabels(test_facts, tokenizer, max_length=512)

test_data_loader = DataLoader(test_dataset, batch_size=32)

In [8]:
def train_model(model, data_loader, loss_fn, optimizer, device, scheduler, n_examples):
    
    model = model.train()
    losses = []
    correct_predictions = 0
    
    for d in tqdm(data_loader):
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        targets = d["labels"].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        _, preds = torch.max(outputs.logits, dim=1)
        loss = loss_fn(outputs.logits, targets)

        correct_predictions += torch.sum(preds == targets)
        losses.append(loss.item())

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # gradient clipping: gradient exploding 방지
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
    return correct_predictions.double() / n_examples, np.mean(losses)

In [9]:
def eval_model(model, data_loader, loss_fn, device, n_examples):
    model = model.eval()
    losses = []
    correct_predictions = 0
    
    with torch.no_grad():
        for d in data_loader:
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            targets = d["labels"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            _, preds = torch.max(outputs.logits, dim=1)
            loss = loss_fn(outputs.logits, targets)

            correct_predictions += torch.sum(preds == targets)
            losses.append(loss.item())

    return correct_predictions.double() / n_examples, np.mean(losses)

In [10]:
def get_predictions(model, data_loader, device):
    model = model.eval()
    predictions = []
    prediction_probs = []
    real_values = []
    
    with torch.no_grad():
        for d in data_loader:
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            targets = d["labels"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            _, preds = torch.max(outputs.logits, dim=1)

            predictions.extend(preds)
            prediction_probs.extend(outputs.logits)
            real_values.extend(targets)
            
    predictions = torch.stack(predictions).cpu()
    prediction_probs = torch.stack(prediction_probs).cpu()
    real_values = torch.stack(real_values).cpu()
    
    return predictions, prediction_probs, real_values

In [11]:
history = defaultdict(list)
best_accuracy = 0

for epoch in range(epochs):
    print(f'Epoch {epoch + 1}/{epochs}')
    print('-' * 10)

    train_acc, train_loss = train_model(
        model,
        train_data_loader,
        loss_fn,
        optimizer,
        device,
        scheduler,
        len(train_dataset)
    )

    print(f'Train loss {train_loss} accuracy {train_acc}')

    val_acc, val_loss = eval_model(
        model,
        val_data_loader,
        loss_fn,
        device,
        len(val_dataset)
    )

    print(f'Val   loss {val_loss} accuracy {val_acc}')
    print()

    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_acc'].append(val_acc)
    history['val_loss'].append(val_loss)

    if val_acc > best_accuracy:
        torch.save(model.state_dict(), 'best_model_state.bin')
        best_accuracy = val_acc

Epoch 1/20
----------


  0%|          | 0/62 [00:00<?, ?it/s]

100%|██████████| 62/62 [00:43<00:00,  1.42it/s]

Train loss 0.6741956885783903 accuracy 0.6579212916246217





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 2/20
----------


100%|██████████| 62/62 [00:42<00:00,  1.44it/s]

Train loss 0.6738038476436369 accuracy 0.656912209889001





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 3/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.44it/s]

Train loss 0.6740864505690913 accuracy 0.6528758829465187





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 4/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.44it/s]

Train loss 0.6746679438698676 accuracy 0.6508577194752775





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 5/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.44it/s]

Train loss 0.6750546914915885 accuracy 0.6533804238143289





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 6/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.43it/s]

Train loss 0.6766791007211131 accuracy 0.6528758829465187





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 7/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.43it/s]

Train loss 0.6741413156832418 accuracy 0.649848637739657





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 8/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.43it/s]

Train loss 0.6732426737585375 accuracy 0.6579212916246217





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 9/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.43it/s]

Train loss 0.6729510090043468 accuracy 0.656912209889001





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 10/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.43it/s]

Train loss 0.6734569457269484 accuracy 0.6533804238143289





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 11/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.43it/s]

Train loss 0.6722271875027688 accuracy 0.6599394550958628





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 12/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.43it/s]

Train loss 0.6751634170932155 accuracy 0.6518668012108981





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 13/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.43it/s]

Train loss 0.6743647994533661 accuracy 0.6564076690211907





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 14/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.43it/s]

Train loss 0.6774291117345134 accuracy 0.6508577194752775





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 15/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.44it/s]

Train loss 0.6724115408235981 accuracy 0.6604439959636731





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 16/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.43it/s]

Train loss 0.674364649480389 accuracy 0.6528758829465187





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 17/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.43it/s]

Train loss 0.6740090568219462 accuracy 0.6478304742684158





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 18/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.43it/s]

Train loss 0.6758202477808921 accuracy 0.648335015136226





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 19/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.43it/s]

Train loss 0.6730416119098663 accuracy 0.6453077699293643





Val   loss 0.6756043061614037 accuracy 0.657258064516129

Epoch 20/20
----------


100%|██████████| 62/62 [00:43<00:00,  1.44it/s]

Train loss 0.673704965460685 accuracy 0.656912209889001





Val   loss 0.6756043061614037 accuracy 0.657258064516129

