In [16]:
from claims_dataset import ClaimEvidenceDataset
from siamese_bert import SiameseBert
from sklearn.model_selection import train_test_split
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler

In [4]:
train_df = pd.read_csv('data/train.csv')
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state= 42)

train_df.to_csv('train_split.csv', index=False)
val_df.to_csv('val_split.csv', index=False)

In [7]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler

class Trainer:
    def __init__(self, train_csv, validation_csv, epochs=3):
        self.epochs = epochs

        self.train_dataset = ClaimEvidenceDataset(train_csv)
        self.val_dataset = ClaimEvidenceDataset(validation_csv)
        self.train_loader = DataLoader(self.train_dataset, batch_size=32, shuffle=True)
        self.val_loader = DataLoader(self.val_dataset, batch_size=32, shuffle=False)

        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        self.model = SiameseBert().to(self.device)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-5)
        self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.9)

        self.best_val_loss = float('inf')

    def trainModel(self):
        for epoch in range(self.epochs):
            # Training Phase
            self.model.train()
            train_loss = 0
            for batch in self.train_loader:
                self.optimizer.zero_grad()

                claim_input_ids = batch['claim_input_ids'].to(self.device)
                claim_attention_mask = batch['claim_attention_mask'].to(self.device)
                evidence_input_ids = batch['evidence_input_ids'].to(self.device)
                evidence_attention_mask = batch['evidence_attention_mask'].to(self.device)
                true_labels = batch['labels'].to(self.device)

                output = self.model(claim_input_ids, evidence_input_ids, claim_attention_mask, evidence_attention_mask)
                loss = self.criterion(output, true_labels)
                loss.backward()
                self.optimizer.step()
                train_loss += loss.item()

            train_loss /= len(self.train_loader)

            # Validation Phase
            self.model.eval()
            val_loss = 0
            correct = 0
            total = 0

            with torch.no_grad():
                for batch in self.val_loader:
                    claim_input_ids = batch['claim_input_ids'].to(self.device)
                    claim_attention_mask = batch['claim_attention_mask'].to(self.device)
                    evidence_input_ids = batch['evidence_input_ids'].to(self.device)
                    evidence_attention_mask = batch['evidence_attention_mask'].to(self.device)
                    true_labels = batch['labels'].to(self.device)

                    output = self.model(claim_input_ids, evidence_input_ids, claim_attention_mask, evidence_attention_mask)
                    loss = self.criterion(output, true_labels)
                    val_loss += loss.item()

                    # Compute accuracy
                    predictions = torch.argmax(output, dim=1)
                    correct += (predictions == true_labels).sum().item()
                    total += true_labels.size(0)

            val_loss /= len(self.val_loader)
            val_accuracy = correct / total
            print(f'Epoch {epoch+1}/{self.epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')

            # Save best model
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                torch.save(self.model.state_dict(), 'best_model.pth')
                print('Saved best model')

            self.scheduler.step()
            print(f'Updated Learning Rate: {self.scheduler.get_last_lr()[0]:.6f}')


In [None]:
model_trainer = Trainer('train_split.csv', 'val_split.csv')
model_trainer.trainModel()