In [1]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizer, RobertaForSequenceClassification, AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score
import numpy as np
from tqdm import tqdm
import time
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import os
from datetime import datetime

  from .autonotebook import tqdm as notebook_tqdm
2025-04-16 22:40:10.714717: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-16 22:40:10.836319: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744823410.894686 2684402 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744823410.912598 2684402 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744823411.023066 2684402 computation_placer.cc:177] computation placer already r

In [2]:
# Configuration
MAX_LEN = 256
BATCH_SIZE = 16
EPOCHS = 3
LEARNING_RATE = 2e-5
CLASS_WEIGHTS = {
    'relation': torch.tensor([0.05, 0.3, 0.65]),  # Adjust based on your class distribution
    'source': torch.tensor([0.1, 0.2, 0.7]),      # non-arg, premise, conclusion
    'target': torch.tensor([0.1, 0.2, 0.7])
}

In [3]:
class ArgumentDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.max_len = max_len
        self.label_map = {
            'relation': {'no-relation': 0, 'support': 1, 'attack': 2},
            'source_type': {'non-argumentative': 0, 'premise': 1, 'conclusion': 2},
            'target_type': {'non-argumentative': 0, 'premise': 1, 'conclusion': 2}
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        row = self.data.iloc[index]
        
        # Tokenize text pair
        text = row['source_text'] + " </s></s> " + row['target_text']
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_token_type_ids=True
        )
        
        # Convert labels
        return {
            'input_ids': torch.tensor(inputs['input_ids'], dtype=torch.long),
            'attention_mask': torch.tensor(inputs['attention_mask'], dtype=torch.long),
            'relation_label': torch.tensor(self.label_map['relation'][row['relation']], dtype=torch.long),
            'source_label': torch.tensor(self.label_map['source_type'][row['source_type']], dtype=torch.long),
            'target_label': torch.tensor(self.label_map['target_type'][row['target_type']], dtype=torch.long)
        }


In [4]:
class MultiTaskRoberta(torch.nn.Module):
    def __init__(self):
        super(MultiTaskRoberta, self).__init__()
        self.roberta = RobertaForSequenceClassification.from_pretrained("roberta-base").roberta
        self.dropout = torch.nn.Dropout(0.1)
        
        # Task-specific heads
        self.relation_classifier = torch.nn.Linear(768, 3)
        self.source_classifier = torch.nn.Linear(768, 3)
        self.target_classifier = torch.nn.Linear(768, 3)

    def forward(self, input_ids, attention_mask):
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)
        
        return (
            self.relation_classifier(pooled_output),
            self.source_classifier(pooled_output),
            self.target_classifier(pooled_output)
        )


In [5]:
df = pd.read_csv('balanced_relations_dataset.csv')
train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['relation'], random_state=42)

In [6]:
# Initialize tokenizer and datasets
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
train_dataset = ArgumentDataset(train_df, tokenizer, MAX_LEN)
val_dataset = ArgumentDataset(val_df, tokenizer, MAX_LEN)

In [7]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MultiTaskRoberta().to(device)

optimizer = AdamW([
    {'params': model.roberta.parameters(), 'lr': 1e-5},
    {'params': model.relation_classifier.parameters(), 'lr': 2e-4},
    {'params': model.source_classifier.parameters(), 'lr': 2e-4},
    {'params': model.target_classifier.parameters(), 'lr': 2e-4}
])

  return torch._C._cuda_getDeviceCount() > 0
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
relation_criterion = torch.nn.CrossEntropyLoss(weight=CLASS_WEIGHTS['relation'].to(device))
source_criterion = torch.nn.CrossEntropyLoss(weight=CLASS_WEIGHTS['source'].to(device))
target_criterion = torch.nn.CrossEntropyLoss(weight=CLASS_WEIGHTS['target'].to(device))


In [10]:
# Function to plot and save training curves
def plot_training_curves(history, save_dir, current_epoch):
    """Plot and save training curves"""
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Plot 1: Loss curves
    plt.figure(figsize=(12, 8))
    plt.subplot(2, 2, 1)
    plt.plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    plt.plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Plot 2: Relation F1 scores
    plt.subplot(2, 2, 2)
    plt.plot(epochs, history['relation_f1'], 'g-', label='Macro F1')
    plt.plot(epochs, history['relation_f1_classes']['no-relation'], 'c--', label='No-relation')
    plt.plot(epochs, history['relation_f1_classes']['support'], 'm--', label='Support')
    plt.plot(epochs, history['relation_f1_classes']['attack'], 'y--', label='Attack')
    plt.title('Relation Classification F1 Scores')
    plt.xlabel('Epochs')
    plt.ylabel('F1 Score')
    plt.legend()
    plt.grid(True)
    
    # Plot 3: Source Type Accuracy
    plt.subplot(2, 2, 3)
    plt.plot(epochs, history['source_acc'], 'b-', label='Source Accuracy')
    plt.title('Source Type Classification Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    # Plot 4: Target Type Accuracy
    plt.subplot(2, 2, 4)
    plt.plot(epochs, history['target_acc'], 'r-', label='Target Accuracy')
    plt.title('Target Type Classification Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    
    # Save the plot
    plt.savefig(os.path.join(save_dir, f'training_curves_epoch_{current_epoch+1}.png'), dpi=300)
    plt.close()
    
    # Save metrics as CSV
    metrics_df = pd.DataFrame({
        'epoch': epochs,
        'train_loss': history['train_loss'],
        'val_loss': history['val_loss'],
        'relation_f1': history['relation_f1'],
        'relation_f1_no_relation': history['relation_f1_classes']['no-relation'],
        'relation_f1_support': history['relation_f1_classes']['support'],
        'relation_f1_attack': history['relation_f1_classes']['attack'],
        'source_acc': history['source_acc'],
        'target_acc': history['target_acc']
    })
    metrics_df.to_csv(os.path.join(save_dir, 'training_metrics.csv'), index=False)

In [11]:
# Create results directory
results_dir = f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
os.makedirs(results_dir, exist_ok=True)

# Initialize tracking variables
history = {
    'train_loss': [],
    'val_loss': [],
    'relation_f1': [],
    'source_acc': [],
    'target_acc': [],
    'relation_f1_classes': {'no-relation': [], 'support': [], 'attack': []},
    'source_acc_classes': {'non-arg': [], 'premise': [], 'conclusion': []},
    'target_acc_classes': {'non-arg': [], 'premise': [], 'conclusion': []}
}

# Training Loop with Enhanced Logging and Checkpointing
for epoch in range(EPOCHS):
    print(f"\n{'='*40}")
    print(f"Epoch {epoch+1}/{EPOCHS}")
    print(f"{'='*40}\n")
    
    # Training Phase
    model.train()
    total_loss = 0
    batch_time = 0
    start_time = time.time()
    
    # Initialize progress bar for training
    train_pbar = tqdm(enumerate(train_loader), 
                     total=len(train_loader),
                     desc=f"Epoch {epoch+1} Training",
                     unit="batch")
    
    for batch_idx, batch in train_pbar:
        batch_start = time.time()
        
        # Forward pass
        optimizer.zero_grad()
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'text'}
        outputs = model(inputs['input_ids'], inputs['attention_mask'])
        
        # Calculate loss
        loss = (0.6 * relation_criterion(outputs[0], inputs['relation_label']) +
                0.2 * source_criterion(outputs[1], inputs['source_label']) +
                0.2 * target_criterion(outputs[2], inputs['target_label']))
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Update metrics
        total_loss += loss.item()
        batch_time += time.time() - batch_start
        
        # Update progress bar
        train_pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'avg_loss': f"{total_loss/(batch_idx+1):.4f}",
            'batch_time': f"{time.time()-batch_start:.2f}s"
        })
    
    # Calculate epoch metrics
    epoch_time = time.time() - start_time
    avg_train_loss = total_loss / len(train_loader)
    
    # Validation Phase
    model.eval()
    val_loss = 0
    all_preds = {'relation': [], 'source': [], 'target': []}
    all_labels = {'relation': [], 'source': [], 'target': []}
    
    # Initialize progress bar for validation
    val_pbar = tqdm(enumerate(val_loader), 
                   total=len(val_loader),
                   desc=f"Epoch {epoch+1} Validation",
                   unit="batch")
    
    with torch.no_grad():
        for batch_idx, batch in val_pbar:
            batch_start = time.time()
            
            # Forward pass
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'text'}
            outputs = model(inputs['input_ids'], inputs['attention_mask'])
            
            # Calculate loss
            loss = (0.6 * relation_criterion(outputs[0], inputs['relation_label']) +
                    0.2 * source_criterion(outputs[1], inputs['source_label']) +
                    0.2 * target_criterion(outputs[2], inputs['target_label']))
            val_loss += loss.item()
            
            # Store predictions
            for i, task in enumerate(['relation', 'source', 'target']):
                preds = torch.argmax(outputs[i], dim=1).cpu().numpy()
                labels = inputs[f'{task}_label'].cpu().numpy()
                all_preds[task].extend(preds)
                all_labels[task].extend(labels)
            
            # Update progress bar
            val_pbar.set_postfix({
                'val_loss': f"{loss.item():.4f}",
                'avg_val_loss': f"{val_loss/(batch_idx+1):.4f}",
                'batch_time': f"{time.time()-batch_start:.2f}s"
            })
    
    # Calculate validation metrics
    avg_val_loss = val_loss / len(val_loader)
    
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(avg_val_loss)
    
    # Detailed Classification Reports
    relation_report = classification_report(
        all_labels['relation'], all_preds['relation'],
        target_names=['no-relation', 'support', 'attack'],
        zero_division=0
    )
    
    source_report = classification_report(
        all_labels['source'], all_preds['source'],
        target_names=['non-arg', 'premise', 'conclusion'],
        zero_division=0
    )
    
    target_report = classification_report(
        all_labels['target'], all_preds['target'],
        target_names=['non-arg', 'premise', 'conclusion'],
        zero_division=0
    )
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
    print(f"Time: {epoch_time:.2f}s ({batch_time/len(train_loader):.2f}s/batch)")
    
    print("\nRelation Classification Report:")
    print(relation_report)
    
    print("Source Type Classification Report:")
    print(source_report)
    
    print("Target Type Classification Report:")
    print(target_report)
    
    print(f"{'='*40}\n")
    
    history['relation_f1'].append(relation_report['macro avg']['f1-score'])
    history['relation_f1_classes']['no-relation'].append(relation_report['no-relation']['f1-score'])
    history['relation_f1_classes']['support'].append(relation_report['support']['f1-score'])
    history['relation_f1_classes']['attack'].append(relation_report['attack']['f1-score'])
    
    # Store accuracies
    history['source_acc'].append(source_report['accuracy'])
    history['target_acc'].append(target_report['accuracy'])
    
    # Save model checkpoint
    checkpoint_path = os.path.join(results_dir, f"checkpoint_epoch_{epoch+1}.pt")
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss
    }, checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")
    
    plot_training_curves(history, results_dir, epoch)



Epoch 1/3



Epoch 1 Training:  36%|███▌      | 87/245 [11:22<20:40,  7.85s/batch, loss=0.8792, avg_loss=0.8445, batch_time=8.28s] 


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), 'roberta_multi_task_model.pth')
