<a href="https://colab.research.google.com/github/tubagokhan/OblgationClassfier/blob/main/ObligationClassfierV2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [29]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [30]:
!gdown 1-Z7rDtJLh0m6QljxtfXSV4QdXRYbuxSf
#!gdown 1-1L2_l6q01aD7Zsw_XMeVQfKtwP031Y9
!gdown 1vJeCleL-o3m0TCTiKWNOxx4Gx2ZQOHwJ

Downloading...
From: https://drive.google.com/uc?id=1-Z7rDtJLh0m6QljxtfXSV4QdXRYbuxSf
To: /content/extended_obligation_classification_dataset.csv
100% 14.3M/14.3M [00:00<00:00, 178MB/s]
Downloading...
From: https://drive.google.com/uc?id=1vJeCleL-o3m0TCTiKWNOxx4Gx2ZQOHwJ
To: /content/annotated_obligation_classification_data_pure_regulations.csv
100% 1.69M/1.69M [00:00<00:00, 208MB/s]


In [31]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, get_linear_schedule_with_warmup, DistilBertConfig
import pandas as pd
from tqdm import tqdm
from torch.optim import AdamW
from sklearn.metrics import accuracy_score
import numpy as np
from torch.cuda.amp import GradScaler, autocast

# Load data
data = pd.read_csv('./annotated_obligation_classification_data_pure_regulations.csv')
data2 = pd.read_csv('./extended_obligation_classification_dataset.csv')

# Merge datasets and remove duplicates
#data = pd.concat([data, data2]).drop_duplicates().reset_index(drop=True)

# Ensure labels are numeric
data['label'] = data['label'].map({'obligation': 1, 'non-obligation': 0})

# Shuffle the data
data = data.sample(frac=1, random_state=42).reset_index(drop=True)


# Split into training, validation, and test sets
train_size = int(0.8 * len(data))
val_size = int(0.1 * len(data))
test_size = len(data) - train_size - val_size

train_data, val_data, test_data = np.split(data.sample(frac=1, random_state=42), [train_size, train_size + val_size])

# Print the lengths of the datasets to ensure they are not empty
print(f"Training set size: {len(train_data)}")
print(f"Validation set size: {len(val_data)}")
print(f"Test set size: {len(test_data)}")

# Preprocess data
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
max_len = 128  # Increased sequence length for better context
batch_size = 64  # Increased batch size

train_dataset = TextDataset(train_data['text'].to_list(), train_data['label'].to_list(), tokenizer, max_len)
val_dataset = TextDataset(val_data['text'].to_list(), val_data['label'].to_list(), tokenizer, max_len)
test_dataset = TextDataset(test_data['text'].to_list(), test_data['label'].to_list(), tokenizer, max_len)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=4)

# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Modify DistilBERT configuration to include dropout
config = DistilBertConfig.from_pretrained('distilbert-base-uncased', num_labels=2, dropout=0.3, attention_dropout=0.3)
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', config=config)

if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)

model.to(device)

optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 10  # Increased number of epochs
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
criterion = torch.nn.CrossEntropyLoss()
scaler = GradScaler()

# Training function
def train(model, train_loader, criterion, optimizer, scheduler, scaler, patience=2):
    model.train()
    total_loss = 0
    best_loss = np.inf
    patience_counter = 0
    best_model = None

    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        epoch_loss = 0
        progress_bar = tqdm(train_loader, desc="Training")

        for batch in progress_bar:
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            with autocast():
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            epoch_loss += loss.item()
            progress_bar.set_postfix(batch_loss=loss.item())

        avg_loss = epoch_loss / len(train_loader)
        print(f'Average training loss: {avg_loss}')

        # Evaluate on validation set
        val_loss, val_acc = evaluate(model, val_loader, criterion)
        print(f'Validation loss: {val_loss}, Validation accuracy: {val_acc}')

        # Early stopping
        if val_loss < best_loss:
            best_loss = val_loss
            patience_counter = 0
            best_model = model.state_dict()  # Save the best model
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break

        total_loss += epoch_loss
    model.load_state_dict(best_model)  # Load the best model
    return total_loss / (epoch + 1)

# Evaluation function
def evaluate(model, val_loader, criterion):
    model.eval()
    val_loss = 0
    preds, true_labels = [], []

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            with autocast():
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                val_loss += loss.item()

            logits = outputs.logits
            preds.extend(torch.argmax(logits, dim=1).cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    avg_loss = val_loss / len(val_loader) if len(val_loader) > 0 else float('inf')
    accuracy = accuracy_score(true_labels, preds) if len(preds) > 0 else 0
    return avg_loss, accuracy

# Training loop
train(model, train_loader, criterion, optimizer, scheduler, scaler, patience=2)

# Save model
torch.save(model.state_dict(), '/content/drive/Othercomputers/MBZUAI/MBZUAI/ADGM-Project/MyRetrievals/obligation_classification_model.pt')


Training set size: 5812
Validation set size: 726
Test set size: 728


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/10


Training: 100%|██████████| 91/91 [00:18<00:00,  4.90it/s, batch_loss=0.129]

Average training loss: 0.3208320517461378





Validation loss: 0.21308319146434465, Validation accuracy: 0.9297520661157025
Epoch 2/10


Training: 100%|██████████| 91/91 [00:18<00:00,  4.96it/s, batch_loss=0.154]

Average training loss: 0.16275908118420904





Validation loss: 0.14991482704256973, Validation accuracy: 0.953168044077135
Epoch 3/10


Training: 100%|██████████| 91/91 [00:18<00:00,  5.03it/s, batch_loss=0.0562]

Average training loss: 0.1116926891545018





Validation loss: 0.14776965665320554, Validation accuracy: 0.9462809917355371
Epoch 4/10


Training: 100%|██████████| 91/91 [00:17<00:00,  5.12it/s, batch_loss=0.0425]


Average training loss: 0.07293538810623872
Validation loss: 0.16200864881587526, Validation accuracy: 0.9504132231404959
Epoch 5/10


Training: 100%|██████████| 91/91 [00:17<00:00,  5.14it/s, batch_loss=0.0534]

Average training loss: 0.04379971509615144





Validation loss: 0.21646877378225327, Validation accuracy: 0.9504132231404959
Early stopping triggered.
