# Embed and rerank

## Outline of classification algorithm

This notebook implements the "embed-and-rerank" algorithm, which is a 3-step process.

1. Take the corpus of all wikipedia plots, and break them up into chunks of 256 tokens, since that's the most the neural network can handle. This results in a large collection of plot summary fragments, labelled by the movie they are from.
2. The corpus of summary fragments is then embedded in $\mathbb{R}^{768}$ using a context sensitive sentence embedder. We use a BERT derived model trained on the MS-MARCO dataset. Using the same embedder, we also embed the search query string into the vector space, and then pick out the closest 100 corpus entries using a cosine-similarity metric. These 100 points are an initial guess for the movie the query string is referencing.
3. Finally, we run the query string and each of the 100 guesses through a cross-encoder, a different neural network that outputs a similarity score based on semantics between two input sentences. We pick the top 10 scoring movies as search results for the input query.

## Outline of testing method

We scraped IMDB for single sentence summaries of movies, and took them to be representative of what people might search for when looking for a movie. We then ran the classifier on the IMDB query dataset and computed the accuracy, as well as looked at examples of misclassifications, to see what kind of queries result in misclassification.

## Installing and importing packages

In [1]:
!pip install sentence-transformers

# Also check the GPU model when running with a GPU kernel
!nvidia-smi

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentence-transformers
  Downloading sentence-transformers-2.2.0.tar.gz (79 kB)
[K     |████████████████████████████████| 79 kB 5.4 MB/s 
[?25hCollecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.19.2-py3-none-any.whl (4.2 MB)
[K     |████████████████████████████████| 4.2 MB 67.4 MB/s 
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 54.0 MB/s 
[?25hCollecting huggingface-hub
  Downloading huggingface_hub-0.7.0-py3-none-any.whl (86 kB)
[K     |████████████████████████████████| 86 kB 7.1 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 43.7 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloadin

In [2]:
# All the necessary imports

import pandas as pd
from tqdm import tqdm
import ast
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util

## Setting hyperparameters and data preprocessing

We break up the plot summaries into shorter fragments, and generate a separate DataFrame for the queries.

In [3]:
# Model choice and some hyperparameters

bi_encoder_model = "msmarco-distilbert-base-v4"
cross_encoder_model = "cross-encoder/ms-marco-MiniLM-L-6-v2"

pre_cross_encode_k = 100
results_to_show = 10

In [4]:
# Mount drive and load datasets and model

from google.colab import drive
drive.mount("/content/gdrive")

plots = pd.read_csv("/content/gdrive/MyDrive/imdb_plots.csv", compression="zip", converters={'to_embed': ast.literal_eval})

plots['MovieId'] = plots.index
plots = plots.drop(['Unnamed: 0', 'Unnamed: 0.1'], axis=1)

Mounted at /content/gdrive


In [5]:
movie_ids = []
to_embed = []
for row in plots.iterrows():
  movie_id = row[1]['MovieId']
  for frag in row[1]['to_embed']:
    movie_ids.append(movie_id)
    to_embed.append(frag)

id_and_summary = pd.DataFrame({'MovieId': movie_ids, 'to_embed': to_embed})

In [6]:
movie_ids = []
queries = []
for row in plots.iterrows():
  movie_id = row[1]['MovieId']
  summ1 = row[1]['imdb_1']
  summ2 = row[1]['imdb_2']
  if not pd.isna(summ1):
    movie_ids.append(movie_id)
    queries.append(summ1)
  if not pd.isna(summ2):
    movie_ids.append(movie_id)
    queries.append(summ2)

test_queries = pd.DataFrame({'MovieId': movie_ids, 'summary': queries})

## Embedding corpus in $\mathbb{R}^{768}$

In [7]:
bi_encoder = SentenceTransformer(bi_encoder_model)
bi_encoder.max_seq_length = 256     #Truncate long passages to 256 tokens

corpus_embeddings = bi_encoder.encode(id_and_summary['to_embed'], convert_to_tensor=True, show_progress_bar=True)

Downloading:   0%|          | 0.00/690 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.71k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/545 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/265M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/319 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Batches:   0%|          | 0/934 [00:00<?, ?it/s]

## Functions to perform the search, compute accuracy and look for misclassifications

In [9]:
# Function to query and return top `results_to_show` with associated score

def semantic_query(query_string, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, wiki_dataset):
    query_embedding = bi_encoder.encode(query_string, convert_to_tensor=True)
    pre_cross_encode_hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=pre_cross_encode_k)

    cross_inp = [[query_string, id_and_summary['to_embed'][hit['corpus_id']]] for hit in pre_cross_encode_hits[0]]
    cross_scores = cross_encoder.predict(cross_inp, activation_fct=torch.nn.Sigmoid())
    cross_encoder_res = sorted(enumerate(cross_scores), key=lambda x: x[1], reverse=True)

    res_movie_title_and_year = []
    res_score = []

    for res in cross_encoder_res:
        if len(res_movie_title_and_year) >= results_to_show:
          break
      
        index = res[0]
        score = res[1]
        corpus_id = pre_cross_encode_hits[0][index]['corpus_id']
        movie_id = id_and_summary['MovieId'][corpus_id]
        movie_title = wiki_dataset['Title'][movie_id]
        movie_year = wiki_dataset['Release Year'][movie_id]
        if not (movie_title.strip(), movie_year) in res_movie_title_and_year:
          res_movie_title_and_year.append((movie_title.strip(), movie_year))
          res_score.append(score)
    return list(zip(res_movie_title_and_year, res_score))

# Function to test performance on a query dataset

def measure_accuracy(query_dataset, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, wiki_dataset):
  total = 0
  correct = 0

  for row in tqdm(query_dataset.iterrows()):
    query_string = row[1]['summary']
    movie_id = row[1]['MovieId']
    hits = semantic_query(query_string, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, wiki_dataset)
    movie_title = wiki_dataset['Title'][movie_id]
    if movie_title.strip() in map(lambda x: x[0][0].strip(), hits):
      correct += 1
    total += 1
    #print(correct/total)

  return correct/total

# Function to show misclassifications
def show_misclassifications(count, query_dataset, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, wiki_dataset):
  misclassifications = []

  while len(misclassifications) < count:
    row = query_dataset.sample()
    #print(row)
    query_string = row.iloc[0]['summary']
    actual_movie_id = row.iloc[0]['MovieId']
    actual_movie_title = wiki_dataset['Title'][actual_movie_id]
    hits = semantic_query(query_string, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, wiki_dataset)
    if not actual_movie_title.strip() in map(lambda x: x[0][0].strip(), hits):
      misclassifications.append((query_string, actual_movie_title, wiki_dataset['Plot'][actual_movie_id], hits))

  return misclassifications

In [10]:
cross_encoder = CrossEncoder(cross_encoder_model)

Downloading:   0%|          | 0.00/794 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/86.7M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/316 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

## An example

Below is an example of a query. Note that it is reasonably vague, i.e. doesn't actually refer to any characters of the movie, and yet the top result is indeed the movie I was thinking of.

In [11]:
queries = [
           "In World War I France, a pilot falls in love with the wife of his friend and superior officer.",
           "A tale of an inner city drug dealer who turns away from a life of crime to pursue his passion of rap music.",
           "A soft and hesitant young man is in danger when tries to break toxic relationships with a mysterious stranger claiming to be his friend.",
           "Two noble Scottish brothers deliberately take opposite sides when Bonnie Prince Charlie returns to claim the throne of Scotland in order to preserve the family fortune.",
           "Coming together to solve a series of murders in New York City are a police detective and an assassin, who will be hunted by the police, the mob, and a ruthless corporation.",
           "A small town's women give birth to unfriendly alien children posing as humans.",
]
for query in queries:
  print(semantic_query(query, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, plots))

[(('The Other Side of Midnight', 1977), 0.9817696), (('The Woman I Love', 1937), 0.7452528), (('The Hunters', 1958), 0.55296546), (('Half Shot at Sunrise', 1930), 0.40081862), (('The Deep Blue Sea', 2011), 0.21304908), (('The White Cliffs of Dover', 1944), 0.19599923), (('The Four Horsemen of the Apocalypse', 1921), 0.14886513), (('Cavalcade', 1933), 0.12770672), (('Little Boy Lost', 1953), 0.08603963), (('Mrs. Miniver', 1942), 0.07356316)]
[(('Friday', 1995), 0.0043068533), (('Some Kind of Hero', 1982), 0.0029227412), (('Prince of the City', 1981), 0.0020679652), (("Get Rich or Die Tryin'", 2005), 0.0019445709), (("Don't Be a Menace to South Central While Drinking Your Juice in the Hood", 1996), 0.0017086571), (('Out for Justice', 1991), 0.0013602166), (('Once Upon a Time in America', 1984), 0.00088909507), (('Cool as Ice', 1991), 0.00061771146), (('Thief', 1981), 0.00048288517), (('Dillinger', 1973), 0.00045474397)]
[(("Alice Doesn't Live Here Anymore", 1974), 0.11200594), (('The Bod

## Testing performance on IMDB query set

In [None]:
measure_accuracy(test_queries, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, plots)

9987it [42:34,  3.91it/s]


0.8419945929708621

## Analyzing misclassifications

In [None]:
misclassifications = show_misclassifications(10, test_queries, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, plots)
misclassifications

[("A mentally disturbed young woman takes a job at a posh country club and falls in with a clique of wealthy college kids where she's taken under the wing of the clique's twisted leader, who harbors some dark secrets too terrifying to tell.",
  'The In Crowd',
  'When Adrien Williams is released from a mental institution, her former doctor, Henry Thompson immediately tries to get her back on her feet by getting her a job at a country club on the East Coast, where she is introduced to the lifestyle of the snooty, rich "beautiful people".\r\nBrittany Foster, a young woman who lives in the area, befriends Adrien and takes her under her wing, accepting her as part of a clique of wealthy teenagers. Brittany\'s group of friends make comments about how much Adrien looks like Brittany\'s older sister who had moved away. At first she enjoys being a close confidant of Brittany but Adrien soon begins to discover how twisted Brittany actually is when Matt Curtis, an object of Brittany\'s affection

Observe that for all the misclassified queries, the query describes the global structure of the plot, i.e. relates events that happen at the beginning and the end of the movie. However, the sentence embeddings can only take 256 tokens, which means for longer plot summaries, the beginning and end are distinct fragments, and the query is in "between" them, and therefore close to neither.