In [8]:
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
import torch
import torch.nn as nn

class SiameseClassifier(nn.Module):
    def __init__(self, base_model='all-MiniLM-L12-v2'):
        super().__init__()
        self.bert = SentenceTransformer(base_model)
        # The input size will be 3x the embedding dimension because we concatenate
        # U, V, and |U-V|
        embedding_dim = self.bert.get_sentence_embedding_dimension()
        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim * 3, 2),  # Binary classification
            nn.Softmax(dim=1)
        )
    
    def forward(self, text_a, text_b):
        # Generate embeddings
        u = self.bert.encode(text_a)
        v = self.bert.encode(text_b)
        
        # Create concatenated feature vector
        abs_diff = torch.abs(u - v)
        combined = torch.cat((u, v, abs_diff), dim=1)
        
        # Pass through classifier
        return self.classifier(combined)
        
model = SiameseClassifier()

In [5]:
# This creates pairs of texts we want to compare
training_examples = [
    # Known same author (label=1)
    InputExample(texts=['LaSalle text 1', 'LaSalle text 2'], label=1.0),
    # Known different authors (label=0) 
    InputExample(texts=['LaSalle text', 'Imposter text'], label=0.0)
]

# We use a special DataLoader that handles pairs
train_dataloader = DataLoader(training_examples, shuffle=True, batch_size=16)