In [1]:
# Install required libraries
!pip install transformers torch torchvision scikit-learn



In [2]:
# Import necessary libraries
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tqdm import tqdm

In [3]:
# Load and preprocess the dataset
dataset_path = '/kaggle/input/legal-text-classification-dataset/legal_text_classification.csv'
df = pd.read_csv(dataset_path)

# Check for missing values and handle them
df.dropna(inplace=True)

# Encode labels
label_encoder = LabelEncoder()
df['case_outcome_encoded'] = label_encoder.fit_transform(df['case_outcome'])

# Split the dataset into training and validation sets
train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['case_outcome_encoded'], random_state=42)

In [4]:
# Define the custom dataset class
class LegalDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        case_text = str(self.data.iloc[index]['case_text'])
        case_outcome = self.data.iloc[index]['case_outcome_encoded']

        encoding = self.tokenizer.encode_plus(
            case_text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(case_outcome, dtype=torch.long)
        }

In [5]:
# Initialize tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(label_encoder.classes_))

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# Define hyperparameters
max_length = 256
batch_size = 16
epochs = 6
learning_rate = 2e-5

# Create datasets and dataloaders
train_dataset = LegalDataset(train_df, tokenizer, max_length)
val_dataset = LegalDataset(val_df, tokenizer, max_length)

train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=WeightedRandomSampler(
    weights=torch.tensor([1.0 / count for count in np.bincount(train_df['case_outcome_encoded'])], dtype=torch.float),
    num_samples=len(train_dataset),
    replacement=True
))

val_loader = DataLoader(val_dataset, batch_size=batch_size)

In [7]:
# Set up optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=learning_rate, correct_bias=False)
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)



In [8]:
# Define the training and evaluation functions
def train_epoch(model, data_loader, optimizer, device, scheduler):
    model = model.train()
    losses = []
    correct_predictions = 0

    for d in tqdm(data_loader):
        input_ids = d['input_ids'].to(device)
        attention_mask = d['attention_mask'].to(device)
        labels = d['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        logits = outputs.logits

        loss.backward()
        optimizer.step()
        scheduler.step()

        losses.append(loss.item())
        preds = torch.argmax(logits, dim=1)
        correct_predictions += torch.sum(preds == labels).item()

    return correct_predictions / len(data_loader.dataset), np.mean(losses)

def eval_model(model, data_loader, device):
    model = model.eval()
    losses = []
    correct_predictions = 0
    y_preds = []
    y_true = []

    with torch.no_grad():
        for d in tqdm(data_loader):
            input_ids = d['input_ids'].to(device)
            attention_mask = d['attention_mask'].to(device)
            labels = d['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            logits = outputs.logits

            losses.append(loss.item())
            preds = torch.argmax(logits, dim=1)
            correct_predictions += torch.sum(preds == labels).item()

            y_preds.extend(preds.cpu().numpy())
            y_true.extend(labels.cpu().numpy())

    return correct_predictions / len(data_loader.dataset), np.mean(losses), y_preds, y_true

In [9]:
# Training loop with early stopping
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

best_accuracy = 0
patience = 2
epochs_without_improvement = 0

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    
    train_acc, train_loss = train_epoch(model, train_loader, optimizer, device, scheduler)
    val_acc, val_loss, y_preds, y_true = eval_model(model, val_loader, device)
    
    print(f"Train loss: {train_loss}, Train accuracy: {train_acc}")
    print(f"Validation loss: {val_loss}, Validation accuracy: {val_acc}")
    
    # Check for early stopping
    if val_acc > best_accuracy:
        best_accuracy = val_acc
        epochs_without_improvement = 0
        torch.save(model.state_dict(), 'best_model.pt')
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            print("Early stopping")
            break

Epoch 1/6


100%|██████████| 1241/1241 [14:03<00:00,  1.47it/s]
100%|██████████| 311/311 [01:59<00:00,  2.59it/s]


Train loss: 0.016443010778609126, Train accuracy: 0.997480727565879
Validation loss: 4.234352544189649, Validation accuracy: 0.4419588875453446
Epoch 2/6


100%|██████████| 1241/1241 [14:03<00:00,  1.47it/s]
100%|██████████| 311/311 [02:00<00:00,  2.59it/s]


Train loss: 0.00047968278067636003, Train accuracy: 1.0
Validation loss: 4.613440788443832, Validation accuracy: 0.44397420395002013
Epoch 3/6


100%|██████████| 1241/1241 [13:58<00:00,  1.48it/s]
100%|██████████| 311/311 [01:59<00:00,  2.59it/s]


Train loss: 0.00024890176737581593, Train accuracy: 1.0
Validation loss: 4.8754301979610775, Validation accuracy: 0.44578798871422814
Epoch 4/6


100%|██████████| 1241/1241 [14:03<00:00,  1.47it/s]
100%|██████████| 311/311 [02:00<00:00,  2.59it/s]


Train loss: 0.00015642073875480564, Train accuracy: 1.0
Validation loss: 5.081053636465042, Validation accuracy: 0.44740024183796856
Epoch 5/6


100%|██████████| 1241/1241 [14:02<00:00,  1.47it/s]
100%|██████████| 311/311 [01:59<00:00,  2.60it/s]


Train loss: 0.00010782079853477312, Train accuracy: 1.0
Validation loss: 5.231888178077158, Validation accuracy: 0.44740024183796856
Epoch 6/6


100%|██████████| 1241/1241 [13:56<00:00,  1.48it/s]
100%|██████████| 311/311 [01:59<00:00,  2.60it/s]

Train loss: 8.556773154572325e-05, Train accuracy: 1.0
Validation loss: 5.295195315811795, Validation accuracy: 0.44740024183796856
Early stopping





In [10]:
# Load the best model and evaluate
model.load_state_dict(torch.load('best_model.pt'))
val_acc, val_loss, y_preds, y_true = eval_model(model, val_loader, device)
print("Final evaluation")
print(classification_report(y_true, y_preds, target_names=label_encoder.classes_))

100%|██████████| 311/311 [01:59<00:00,  2.60it/s]

Final evaluation
               precision    recall  f1-score   support

     affirmed       0.00      0.00      0.00        21
      applied       0.06      0.00      0.01       488
     approved       0.00      0.00      0.00        21
        cited       0.49      0.87      0.62      2422
   considered       0.00      0.00      0.00       340
    discussed       0.00      0.00      0.00       204
distinguished       0.00      0.00      0.00       121
     followed       0.00      0.00      0.00       450
  referred to       0.19      0.13      0.15       873
      related       0.00      0.00      0.00        22

     accuracy                           0.45      4962
    macro avg       0.07      0.10      0.08      4962
 weighted avg       0.28      0.45      0.33      4962




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
