In [7]:
from av_dataset import VALLAAVTrainDataset, VALLAAVValDataset
import sys
import os
from constants import ROOT_DIR, MODEL
import torch
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

# Add src directory to sys.path
# Adapted from Taras Alenin's answer on StackOverflow at:
# https://stackoverflow.com/a/55623567
src_path = os.path.join(ROOT_DIR, 'src')
if src_path not in sys.path:
    sys.path.insert(0, src_path)

# Import custom modules
from siamese_sbert import SiameseSBERT
from modified_contrastive_loss import ModifiedContrastiveLoss

In [2]:
path_to_train_csv = './data/pan20-authorship-verification-training-small/processed/small/pan20_train.csv'
# path_to_val_csv = './data/pan20-authorship-verification-training-small/processed/small/pan20_AV_val.csv'
train_dataset = VALLAAVTrainDataset(path_to_train_csv)
# val_dataset = VALLAAVValDataset(path_to_val_csv)
print('Train len:', len(train_dataset))
# print('Val len:', len(val_dataset))

Train len: 83934


In [11]:
def valla_collate_fn(batch):
    """
    Collate function to properly batch the paired inputs.
    """
    tokenizer = AutoTokenizer.from_pretrained(MODEL)

    labels = torch.tensor([item.label for item in batch])
    anchor_texts = [tokenizer(item.texts[0],
                              return_tensors="pt",
                              padding=True,
                              truncation=True,
                              max_length=512) for item in batch]
    other_texts = [tokenizer(item.texts[1],
                             return_tensors="pt",
                             padding=True,
                             truncation=True,
                             max_length=512) for item in batch]

    # Combine input_ids and attention_masks
    batched_a = {
        'input_ids': torch.cat([x['input_ids'] for x in anchor_texts]),
        'attention_mask': torch.cat([x['attention_mask']
                                     for x in anchor_texts])
    }
    batched_o = {
        'input_ids': torch.cat([x['input_ids'] for x in other_texts]),
        'attention_mask': torch.cat([x['attention_mask']
                                     for x in other_texts])
    }

    return batched_a, batched_o, labels



In [15]:
model = SiameseSBERT(MODEL, 'mps').to('mps')

# Instantiate custom contrastive loss fuction
# 'modified contrastive loss'
loss_function = ModifiedContrastiveLoss(margin_s=0.75,
                                        margin_d=0.25)

# Instantiate Adam optimizer
optimizer = torch.optim.Adam(model.parameters(),
                             lr=2e-05,
                             eps=1e-6)


scheduler = LinearLR(optimizer, start_factor=1.0,
                     end_factor=0.1,
                     total_iters=1)

# Create a list to save fold losses
fold_losses = []

# Instantiate the dataloader for the train_dataset
train_dataloader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=valla_collate_fn,
)

In [16]:
# TRAIN
##############################################################
# Set model to training mode
model.train()

# Gradient Accumulation Implementation:
# Adapted from
# https://stackoverflow.com/a/78619879 [37]
# Initialize running total for gradients
optimizer.zero_grad()

# Iterate over the train_dataloader one batch at a time
for batch_idx, (batch_anchor,
                batch_other,
                labels) in enumerate(train_dataloader):
    # batch_content is a tuple containing three elements
    # coming from the PyTorch `DataLoader` object:
    # - batch_anchor at index 0 - the batch tensor of
    #   chunks to be fed through the 'left' side of the
    #   Siamese network.
    # - batch_other at index 1 - the batch tensor of
    #   chunks to be fed through the 'right' side of the
    #   Siamese network.
    # - labels at index 2 - the ground truths for the
    #   pairs:
    #     - 1 = same-author
    #     - 0 = different-author

    # Move batches to device (MPS/CPU)
    batch_anchor = {k: v.to('mps')
                    for k, v in batch_anchor.items()}
    batch_other = {k: v.to('mps')
                   for k, v in batch_other.items()}
    labels = labels.to('mps')

    # Forward pass
    anchor_embedding, other_embedding = model(
        batch_anchor['input_ids'],
        batch_anchor['attention_mask'],
        batch_other['input_ids'],
        batch_other['attention_mask']
    )
    # Calculate the contrastive loss of this batch and
    # normalize by accumulation steps
    loss = loss_function(anchor_embedding,
                         other_embedding,
                         labels) / accumulation_steps
    # Save the batch loss
    # unnormalized loss for reporting
    fold_losses.append(loss * accumulation_steps)

    loss.backward()

    # Clear MPS cache after each epoch
    torch.mps.empty_cache()
    if hasattr(torch.mps, 'synchronize'):
        torch.mps.synchronize()

    # Adam optimizer
    optimizer.step()
    # Clear out any existing gradients
    optimizer.zero_grad()

RuntimeError: MPS backend out of memory (MPS allocated: 17.97 GB, other allocations: 96.72 MB, max allowed: 18.13 GB). Tried to allocate 96.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).