In [1]:
!pip install transformers torch torchvision scikit-learn



In [2]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import BertTokenizer, BertForSequenceClassification, get_linear_schedule_with_warmup
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tqdm import tqdm
from sklearn.utils.class_weight import compute_class_weight
import torch.optim as optim

In [3]:
dataset_path = '/kaggle/input/legal-text-classification-dataset/legal_text_classification.csv'
df = pd.read_csv(dataset_path)

df.dropna(inplace=True)

label_encoder = LabelEncoder()
df['case_outcome_encoded'] = label_encoder.fit_transform(df['case_outcome'])

class_counts = df['case_outcome_encoded'].value_counts()
print("Class distribution in training data:\n", class_counts)

Class distribution in training data:
 case_outcome_encoded
3    12110
8     4363
1     2438
7     2252
4     1699
5     1018
6      603
9      112
2      108
0      106
Name: count, dtype: int64


In [4]:
train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['case_outcome_encoded'], random_state=42)

majority_class_size = train_df['case_outcome_encoded'].value_counts().max()
train_df_balanced = train_df.groupby('case_outcome_encoded', group_keys=False)\
                            .apply(lambda x: x.sample(majority_class_size, replace=True)).reset_index(drop=True)

balanced_class_counts = train_df_balanced['case_outcome_encoded'].value_counts()
print("Class distribution after oversampling:\n", balanced_class_counts)

Class distribution after oversampling:
 case_outcome_encoded
0    9688
1    9688
2    9688
3    9688
4    9688
5    9688
6    9688
7    9688
8    9688
9    9688
Name: count, dtype: int64


  .apply(lambda x: x.sample(majority_class_size, replace=True)).reset_index(drop=True)


In [5]:
class LegalDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        case_text = str(self.data.iloc[index]['case_text'])
        case_outcome = self.data.iloc[index]['case_outcome_encoded']

        encoding = self.tokenizer.encode_plus(
            case_text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(case_outcome, dtype=torch.long)
        }



In [6]:
legalbert_model_name = 'nlpaueb/legal-bert-base-uncased'

tokenizer = BertTokenizer.from_pretrained(legalbert_model_name)
model = BertForSequenceClassification.from_pretrained(legalbert_model_name, num_labels=len(label_encoder.classes_))

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/222k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]



pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
max_length = 256
batch_size = 16
epochs = 6
learning_rate = 2e-5
weight_decay = 0.01

In [8]:
train_dataset = LegalDataset(train_df_balanced, tokenizer, max_length)
val_dataset = LegalDataset(val_df, tokenizer, max_length)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

In [9]:
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_df_balanced['case_outcome_encoded']), y=train_df_balanced['case_outcome_encoded'])
class_weights = torch.tensor(class_weights, dtype=torch.float)

loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)

In [10]:
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

In [11]:
def train_epoch(model, data_loader, optimizer, device, scheduler):
    model = model.train()
    losses = []
    correct_predictions = 0

    for d in tqdm(data_loader):
        input_ids = d['input_ids'].to(device)
        attention_mask = d['attention_mask'].to(device)
        labels = d['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        logits = outputs.logits

        loss.backward()
        optimizer.step()
        scheduler.step()

        losses.append(loss.item())
        preds = torch.argmax(logits, dim=1)
        correct_predictions += torch.sum(preds == labels).item()

    return correct_predictions / len(data_loader.dataset), np.mean(losses)


In [12]:
def eval_model(model, data_loader, device):
    model = model.eval()
    losses = []
    correct_predictions = 0
    y_preds = []
    y_true = []

    with torch.no_grad():
        for d in tqdm(data_loader):
            input_ids = d['input_ids'].to(device)
            attention_mask = d['attention_mask'].to(device)
            labels = d['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            logits = outputs.logits

            losses.append(loss.item())
            preds = torch.argmax(logits, dim=1)
            correct_predictions += torch.sum(preds == labels).item()
            y_preds.extend(preds.cpu().numpy())
            y_true.extend(labels.cpu().numpy())

    report = classification_report(y_true, y_preds, target_names=label_encoder.classes_)
    print(report)
    return correct_predictions / len(data_loader.dataset), np.mean(losses)


In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

for epoch in range(epochs):
    print(f'Epoch {epoch+1}/{epochs}')
    print('-' * 20)

    train_acc, train_loss = train_epoch(model, train_loader, optimizer, device, scheduler)
    print(f'Train loss {train_loss}, accuracy {train_acc}')

    val_acc, val_loss = eval_model(model, val_loader, device)
    print(f'Validation loss {val_loss}, accuracy {val_acc}')


Epoch 1/6
--------------------


100%|██████████| 6055/6055 [1:13:44<00:00,  1.37it/s]


Train loss 0.8601962563760954, accuracy 0.6860239471511148


100%|██████████| 311/311 [01:59<00:00,  2.61it/s]


               precision    recall  f1-score   support

     affirmed       0.48      0.67      0.56        21
      applied       0.30      0.28      0.29       488
     approved       0.25      0.19      0.22        21
        cited       0.73      0.45      0.56      2422
   considered       0.20      0.45      0.28       340
    discussed       0.27      0.34      0.30       204
distinguished       0.28      0.52      0.37       121
     followed       0.26      0.46      0.33       450
  referred to       0.45      0.46      0.45       873
      related       0.52      0.59      0.55        22

     accuracy                           0.43      4962
    macro avg       0.37      0.44      0.39      4962
 weighted avg       0.52      0.43      0.46      4962

Validation loss 1.7857638886887162, accuracy 0.4349052801289803
Epoch 2/6
--------------------


100%|██████████| 6055/6055 [1:13:46<00:00,  1.37it/s]


Train loss 0.2976536213950209, accuracy 0.9038088356729975


100%|██████████| 311/311 [01:58<00:00,  2.62it/s]


               precision    recall  f1-score   support

     affirmed       0.47      0.71      0.57        21
      applied       0.36      0.32      0.34       488
     approved       0.18      0.24      0.20        21
        cited       0.68      0.69      0.69      2422
   considered       0.35      0.34      0.34       340
    discussed       0.30      0.29      0.29       204
distinguished       0.41      0.40      0.40       121
     followed       0.41      0.36      0.38       450
  referred to       0.51      0.50      0.51       873
      related       0.15      0.73      0.25        22

     accuracy                           0.54      4962
    macro avg       0.38      0.46      0.40      4962
 weighted avg       0.54      0.54      0.54      4962

Validation loss 1.8620954093154987, accuracy 0.5421201128577187
Epoch 3/6
--------------------


100%|██████████| 6055/6055 [1:13:50<00:00,  1.37it/s]


Train loss 0.18408205295104646, accuracy 0.9413398018166804


100%|██████████| 311/311 [01:58<00:00,  2.62it/s]


               precision    recall  f1-score   support

     affirmed       0.50      0.62      0.55        21
      applied       0.33      0.34      0.34       488
     approved       0.18      0.24      0.20        21
        cited       0.71      0.65      0.67      2422
   considered       0.36      0.31      0.33       340
    discussed       0.31      0.34      0.32       204
distinguished       0.45      0.40      0.42       121
     followed       0.43      0.35      0.39       450
  referred to       0.45      0.60      0.51       873
      related       0.50      0.59      0.54        22

     accuracy                           0.54      4962
    macro avg       0.42      0.44      0.43      4962
 weighted avg       0.55      0.54      0.54      4962

Validation loss 1.9887617013845413, accuracy 0.5376864167674325
Epoch 4/6
--------------------


100%|██████████| 6055/6055 [1:13:52<00:00,  1.37it/s]


Train loss 0.13618526202661133, accuracy 0.9537365813377374


100%|██████████| 311/311 [01:58<00:00,  2.62it/s]


               precision    recall  f1-score   support

     affirmed       0.47      0.67      0.55        21
      applied       0.34      0.41      0.37       488
     approved       0.16      0.24      0.19        21
        cited       0.73      0.63      0.68      2422
   considered       0.30      0.40      0.34       340
    discussed       0.29      0.28      0.28       204
distinguished       0.47      0.40      0.43       121
     followed       0.40      0.41      0.40       450
  referred to       0.48      0.55      0.51       873
      related       0.67      0.64      0.65        22

     accuracy                           0.54      4962
    macro avg       0.43      0.46      0.44      4962
 weighted avg       0.56      0.54      0.55      4962

Validation loss 2.1760138275155683, accuracy 0.5362756952841596
Epoch 5/6
--------------------


100%|██████████| 6055/6055 [1:13:51<00:00,  1.37it/s]


Train loss 0.11113160056562696, accuracy 0.9590730800990916


100%|██████████| 311/311 [02:00<00:00,  2.59it/s]


               precision    recall  f1-score   support

     affirmed       0.65      0.62      0.63        21
      applied       0.34      0.38      0.36       488
     approved       0.29      0.19      0.23        21
        cited       0.70      0.66      0.68      2422
   considered       0.34      0.33      0.34       340
    discussed       0.31      0.34      0.32       204
distinguished       0.45      0.36      0.40       121
     followed       0.41      0.37      0.39       450
  referred to       0.48      0.56      0.51       873
      related       0.54      0.64      0.58        22

     accuracy                           0.54      4962
    macro avg       0.45      0.44      0.44      4962
 weighted avg       0.55      0.54      0.54      4962

Validation loss 2.241248156767566, accuracy 0.5411124546553809
Epoch 6/6
--------------------


100%|██████████| 6055/6055 [1:13:48<00:00,  1.37it/s]


Train loss 0.09644658340290758, accuracy 0.9613645747316267


100%|██████████| 311/311 [01:58<00:00,  2.62it/s]

               precision    recall  f1-score   support

     affirmed       0.48      0.62      0.54        21
      applied       0.37      0.31      0.34       488
     approved       0.18      0.24      0.20        21
        cited       0.69      0.68      0.68      2422
   considered       0.33      0.35      0.34       340
    discussed       0.30      0.33      0.32       204
distinguished       0.43      0.35      0.39       121
     followed       0.41      0.39      0.40       450
  referred to       0.48      0.55      0.51       873
      related       0.56      0.64      0.60        22

     accuracy                           0.54      4962
    macro avg       0.42      0.44      0.43      4962
 weighted avg       0.55      0.54      0.54      4962

Validation loss 2.3132892131230456, accuracy 0.5447400241837969



