In [1]:
import sys
import os
from transformers import AutoTokenizer
from definitions import ROOT_DIR, MODEL
from torch.utils.data import DataLoader

# Get the relevant directories
src_path = os.path.join(ROOT_DIR, 'src')

# Add src directory to sys.path
# Adapted from Taras Alenin's answer on StackOverflow at:
# https://stackoverflow.com/a/55623567
if src_path not in sys.path:
    sys.path.insert(0, src_path)

# Import my custom modules
import dataset  # noqa: E402

import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [4]:
#############################################################################
# DATA PREPROCESSING
#############################################################################

# Toy Dataset
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
toy_data = [
    ("This is a test sentence A.",
     "This is a test sentence B.",
     1),  # Similar pair
    ("This is a test sentence C.",
     "Sentence D is completely different.",
     0)  # Dissimilar pair
]


def tokenize_pair(text_a, text_b, tokenizer, max_length=128):
    """
    Tokenize two input sentences with a given tokenizer and max_length
    argument.

    :param text_a, text_b: The raw input text
    :type text_a, text_b: string
    :param tokenizer: The tokenizer
    :type tokenizer: transformers.models
    :param max_length: Max length in tokens that the output embedding should
        be. Defaults to 128
    :type max_length: int
    """

    # Truncate inputs that tokenize to tensors longer than `max_length`
    # Add special characters from BERT encoders, like [CLS] and [SEP]
    # TODO: perhaps revisit the above as these might add unhelpful noise
    # for our task
    tokens_a = tokenizer(text_a,
                         return_tensors="pt",
                         padding="max_length",
                         truncation=True,
                         max_length=max_length,
                         add_special_tokens=True)

    tokens_b = tokenizer(text_b,
                         return_tensors="pt",
                         padding="max_length",
                         truncation=True,
                         max_length=max_length,
                         add_special_tokens=True)

    return tokens_a, tokens_b


# Tokenize inputs
tokenized_pairs = []

for text_a, text_b, _ in toy_data:
    # Tokenize each pair of texts
    tokens_a, tokens_b = tokenize_pair(text_a, text_b, tokenizer)

    # Store the tokenized outputs
    tokenized_pairs.append((tokens_a, tokens_b))

# Create batched tensors

# Collate a tensor of row vectors containing indices into our pre-trained
# model's vocabulary, representing the sentences in position 0 (known works)
# ordered based on the ordering of the tokenized sentence.
known_author_input_ids = torch.cat([pair[0]['input_ids']
                                    for pair in tokenized_pairs])
# Do the same for tokenized sentences in position 1 (works to verify).
verification_text_input_ids = torch.cat([pair[1]['input_ids']
                                         for pair in tokenized_pairs])
# Collate the attention masks for sentences in position 0 similarly.
known_author_attention_mask = torch.cat([pair[0]['attention_mask']
                                         for pair in tokenized_pairs])
# Collate the attention masks for sentences in position 1 similarly.
verification_text_attention_mask = torch.cat([pair[1]['attention_mask']
                                              for pair in tokenized_pairs])

# Collate labels tensor, preserving ordering relative to input ids and
# attention masks.
labels = torch.tensor([label for _, _, label in toy_data])


In [2]:
ds = dataset.CustomDataset('data/test')
len(ds)

12

In [3]:
dataloader = DataLoader(
    ds,
    batch_size=6,
    shuffle=True,
    num_workers=4
)

In [5]:
for i in dataloader:
    print(i)

[{'input_ids': tensor([[[  101, 29379, 29379,   102]],

        [[  101,  1040,  1040,   102]],

        [[  101, 29379, 29379,   102]],

        [[  101,  3347,  3347,   102]],

        [[  101,  3347,  3347,   102]],

        [[  101, 29379, 29379,   102]]]), 'attention_mask': tensor([[[1, 1, 1, 1]],

        [[1, 1, 1, 1]],

        [[1, 1, 1, 1]],

        [[1, 1, 1, 1]],

        [[1, 1, 1, 1]],

        [[1, 1, 1, 1]]])}, {'input_ids': tensor([[[ 101, 1040, 1040,  102]],

        [[ 101, 1038, 1038,  102]],

        [[ 101, 8670, 2480,  102]],

        [[ 101, 1040, 1040,  102]],

        [[ 101, 1037, 1037,  102]],

        [[ 101, 3347, 3347,  102]]]), 'attention_mask': tensor([[[1, 1, 1, 1]],

        [[1, 1, 1, 1]],

        [[1, 1, 1, 1]],

        [[1, 1, 1, 1]],

        [[1, 1, 1, 1]],

        [[1, 1, 1, 1]]])}, tensor([1, 0, 0, 1, 1, 1])]
[{'input_ids': tensor([[[  101,  1037,  1037,   102]],

        [[  101,  1037,  1037,   102]],

        [[  101,  1040,  1040,   102