In [1]:
import sys
import os
from transformers import AutoTokenizer
import torch

# Get project root 
project_root = os.getcwd()
src_path = os.path.join(project_root, 'src')

# Add src directory to sys.path
# Adapted from Taras Alenin's answer on StackOverflow at:
# https://stackoverflow.com/a/55623567
if src_path not in sys.path:
    sys.path.insert(0, src_path)

# Import custom modules
from prototype import SiameseBERT, ContrastiveLoss  # noqa: E402


# Prototype

This demonstrates the Siamese SBERT model architecture for authorship verification (AV). For more details on the architecture and implementation, see the module docstring in `/src/prototype.py`.

Below is a modified version of my model architecture diagram showing the prototype architecture at training (left) and inference (right). It notes that the prototype is intended for a single forward and backward pass on a micro-batch of two sample pairs. It strikes out references to chunking, as that is not handled for this prototype.

_Figure 1_: **_Prototype_ architecture**
```mermaid
flowchart TD
    subgraph "SBERT Training"
        pab1["Pair (A, B),\n<s style='text-decoration-thickness: 4px; text-decoration-color: red'>chunk len = 256</s>"] --> chp1["<s style='text-decoration-thickness: 4px; text-decoration-color: red'>Chunking Pair</s>"]
        chp1 --> np1["<s style='text-decoration-thickness: 4px; text-decoration-color: red'>N</s> 2 pairs"]
        np1 --> pnon1["Pair n out of <s style='text-decoration-thickness: 4px; text-decoration-color: red'>N</s> 2"]
        pnon1 --> can1["<s style='text-decoration-thickness: 4px; text-decoration-color: red'>Chunk</s> A.n"]
        pnon1 --> can2["<s style='text-decoration-thickness: 4px; text-decoration-color: red'>Chunk</s> B.n"]
        can1 --> bert1[BERT]
        can2 --> bert2[BERT]
        bert1 --> pool1[Mean Pooling]
        bert2 --> pool2[Mean Pooling]
        pool1 --> u1[U]
        pool2 --> v1[V]
        u1 --> contrastive[Contrastive Loss]
        v1 --> contrastive
    end

    subgraph "SBERT at Inference"
        pab2["Pair (A, B),\n<s style='text-decoration-thickness: 4px; text-decoration-color: red'>chunk len = 256</s>"] --> chp2["<s style='text-decoration-thickness: 4px; text-decoration-color: red'>Chunking Pair</s>"]
        chp2 --> np2["<s style='text-decoration-thickness: 4px; text-decoration-color: red'>N</s> 2 pairs"]
        np2 --> pnon2["Pair n out of <s style='text-decoration-thickness: 4px; text-decoration-color: red'>N</s> 2"]
        pnon2 --> can3["<s style='text-decoration-thickness: 4px; text-decoration-color: red'>Chunk</s> A.n"]
        pnon2 --> can4["<s style='text-decoration-thickness: 4px; text-decoration-color: red'>Chunk</s> B.n"]
        can3 --> bert3[BERT]
        can4 --> bert4[BERT]
        bert3 --> pool3[Mean Pooling]
        bert4 --> pool4[Mean Pooling]
        pool3 --> u2[U]
        pool4 --> v2[V]
        u2 --> cos["Cosine-Similarity(U,V)"]
        v2 --> cos
        cos --> sim["Similarity Score n"]
    end
    
    note["For one micro-batch of 2 sample pairs"] --- pab1
    note --- pab2
    
    style contrastive fill:#9999ff
    style contrastive color:#000000
    style cos fill:#ff9999
    style cos color:#000000
    style sim fill:#ff9999
    style sim color:#000000
    style pab1 fill:#999999
    style pab1 color:#000000
    style pab2 fill:#999999
    style pab2 color:#000000
    linkStyle 27,28 stroke:#999999,stroke-width:1px,stroke-dasharray: 5 5
```

In [2]:
#############################################################################
# CONSTANTS
#############################################################################

# all-MiniLM-L12-v2 pretrained model from Hugging Face:
# https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2
MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2"

# Learning rate of 2e-05 as per Ibrahim Et. Al (2023) [13:10]
# This is intentionally very small as we are fine-tuning a pre-trained model
LEARNING_RATE = 2e-05


In [3]:
#############################################################################
# MODEL INSTANTIATION
#############################################################################

# Instantiate custom Siamese SBERT model
model = SiameseBERT(MODEL_NAME)

# Instantiate custom contrastive loss function
# TODO: Consider implementing 'modified contrastive loss' from
# https://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf [18]
# and
# Tyo Et. Al (2021) [15]
loss_function = ContrastiveLoss(margin=1.0)

# Instantiate Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)


In [4]:
#############################################################################
# DATA PREPROCESSING
#############################################################################

# Toy Dataset
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
toy_data = [
    ("This is a test sentence A.",
     "This is a test sentence B.",
     1),  # Similar pair
    ("This is a test sentence C.",
     "Sentence D is completely different.",
     0)  # Dissimilar pair
]


def tokenize_pair(text_a, text_b, tokenizer, max_length=128):
    """
    Tokenize two input sentences with a given tokenizer and max_length
    argument.

    :param text_a, text_b: The raw input text
    :type text_a, text_b: string
    :param tokenizer: The tokenizer
    :type tokenizer: transformers.models
    :param max_length: Max length in tokens that the output embedding should
        be. Defaults to 128
    :type max_length: int
    """

    # Truncate inputs that tokenize to tensors longer than `max_length`
    # Add special characters from BERT encoders, like [CLS] and [SEP]
    # TODO: perhaps revisit the above as these might add unhelpful noise
    # for our task
    tokens_a = tokenizer(text_a,
                         return_tensors="pt",
                         padding="max_length",
                         truncation=True,
                         max_length=max_length,
                         add_special_tokens=True)

    tokens_b = tokenizer(text_b,
                         return_tensors="pt",
                         padding="max_length",
                         truncation=True,
                         max_length=max_length,
                         add_special_tokens=True)

    return tokens_a, tokens_b


# Tokenize inputs
tokenized_pairs = []

for text_a, text_b, _ in toy_data:
    # Tokenize each pair of texts
    tokens_a, tokens_b = tokenize_pair(text_a, text_b, tokenizer)

    # Store the tokenized outputs
    tokenized_pairs.append((tokens_a, tokens_b))

# Create batched tensors

# Collate a tensor of row vectors containing indices into our pre-trained
# model's vocabulary, representing the sentences in position 0 (known works)
# ordered based on the ordering of the tokenized sentence.
known_author_input_ids = torch.cat([pair[0]['input_ids']
                                    for pair in tokenized_pairs])
# Do the same for tokenized sentences in position 1 (works to verify).
verification_text_input_ids = torch.cat([pair[1]['input_ids']
                                         for pair in tokenized_pairs])
# Collate the attention masks for sentences in position 0 similarly.
known_author_attention_mask = torch.cat([pair[0]['attention_mask']
                                         for pair in tokenized_pairs])
# Collate the attention masks for sentences in position 1 similarly.
verification_text_attention_mask = torch.cat([pair[1]['attention_mask']
                                              for pair in tokenized_pairs])

# Collate labels tensor, preserving ordering relative to input ids and
# attention masks.
labels = torch.tensor([label for _, _, label in toy_data])


In [5]:
#############################################################################
# FORWARD PASS
#############################################################################

# Forward pass through the model
u, v = model(known_author_input_ids, known_author_attention_mask,
             verification_text_input_ids, verification_text_attention_mask)


In [6]:
#############################################################################
# LOSS CALCULATION
#############################################################################

# Calculate loss
loss = loss_function(u, v, labels)
print(f"Loss: {loss.item()}")


Loss: 0.3593752980232239


In [7]:
#############################################################################
# BACKPROPAGATION
#############################################################################

# Clear out any existing gradients
optimizer.zero_grad()

# Backpropogation pass
loss.backward()

# Update weights using calculated gradients from Adam optimizer
optimizer.step()


In [8]:
#############################################################################
# INFERENCE
#############################################################################

# Set model to evaluation mode
model.eval()

# Run inference on the same toy training data
with torch.no_grad():  # Disable gradient calculation for inference
    # Forward pass through the model
    u_inference, v_inference = model(known_author_input_ids, known_author_attention_mask,
                                     verification_text_input_ids, verification_text_attention_mask)

    # Calculate cosine similarity between embeddings
    similarities = torch.nn.functional.cosine_similarity(u_inference,
                                                         v_inference)

    # Scale similarities from [-1,1] to [0,1] range
    scaled_similarities = (similarities + 1) / 2

    # Print results
    for i, (text_a, text_b, true_label) in enumerate(toy_data):
        print(f"\nPair {i+1}:")
        print(f"Text A: {text_a}")
        print(f"Text B: {text_b}")
        print(f"True Label: {true_label}")
        print(f"Similarity Score: {scaled_similarities[i]:.3f}")



Pair 1:
Text A: This is a test sentence A.
Text B: This is a test sentence B.
True Label: 1
Similarity Score: 0.998

Pair 2:
Text A: This is a test sentence C.
Text B: Sentence D is completely different.
True Label: 0
Similarity Score: 0.797
