In [1]:
import nltk
import torch
from torch.utils.data import Dataset
from nltk.tokenize import sent_tokenize, word_tokenize
import torch.nn as nn
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from torchcrf import CRF
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, matthews_corrcoef, cohen_kappa_score
import ast
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module="transformers.convert_slow_tokenizer")
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [2]:
def preprocess_boundaries(df):
    def adjust_boundaries(row):
        text = str(row['hybrid_text'])
        sentences = sent_tokenize(text)
        num_sentences = len(sentences)
        boundaries = row['boundary_ix']
        if isinstance(boundaries, str):
            try:
                boundaries = ast.literal_eval(boundaries)
            except (ValueError, SyntaxError):
                boundaries = [0]
        if not isinstance(boundaries, list):
            boundaries = [boundaries]
        boundaries = [min(int(b), num_sentences-1) for b in boundaries]
        return boundaries

    df['boundary_ix'] = df.apply(adjust_boundaries, axis=1)
    return df

In [3]:
class MixedTextDataset(Dataset):
    def __init__(self, texts, labels, author_seqs, tokenizer):
        self.texts = texts
        self.labels = labels
        self.author_seqs = author_seqs
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        author_seq = self.author_seqs[idx]
        boundaries = self.labels[idx]
        if not isinstance(boundaries, list):
            boundaries = [boundaries]
        boundaries = [int(b) for b in boundaries]
        sentences = sent_tokenize(text)
        word_labels = []
        words = []
        for sent_idx, sent in enumerate(sentences):
            current_words = word_tokenize(sent)
            words.extend(current_words)
            label = self._get_label_for_sentence(sent_idx, author_seq, boundaries)
            word_labels.extend([label] * len(current_words))
        input_ids = []
        attention_mask = []
        token_labels = []
        for word, label in zip(words, word_labels):
            encoded_word = self.tokenizer(
                word,
                add_special_tokens=False,
                return_attention_mask=True,
                return_tensors="pt"
            )
            word_ids = encoded_word["input_ids"].squeeze()
            word_ids = [word_ids.item()] if word_ids.dim() == 0 else word_ids.tolist()
            input_ids.extend(word_ids)
            attention_mask.extend([1] * len(word_ids))
            token_labels.extend([label] * len(word_ids))
        input_ids = [self.tokenizer.cls_token_id] + input_ids + [self.tokenizer.sep_token_id]
        attention_mask = [1] + attention_mask + [1]
        token_labels = [-100] + token_labels + [-100]
        if any(l not in [-100, 0, 1] for l in token_labels):
            print(f"Invalid labels at index {idx}: {token_labels}")
        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "labels": torch.tensor(token_labels, dtype=torch.long)
        }

    def _get_label_for_sentence(self, sent_idx, author_seq, boundaries):
        if author_seq == 'M_H':
            return 1 if sent_idx < boundaries[0] else 0
        elif author_seq == 'H_M':
            return 0 if sent_idx < boundaries[0] else 1
        elif author_seq == 'H_M_H':
            if sent_idx < boundaries[0]:
                return 0
            elif sent_idx < boundaries[1]:
                return 1
            else:
                return 0
        elif author_seq == 'M_H_M':
            if sent_idx < boundaries[0]:
                return 1
            elif sent_idx < boundaries[1]:
                return 0
            else:
                return 1
        elif author_seq == 'H_M_H_M':
            if sent_idx < boundaries[0]:
                return 0
            elif sent_idx < boundaries[1]:
                return 1
            elif sent_idx < boundaries[2]:
                return 0
            else:
                return 1
        elif author_seq == 'M_H_M_H':
            if sent_idx < boundaries[0]:
                return 1
            elif sent_idx < boundaries[1]:
                return 0
            elif sent_idx < boundaries[2]:
                return 1
            else:
                return 0
        else:
            raise ValueError(f"Unknown author_seq: {author_seq}")


In [4]:
class BiGRUCRFTagger(nn.Module):
    def __init__(self, input_dim, num_labels, embedding_dim=768, hidden_dim=512, num_layers=2, dropout=0.3):
        super(BiGRUCRFTagger, self).__init__()
        self.num_labels = num_labels
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.embed_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(
            input_size=embedding_dim,
            hidden_size=hidden_dim // 2, 
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )
        self.hidden2tag = nn.Linear(hidden_dim, num_labels)
        self.crf = CRF(num_labels, batch_first=True)
        for name, param in self.named_parameters():
            if 'weight' in name:
                if len(param.shape) > 1:
                    nn.init.xavier_uniform_(param)
                else:
                    nn.init.normal_(param, mean=0, std=0.01)
            elif 'bias' in name:
                nn.init.constant_(param, 0)

    def forward(self, input_ids, attention_mask, labels=None):
        if input_ids.numel() == 0:
            raise ValueError("Input tensor is empty.")
        embedded = self.embedding(input_ids)
        embedded = self.embed_dropout(embedded)
        gru_out, _ = self.gru(embedded)
        logits = self.hidden2tag(gru_out)
        if labels is not None:
            mask = attention_mask.bool()
            crf_labels = labels.clone()
            crf_labels[crf_labels == -100] = 0
            loss = -self.crf(logits, crf_labels, mask=mask, reduction='mean')
            return loss
        else:
            mask = attention_mask.bool()
            predictions = self.crf.decode(logits, mask=mask)
            return predictions

In [5]:
def train_model(model, data_loader, optimizer, scheduler, device, clip_value=1.0):
    model.train()
    total_loss = 0
    for batch in data_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        optimizer.zero_grad()
        loss = model(input_ids, attention_mask, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_value)
        optimizer.step()
        if scheduler:
            scheduler.step()
        total_loss += loss.item()
    return total_loss / len(data_loader)

def evaluate_model(model, data_loader, device):
    model.eval()
    all_predictions = []
    all_labels = []
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            mask = attention_mask.bool()
            predictions = model(input_ids, attention_mask)
            for pred_seq, label_seq, mask_seq in zip(predictions, labels, attention_mask):
                true_len = mask_seq.sum().item()
                valid_label = label_seq[:true_len].cpu().numpy()
                non_special_label_indices = (label_seq[:true_len] != -100).nonzero(as_tuple=True)[0]
                if non_special_label_indices.numel() == 0:
                    continue
                pred_seq_np = np.array(pred_seq)
                valid_pred = pred_seq_np[non_special_label_indices.cpu().numpy()]
                valid_label = label_seq[non_special_label_indices].cpu().numpy()
                all_predictions.extend(valid_pred)
                all_labels.extend(valid_label)
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    if len(all_labels) == 0:
        print("Warning: No valid tokens to evaluate. Returning default metrics.")
        return 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
    absolute_errors = np.abs(all_predictions - all_labels)
    mae = np.mean(absolute_errors)
    std_dev = np.std(absolute_errors)
    accuracy = accuracy_score(all_labels, all_predictions)
    precision = precision_score(all_labels, all_predictions, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_predictions, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_predictions, average='weighted', zero_division=0)
    mcc = matthews_corrcoef(all_labels, all_predictions)
    kappa = cohen_kappa_score(all_labels, all_predictions)
    return accuracy, precision, recall, f1, mcc, mae, std_dev, kappa

def custom_collate_fn(batch):
    max_len = max(len(item["input_ids"]) for item in batch)
    input_ids_batch = []
    attention_mask_batch = []
    labels_batch = []
    for item in batch:
        seq_len = len(item["input_ids"])
        pad_len = max_len - seq_len
        input_ids_batch.append(torch.cat([item["input_ids"], torch.tensor([0] * pad_len, dtype=torch.long)]))
        attention_mask_batch.append(torch.cat([item["attention_mask"], torch.tensor([0] * pad_len, dtype=torch.long)]))
        labels_batch.append(torch.cat([item["labels"], torch.tensor([-100] * pad_len, dtype=torch.long)]))
    return {
        "input_ids": torch.stack(input_ids_batch),
        "attention_mask": torch.stack(attention_mask_batch),
        "labels": torch.stack(labels_batch)
    }


In [6]:
config = {
    'batch_size': 16,
    'learning_rate': 1e-3,
    'hidden_dim': 512,
    'embedding_dim': 768,
    'num_layers': 2,
    'dropout': 0.3,
    'epochs': 3,
    'weight_decay': 0.01,
    'gradient_clip': 1.0
}

In [7]:
train_df = pd.read_csv('aaai_train.csv')
dev_df = pd.read_csv('aaai_valid.csv')
test_df = pd.read_csv('aaai_test.csv')

train_df = preprocess_boundaries(train_df)
dev_df = preprocess_boundaries(dev_df)
test_df = preprocess_boundaries(test_df)

def get_author_seqs(df, default_seq='H_M'):
    if 'author_seq' not in df.columns:
        print(f"Warning: 'author_seq' column missing in dataframe. Using default sequence '{default_seq}' for all samples.")
        return [default_seq] * len(df)
    return df["author_seq"].tolist()

train_texts = train_df["hybrid_text"].tolist()
train_labels = train_df["boundary_ix"].tolist()
train_author_seqs = get_author_seqs(train_df)
dev_texts = dev_df["hybrid_text"].tolist()
dev_labels = dev_df["boundary_ix"].tolist()
dev_author_seqs = get_author_seqs(dev_df)
test_texts = test_df["hybrid_text"].tolist()
test_labels = test_df["boundary_ix"].tolist()
test_author_seqs = get_author_seqs(test_df)

MODEL_NAME = 'microsoft/deberta-v3-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = MixedTextDataset(train_texts, train_labels, train_author_seqs, tokenizer)
dev_dataset = MixedTextDataset(dev_texts, dev_labels, dev_author_seqs, tokenizer)
test_dataset = MixedTextDataset(test_texts, test_labels, test_author_seqs, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_collate_fn)
dev_loader = DataLoader(dev_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_collate_fn)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vocab_size = tokenizer.vocab_size
num_labels = 2

model = BiGRUCRFTagger(
    input_dim=vocab_size,
    num_labels=num_labels,
    embedding_dim=config['embedding_dim'],
    hidden_dim=config['hidden_dim'],
    num_layers=config['num_layers'],
    dropout=config['dropout']
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
total_steps = len(train_loader) * config['epochs']
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=config['learning_rate'],
                                               total_steps=total_steps, pct_start=0.1, anneal_strategy='cos')

In [8]:
best_kappa = -float('inf')
best_epoch = 0
patience = 2
patience_counter = 0
best_model_state = None
train_losses = []
val_metrics = {'accuracy': [], 'precision': [], 'recall': [], 'f1': [], 'mcc': [], 'mae': [], 'std_dev': [], 'kappa': []}

print("Starting training...")
print(f"Training on device: {device}")
print(f"Number of training examples: {len(train_dataset)}")
print(f"Number of validation examples: {len(dev_dataset)}")
print(f"Number of test examples: {len(test_dataset)}")

for epoch in range(config['epochs']):
    print(f"\nEpoch {epoch + 1}/{config['epochs']}")
    train_loss = train_model(model, train_loader, optimizer, scheduler, device, config['gradient_clip'])
    train_losses.append(train_loss)
    val_accuracy, val_precision, val_recall, val_f1, val_mcc, val_mae, val_std_dev, val_kappa = evaluate_model(model, dev_loader, device)
    val_metrics['accuracy'].append(val_accuracy)
    val_metrics['precision'].append(val_precision)
    val_metrics['recall'].append(val_recall)
    val_metrics['f1'].append(val_f1)
    val_metrics['mcc'].append(val_mcc)
    val_metrics['mae'].append(val_mae)
    val_metrics['std_dev'].append(val_std_dev)
    val_metrics['kappa'].append(val_kappa)
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Validation Metrics:")
    print(f"Accuracy: {val_accuracy:.4f}")
    print(f"Precision: {val_precision:.4f}")
    print(f"Recall: {val_recall:.4f}")
    print(f"F1 Score: {val_f1:.4f}")
    print(f"MCC: {val_mcc:.4f}")
    print(f"Kappa: {val_kappa:.4f}")
    print(f"MAE: {val_mae:.2f}±{val_std_dev:.2f}")
    if val_kappa > best_kappa:
        best_kappa = val_kappa
        best_epoch = epoch + 1
        patience_counter = 0
        best_model_state = {
            'model_state_dict': {k: v.cpu() for k, v in model.state_dict().items()},
            'epoch': epoch + 1,
            'metrics': {
                'kappa': val_kappa,
                'f1': val_f1,
                'accuracy': val_accuracy,
                'precision': val_precision,
                'recall': val_recall,
                'mcc': val_mcc,
                'mae': val_mae,
                'std_dev': val_std_dev
            }
        }
        print(f"New best model with Kappa Score: {val_kappa:.4f}")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch + 1} epochs")
            break

print("\nTraining completed!")
print(f"Best model at epoch {best_epoch} with Kappa Score: {best_kappa:.4f}")

if best_model_state is not None:
    print("\nEvaluating test data...")
    model.load_state_dict(best_model_state['model_state_dict'])
    model.to(device)
    test_accuracy, test_precision, test_recall, test_f1, test_mcc, test_mae, test_std_dev, test_kappa = evaluate_model(model, test_loader, device)
    print("\nTest Metrics:")
    print(f"Accuracy: {test_accuracy:.4f}")
    print(f"Precision: {test_precision:.4f}")
    print(f"Recall: {test_recall:.4f}")
    print(f"F1 Score: {test_f1:.4f}")
    print(f"MCC: {test_mcc:.4f}")
    print(f"Kappa: {test_kappa:.4f}")
    print(f"MAE: {test_mae:.2f}±{test_std_dev:.2f}")
else:
    print("\nNo valid model state found. Skipping test evaluation.")

Starting training...
Training on device: cuda
Number of training examples: 12049
Number of validation examples: 2527
Number of test examples: 2560

Epoch 1/3
Train Loss: 69.0634
Validation Metrics:
Accuracy: 0.9371
Precision: 0.9380
Recall: 0.9371
F1 Score: 0.9373
MCC: 0.8608
Kappa: 0.8605
MAE: 0.06±0.24
New best model with Kappa Score: 0.8605

Epoch 2/3
Train Loss: 17.0914
Validation Metrics:
Accuracy: 0.9507
Precision: 0.9507
Recall: 0.9507
F1 Score: 0.9504
MCC: 0.8890
Kappa: 0.8883
MAE: 0.05±0.22
New best model with Kappa Score: 0.8883

Epoch 3/3
Train Loss: 7.4398
Validation Metrics:
Accuracy: 0.9515
Precision: 0.9516
Recall: 0.9515
F1 Score: 0.9512
MCC: 0.8908
Kappa: 0.8900
MAE: 0.05±0.21
New best model with Kappa Score: 0.8900

Training completed!
Best model at epoch 3 with Kappa Score: 0.8900

Evaluating test data...

Test Metrics:
Accuracy: 0.9509
Precision: 0.9510
Recall: 0.9509
F1 Score: 0.9506
MCC: 0.8910
Kappa: 0.8904
MAE: 0.05±0.22
