In [1]:
import sys
import os
from torch.utils.data import DataLoader
import torch
from constants import ROOT_DIR, MODEL, LEARNING_RATE

# Add src directory to sys.path
# Adapted from Taras Alenin's answer on StackOverflow at:
# https://stackoverflow.com/a/55623567
src_path = os.path.join(ROOT_DIR, 'src')
if src_path not in sys.path:
    sys.path.insert(0, src_path)

# Import custom modules
from prototype import SiameseBERT, ContrastiveLoss  # noqa: E402
from lila_dataset import LILADataset  # noqa: E402

# Try to use machines parallelism by setting env variable
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [2]:
#############################################################################
# MODEL INSTANTIATION
#############################################################################

# Instantiate custom Siamese SBERT model
model = SiameseBERT(MODEL)

# Instantiate custom contrastive loss function
# TODO: Consider implementing 'modified contrastive loss' from
# https://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf [18]
# and
# Tyo Et. Al (2021) [15]
loss_function = ContrastiveLoss(margin=10.0)

# Instantiate Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [3]:
#############################################################################
# DATASET/LOADER INSTANTIATION
#############################################################################

train_ds = LILADataset('../data/normalized/undistorted', 512)

Token indices sequence length is longer than the specified maximum sequence length for this model (21972 > 512). Running this sequence through the model will result in indexing errors


In [4]:
A_docs_chunked = [element for sublist
                  in train_ds._A_docs_chunked for element
                  in sublist]

notA_docs_chunked = [element for sublist
                  in train_ds._notA_docs_chunked for element
                  in sublist]
print("A:", len(A_docs_chunked))
print("notA:", len(notA_docs_chunked))
print("pairs:", len(train_ds._pairs))

A: 602
notA: 4042
pairs: 2614185
