In [1]:
from transformers import AutoTokenizer
from constants import MODEL
from siamese_sbert import SiameseSBERT
from lila_dataset import LILADataset
import torch
import torch.nn.functional as F
import os

In [2]:
# Parallelization/Concurency
# Use CUDA if available, else use MPS if available. Fallback is CPU
device = torch.device("cuda" if torch.cuda.is_available()
                      else (
                        "mps"
                        if torch.backends.mps.is_available()
                        else "cpu"
                      ))

In [5]:
dataset_path = '../data/normalized'
undistorted_path = os.path.join(dataset_path, 'undistorted')
assert os.path.exists(undistorted_path)
metadata_path = os.path.join(dataset_path, 'metadata.csv')
assert os.path.exists(metadata_path)

# Create list to store all views to process
views = [undistorted_path]

for view_dir in os.listdir(dataset_path):
    view_path = os.path.join(dataset_path, view_dir)
    if view_dir != 'undistorted' and\
       os.path.isdir(view_path) and\
       view_dir[0] != '.':
        assert view_dir[:8] == 'DV-SA-k-' or view_dir[:8] == 'DV-MA-k-'
        views.append(view_path)

view_preds = {}
for view_path in views:
    # Get simple view string
    # Adapted from:
    # https://stackoverflow.com/a/3925147
    view = os.path.basename(os.path.normpath(view_path))
    # Add an entry to view_preds to store predictions for this
    # particular view
    view_preds[view] = []

    # Reset any existing splits
    LILADataset.reset_splits()

    # Instantiate the full LILA dataset
    inference_dataset = LILADataset(view_path,
                                    metadata_path,
                                    cnk_size=512,
                                    num_pairs=0,
                                    seed=0,
                                    letters=True)

    # Load model
    model = SiameseSBERT(MODEL, device).to(device)
    checkpoint_path = ("/Users/zacbolton/dev/BSc/FP/"
                       "historical_av_with_SBERT/saved_experiments/"
                       f"full_run/{view}/"
                       "full_run_fold_4_epoch_2.pt")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    for pair in inference_dataset._pairs:
        # Run inference
        with torch.no_grad():
            # Move input tensors to device
            input_ids_1 = pair[0]['input_ids'].to(device)
            attention_mask_1 = pair[0]['attention_mask'].to(device)
            input_ids_2 = pair[1]['input_ids'].to(device)
            attention_mask_2 = pair[1]['attention_mask'].to(device)

            embeddings1, embeddings2 = model(
                input_ids_1,
                attention_mask_1,
                input_ids_2,
                attention_mask_2
            )

            # Calculate similarity
            similarity = F.cosine_similarity(embeddings1, embeddings2)
            # Scale from [-1,1] to [0,1]
            scaled_similarity = (similarity + 1) / 2

        view_preds[view].append(scaled_similarity.item())

Token indices sequence length is longer than the specified maximum sequence length for this model (21972 > 512). Running this sequence through the model will result in indexing errors
  checkpoint = torch.load(checkpoint_path, map_location=device)
Token indices sequence length is longer than the specified maximum sequence length for this model (58898 > 512). Running this sequence through the model will result in indexing errors
  checkpoint = torch.load(checkpoint_path, map_location=device)
Token indices sequence length is longer than the specified maximum sequence length for this model (24306 > 512). Running this sequence through the model will result in indexing errors
  checkpoint = torch.load(checkpoint_path, map_location=device)
Token indices sequence length is longer than the specified maximum sequence length for this model (37976 > 512). Running this sequence through the model will result in indexing errors
  checkpoint = torch.load(checkpoint_path, map_location=device)


In [12]:
import numpy as np

p1=0.45
p2=0.54

for key in view_preds:
    preds = np.array(view_preds[key])

    print(f"""
VIEW {key}

Mumber of predictions: {len(preds)}
Mean:                  {np.mean(preds)}
STD:                   {np.std(preds)}
Median:                {np.median(preds)}

Same-author:           {np.mean(preds > p2)}
Different-author:      {np.mean(preds < p1)}
Undecided:             {np.mean((preds >= p1) & (preds <= p2))}

Quartiles:             {np.percentile(preds, [25, 50, 75])}
Min:                   {np.min(preds)}
Max:                   {np.max(preds)}

---\n
""")


VIEW undistorted

Mumber of predictions: 820
Mean:                  0.7086946412920951
STD:                   0.29838421972626655
Median:                0.8612443208694458

Same-author:           0.7439024390243902
Different-author:      0.2475609756097561
Undecided:             0.00853658536585366

Quartiles:             [0.45598616 0.86124432 0.92561182]
Min:                   0.05405956506729126
Max:                   0.9899935126304626

---



VIEW DV-MA-k-300

Mumber of predictions: 2160
Mean:                  0.5635489236287496
STD:                   0.271862259741844
Median:                0.6489334106445312

Same-author:           0.5777777777777777
Different-author:      0.375
Undecided:             0.04722222222222222

Quartiles:             [0.29706125 0.64893341 0.81157312]
Min:                   0.044049471616744995
Max:                   0.9589341282844543

---



VIEW DV-MA-k-20000

Mumber of predictions: 880
Mean:                  0.6612847168675878
STD:               