In [None]:
!pip install transformers torch torchvision scikit-learn

In [None]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import BertTokenizer, BertForSequenceClassification, get_linear_schedule_with_warmup
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tqdm import tqdm
from sklearn.utils.class_weight import compute_class_weight
import torch.optim as optim

In [None]:
dataset_path = '/kaggle/input/legal-text-classification-dataset/legal_text_classification.csv'
df = pd.read_csv(dataset_path)

df.dropna(inplace=True)

label_encoder = LabelEncoder()
df['case_outcome_encoded'] = label_encoder.fit_transform(df['case_outcome'])

class_counts = df['case_outcome_encoded'].value_counts()
print("Class distribution in training data:\n", class_counts)

In [None]:
train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['case_outcome_encoded'], random_state=42)

majority_class_size = train_df['case_outcome_encoded'].value_counts().max()
train_df_balanced = train_df.groupby('case_outcome_encoded', group_keys=False)\
                            .apply(lambda x: x.sample(majority_class_size, replace=True)).reset_index(drop=True)

balanced_class_counts = train_df_balanced['case_outcome_encoded'].value_counts()
print("Class distribution after oversampling:\n", balanced_class_counts)

In [None]:
class LegalDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        case_text = str(self.data.iloc[index]['case_text'])
        case_outcome = self.data.iloc[index]['case_outcome_encoded']

        encoding = self.tokenizer.encode_plus(
            case_text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(case_outcome, dtype=torch.long)
        }



In [None]:
legalbert_model_name = 'nlpaueb/legal-bert-base-uncased'

tokenizer = BertTokenizer.from_pretrained(legalbert_model_name)
model = BertForSequenceClassification.from_pretrained(legalbert_model_name, num_labels=len(label_encoder.classes_))

In [None]:
max_length = 256
batch_size = 16
epochs = 6
learning_rate = 2e-5
weight_decay = 0.01

In [None]:
train_dataset = LegalDataset(train_df_balanced, tokenizer, max_length)
val_dataset = LegalDataset(val_df, tokenizer, max_length)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

In [None]:
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_df_balanced['case_outcome_encoded']), y=train_df_balanced['case_outcome_encoded'])
class_weights = torch.tensor(class_weights, dtype=torch.float)

loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

In [None]:
def train_epoch(model, data_loader, optimizer, device, scheduler):
    model = model.train()
    losses = []
    correct_predictions = 0

    for d in tqdm(data_loader):
        input_ids = d['input_ids'].to(device)
        attention_mask = d['attention_mask'].to(device)
        labels = d['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        logits = outputs.logits

        loss.backward()
        optimizer.step()
        scheduler.step()

        losses.append(loss.item())
        preds = torch.argmax(logits, dim=1)
        correct_predictions += torch.sum(preds == labels).item()

    return correct_predictions / len(data_loader.dataset), np.mean(losses)


In [None]:
def eval_model(model, data_loader, device):
    model = model.eval()
    losses = []
    correct_predictions = 0
    y_preds = []
    y_true = []

    with torch.no_grad():
        for d in tqdm(data_loader):
            input_ids = d['input_ids'].to(device)
            attention_mask = d['attention_mask'].to(device)
            labels = d['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            logits = outputs.logits

            losses.append(loss.item())
            preds = torch.argmax(logits, dim=1)
            correct_predictions += torch.sum(preds == labels).item()
            y_preds.extend(preds.cpu().numpy())
            y_true.extend(labels.cpu().numpy())

    report = classification_report(y_true, y_preds, target_names=label_encoder.classes_)
    print(report)
    return correct_predictions / len(data_loader.dataset), np.mean(losses)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

for epoch in range(epochs):
    print(f'Epoch {epoch+1}/{epochs}')
    print('-' * 20)

    train_acc, train_loss = train_epoch(model, train_loader, optimizer, device, scheduler)
    print(f'Train loss {train_loss}, accuracy {train_acc}')

    val_acc, val_loss = eval_model(model, val_loader, device)
    print(f'Validation loss {val_loss}, accuracy {val_acc}')
