In [4]:
from transformers import AutoTokenizer
from constants import MODEL
from siamese_sbert import SiameseSBERT
import torch
import torch.nn.functional as F

In [5]:
# Parallelization/Concurency
# Use CUDA if available, else use MPS if available. Fallback is CPU
device = torch.device("cuda" if torch.cuda.is_available()
                      else (
                        "mps"
                        if torch.backends.mps.is_available()
                        else "cpu"
                      ))

In [6]:
def run_inference(checkpoint_path, text_pair, tokenizer, device='cuda'):
    """
    Load a saved model and run inference on a text pair.

    :param checkpoint_path: Path to the saved checkpoint
    :type checkpoint_path: str

    :param text_pair: Tuple of (text1, text2) to compare
    :type text_pair: tuple

    :param tokenizer: The BERT tokenizer
    :param tokenizer: transformers.Autotokenizer

    :param device: Device to run inference on
    :type device: str

    :returns: Similarity score between 0-1
    :rtype: float
    """
    # Load model
    model = SiameseSBERT(MODEL, device).to(device)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Tokenize inputs
    text1, text2 = text_pair
    tokens1 = tokenizer(text1,
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                        max_length=512).to(device)
    tokens2 = tokenizer(text2,
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                        max_length=512).to(device)

    # Run inference
    with torch.no_grad():
        embeddings1, embeddings2 = model(
            tokens1['input_ids'],
            tokens1['attention_mask'],
            tokens2['input_ids'],
            tokens2['attention_mask']
        )

        # Calculate similarity
        similarity = F.cosine_similarity(embeddings1, embeddings2)
        # Scale from [-1,1] to [0,1]
        scaled_similarity = (similarity + 1) / 2

    return scaled_similarity.item()

In [9]:
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL)

# Load and run inference
checkpoint_path = "/Users/zacbolton/dev/BSc/FP/historical_av_with_SBERT/model_out/model_saving_exp_2/undistorted/model_saving_exp_2_fold_1_epoch_0.pt"
text1 = "First piece of text to compare"
text2 = "Second piece of text to compare"

similarity = run_inference(checkpoint_path,
                           (text1, text2),
                           tokenizer,
                           device=device)

print(f"Similarity score: {similarity:.4f}")

  checkpoint = torch.load(checkpoint_path, map_location=device)


Similarity score: 0.9587
Predicted same author: True
