# Embed and rerank

## Outline of classification algorithm

This notebook implements the "embed-and-rerank" algorithm, which is a 3-step process.

1. Take the corpus of all wikipedia plots, and break them up into chunks of 256 tokens, since that's the most the neural network can handle. This results in a large collection of plot summary fragments, labelled by the movie they are from.
2. The corpus of summary fragments is then embedded in $\mathbb{R}^{768}$ using a context sensitive sentence embedder. We use a BERT derived model trained on the MS-MARCO dataset. Using the same embedder, we also embed the search query string into the vector space, and then pick out the closest 100 corpus entries using a cosine-similarity metric. These 100 points are an initial guess for the movie the query string is referencing.
3. Finally, we run the query string and each of the 100 guesses through a cross-encoder, a different neural network that outputs a similarity score based on semantics between two input sentences. We pick the top 10 scoring movies as search results for the input query.

## Outline of testing method

We scraped IMDB for single sentence summaries of movies, and took them to be representative of what people might search for when looking for a movie. We then ran the classifier on the IMDB query dataset and computed the accuracy, as well as looked at examples of misclassifications, to see what kind of queries result in misclassification.

## Installing and importing packages

In [None]:
!pip install sentence-transformers

# Also check the GPU model when running with a GPU kernel
!nvidia-smi

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentence-transformers
  Downloading sentence-transformers-2.2.0.tar.gz (79 kB)
[K     |████████████████████████████████| 79 kB 3.8 MB/s 
[?25hCollecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.19.2-py3-none-any.whl (4.2 MB)
[K     |████████████████████████████████| 4.2 MB 19.1 MB/s 
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 44.2 MB/s 
[?25hCollecting huggingface-hub
  Downloading huggingface_hub-0.7.0-py3-none-any.whl (86 kB)
[K     |████████████████████████████████| 86 kB 4.4 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 57.1 MB/s 
Collecting tokenizers!=0.11.3,<

In [None]:
# All the necessary imports

import pandas as pd
from tqdm import tqdm
import ast
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util

## Setting hyperparameters and data preprocessing

We break up the plot summaries into shorter fragments, and generate a separate DataFrame for the queries.

In [None]:
# Model choice and some hyperparameters

bi_encoder_model = "msmarco-distilbert-base-v4"
cross_encoder_model = "cross-encoder/ms-marco-MiniLM-L-6-v2"

pre_cross_encode_k = 100
results_to_show = 10

In [None]:
# Mount drive and load datasets and model

from google.colab import drive
drive.mount("/content/gdrive")

plots = pd.read_csv("/content/gdrive/MyDrive/imdb_plots.csv", compression="zip", converters={'to_embed': ast.literal_eval})

plots['MovieId'] = plots.index
plots = plots.drop(['Unnamed: 0', 'Unnamed: 0.1'], axis=1)

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
movie_ids = []
to_embed = []
for row in plots.iterrows():
  movie_id = row[1]['MovieId']
  for frag in row[1]['to_embed']:
    movie_ids.append(movie_id)
    to_embed.append(frag)

id_and_summary = pd.DataFrame({'MovieId': movie_ids, 'to_embed': to_embed})

In [None]:
movie_ids = []
queries = []
for row in plots.iterrows():
  movie_id = row[1]['MovieId']
  summ1 = row[1]['imdb_1']
  summ2 = row[1]['imdb_2']
  if not pd.isna(summ1):
    movie_ids.append(movie_id)
    queries.append(summ1)
  if not pd.isna(summ2):
    movie_ids.append(movie_id)
    queries.append(summ2)

test_queries = pd.DataFrame({'MovieId': movie_ids, 'summary': queries})

## Embedding corpus in $\mathbb{R}^{768}$

In [None]:
bi_encoder = SentenceTransformer(bi_encoder_model)
bi_encoder.max_seq_length = 256     #Truncate long passages to 256 tokens

corpus_embeddings = bi_encoder.encode(id_and_summary['to_embed'], convert_to_tensor=True, show_progress_bar=True)

Downloading:   0%|          | 0.00/690 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.71k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/545 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/265M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/319 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Batches:   0%|          | 0/934 [00:00<?, ?it/s]

## Functions to perform the actual search as well as compute accuracy

In [None]:
# Function to query and return top `results_to_show` with associated score

def semantic_query(query_string, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, wiki_dataset):
    query_embedding = bi_encoder.encode(query_string, convert_to_tensor=True)
    pre_cross_encode_hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=pre_cross_encode_k)

    cross_inp = [[query_string, id_and_summary['to_embed'][hit['corpus_id']]] for hit in pre_cross_encode_hits[0]]
    cross_scores = cross_encoder.predict(cross_inp)
    cross_encoder_res = sorted(enumerate(cross_scores), key=lambda x: x[1], reverse=True)

    res_movie_title_and_year = []
    res_score = []

    for res in cross_encoder_res:
        if len(res_movie_title_and_year) >= results_to_show:
          break
      
        index = res[0]
        score = res[1]
        corpus_id = pre_cross_encode_hits[0][index]['corpus_id']
        movie_id = id_and_summary['MovieId'][corpus_id]
        movie_title = wiki_dataset['Title'][movie_id]
        movie_year = wiki_dataset['Release Year'][movie_id]
        if not (movie_title.strip(), movie_year) in res_movie_title_and_year:
          res_movie_title_and_year.append((movie_title.strip(), movie_year))
          res_score.append(score)
    return list(zip(res_movie_title_and_year, res_score))

# Function to test performance on a query dataset

def measure_accuracy(query_dataset, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, wiki_dataset):
  total = 0
  correct = 0

  for row in tqdm(query_dataset.iterrows()):
    query_string = row[1]['summary']
    movie_id = row[1]['MovieId']
    hits = semantic_query(query_string, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, wiki_dataset)
    movie_title = wiki_dataset['Title'][movie_id]
    if movie_title.strip() in map(lambda x: x[0][0].strip(), hits):
      correct += 1
    total += 1
    #print(correct/total)

  return correct/total

In [None]:
cross_encoder = CrossEncoder(cross_encoder_model)

## An example

Below is an example of a query. Note that it is reasonably vague, i.e. doesn't actually refer to any characters of the movie, and yet the top result is indeed the movie I was thinking of.

In [None]:
query = "italian truck driver is killed by police while smuggling"
semantic_query(query, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, plots)

[(('The Rose Tattoo', 1955), 2.5438912),
 (('F/X', 1986), -0.64561),
 (('Suspect Zero', 2004), -0.9597426),
 (("Carlito's Way", 1993), -1.1318713),
 (('Exit Wounds', 2001), -1.1443862),
 (('The Transporter', 2002), -1.6462729),
 (('Drive', 2011), -1.7802924),
 (('The Tourist', 2010), -2.5319207),
 (('Gang Related', 1997), -2.646069),
 (('Fast & Furious', 2009), -2.796188)]

## Testing performance on IMDB query set

In [None]:
measure_accuracy(test_queries, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, plots)

9987it [42:34,  3.91it/s]


0.8419945929708621

## Analyzing misclassifications