In [1]:
from av_dataset import VALLAAVTrainDataset, VALLAAVValDataset
import sys
import os
import torch
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import pandas as pd

# Add src directory to sys.path
# Adapted from Taras Alenin's answer on StackOverflow at:
# https://stackoverflow.com/a/55623567
src_path = os.path.join('..', 'src')
if src_path not in sys.path:
    sys.path.insert(0, src_path)

# Import custom modules
from constants import MODEL
from siamese_sbert import SiameseSBERT
from modified_contrastive_loss import ModifiedContrastiveLoss

In [2]:
path_to_train_csv = '../../pan20-authorship-verification-training-small/processed/small/pan20_train_DV_MA_k_20000.csv'
path_to_val_csv = '../../pan20-authorship-verification-training-small/processed/small/pan20_val_DV_MA_k_20000.csv'
train_dataset = VALLAAVTrainDataset(path_to_train_csv)
val_dataset = VALLAAVValDataset(path_to_val_csv)
print('Train len:', len(train_dataset))
print('Val len:', len(val_dataset))

Train len: 83934
Val len: 14202


In [3]:
def valla_collate_fn(batch):
    """
    Collate function to properly batch the paired inputs.
    """
    tokenizer = AutoTokenizer.from_pretrained(MODEL)

    labels = torch.tensor([item.label for item in batch])
    print(labels)
    anchor_texts = [tokenizer(item.texts[0],
                              return_tensors="pt",
                              padding='max_length',
                              truncation=True,
                              max_length=512) for item in batch]
    other_texts = [tokenizer(item.texts[1],
                             return_tensors="pt",
                             padding='max_length',
                             truncation=True,
                             max_length=512) for item in batch]

    # Combine input_ids and attention_masks
    batched_a = {
        'input_ids': torch.cat([x['input_ids'] for x in anchor_texts]),
        'attention_mask': torch.cat([x['attention_mask']
                                     for x in anchor_texts])
    }
    batched_o = {
        'input_ids': torch.cat([x['input_ids'] for x in other_texts]),
        'attention_mask': torch.cat([x['attention_mask']
                                     for x in other_texts])
    }

    return batched_a, batched_o, labels



In [4]:
model = SiameseSBERT(MODEL, 'cuda').to('cuda')

# Instantiate custom contrastive loss fuction
# 'modified contrastive loss'
loss_function = ModifiedContrastiveLoss(margin_s=0.75,
                                        margin_d=0.25)

# Instantiate Adam optimizer
optimizer = torch.optim.Adam(model.parameters(),
                             lr=2e-05,
                             eps=1e-6)


scheduler = LinearLR(optimizer, start_factor=1.0,
                     end_factor=0.1,
                     total_iters=1)

# Create a list to save fold losses
fold_losses = []

# Instantiate the dataloader for the train_dataset
train_dataloader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=valla_collate_fn,
)

In [5]:
# TRAIN
##############################################################
# Set model to training mode
model.train()

# Gradient Accumulation Implementation:
# Adapted from
# https://stackoverflow.com/a/78619879 [37]
# Initialize running total for gradients
optimizer.zero_grad()

print(len(train_dataloader))
# Iterate over the train_dataloader one batch at a time
for batch_idx, (batch_anchor,
                batch_other,
                labels) in enumerate(train_dataloader):
    print(f'starting batch {batch_idx}')
    # batch_content is a tuple containing three elements
    # coming from the PyTorch `DataLoader` object:
    # - batch_anchor at index 0 - the batch tensor of
    #   chunks to be fed through the 'left' side of the
    #   Siamese network.
    # - batch_other at index 1 - the batch tensor of
    #   chunks to be fed through the 'right' side of the
    #   Siamese network.
    # - labels at index 2 - the ground truths for the
    #   pairs:
    #     - 1 = same-author
    #     - 0 = different-author

    # Move batches to device (MPS/CPU)
    batch_anchor = {k: v.to('cuda')
                    for k, v in batch_anchor.items()}
    batch_other = {k: v.to('cuda')
                   for k, v in batch_other.items()}
    labels = labels.to('cuda')

    # Forward pass
    anchor_embedding, other_embedding = model(
        batch_anchor['input_ids'],
        batch_anchor['attention_mask'],
        batch_other['input_ids'],
        batch_other['attention_mask']
    )
    # Calculate the contrastive loss of this batch and
    # normalize by accumulation steps
    loss = loss_function(anchor_embedding,
                         other_embedding,
                         labels) / 1
    # Save the batch loss
    # unnormalized loss for reporting
    fold_losses.append(loss * 1)

    loss.backward()

    torch.cuda.empty_cache()

    # Adam optimizer
    optimizer.step()
    # Clear out any existing gradients
    optimizer.zero_grad()

10492
tensor([1, 1, 1, 0, 1, 0, 0, 1])
starting batch 0
tensor([0, 1, 1, 0, 1, 0, 1, 1])
starting batch 1
tensor([1, 1, 1, 0, 1, 1, 0, 0])
starting batch 2
tensor([0, 1, 1, 1, 1, 1, 0, 0])
starting batch 3
tensor([1, 0, 0, 0, 0, 1, 1, 1])
starting batch 4
tensor([1, 1, 0, 1, 0, 1, 0, 1])
starting batch 5
tensor([1, 0, 0, 1, 0, 1, 1, 1])
starting batch 6
tensor([0, 0, 0, 1, 0, 0, 0, 1])
starting batch 7
tensor([1, 0, 1, 1, 0, 1, 0, 0])
starting batch 8
tensor([0, 0, 1, 1, 1, 1, 0, 1])
starting batch 9
tensor([1, 0, 0, 1, 0, 0, 1, 1])
starting batch 10
tensor([0, 0, 1, 1, 0, 0, 1, 0])
starting batch 11
tensor([1, 0, 1, 1, 1, 0, 0, 0])
starting batch 12
tensor([0, 1, 0, 0, 1, 1, 1, 0])
starting batch 13


KeyboardInterrupt: 