## Installing `sentence-transformers`

In [1]:
!pip install sentence-transformers

# Also check the GPU model when running with a GPU kernel
!nvidia-smi

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentence-transformers
  Downloading sentence-transformers-2.2.0.tar.gz (79 kB)
[K     |████████████████████████████████| 79 kB 7.3 MB/s 
[?25hCollecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.19.2-py3-none-any.whl (4.2 MB)
[K     |████████████████████████████████| 4.2 MB 38.6 MB/s 
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 28.9 MB/s 
[?25hCollecting huggingface-hub
  Downloading huggingface_hub-0.7.0-py3-none-any.whl (86 kB)
[K     |████████████████████████████████| 86 kB 6.6 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 16.3 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYA

In [2]:
# All the necessary imports

import pandas as pd
import ast
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util

In [3]:
# Model choice and some hyperparameters

bi_encoder_model = "msmarco-distilbert-base-v4"
cross_encoder_model = "cross-encoder/ms-marco-MiniLM-L-6-v2"

pre_cross_encode_k = 100
results_to_show = 10

In [4]:
# Mount drive and load datasets and model

from google.colab import drive
drive.mount("/content/gdrive")

plots = pd.read_csv("/content/gdrive/MyDrive/wiki_with_revenue.csv", compression="zip", converters={'to_embed': ast.literal_eval})
test_queries = pd.read_csv("/content/gdrive/MyDrive/summaries_test.csv", compression="zip")
id_and_summary = pd.read_csv("/content/gdrive/MyDrive/id_and_summary.csv", compression="zip")

# If running on a non-GPU kernel
#corpus_embeddings = torch.load('/content/gdrive/MyDrive/corpus_embeddings.pt', map_location=torch.device('cpu'))
# If running on a GPU kernel
corpus_embeddings = torch.load('/content/gdrive/MyDrive/corpus_embeddings.pt')

Mounted at /content/gdrive


In [5]:
# Function to query and return top `results_to_show` with associated score

def semantic_query(query_string, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, wiki_dataset):
    query_embedding = bi_encoder.encode(query_string, convert_to_tensor=True)
    pre_cross_encode_hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=pre_cross_encode_k)

    cross_inp = [[query_string, id_and_summary['to_embed'][hit['corpus_id']]] for hit in pre_cross_encode_hits[0]]
    cross_scores = cross_encoder.predict(cross_inp)
    cross_encoder_res = sorted(enumerate(cross_scores), key=lambda x: x[1], reverse=True)

    res_movie_title_and_year = []
    res_score = []

    for res in cross_encoder_res:
        if len(res_movie_title_and_year) >= results_to_show:
          break
      
        index = res[0]
        score = res[1]
        corpus_id = pre_cross_encode_hits[0][index]['corpus_id']
        movie_id = id_and_summary['MovieId'][corpus_id]
        movie_title = wiki_dataset['Title'][movie_id]
        movie_year = wiki_dataset['Release Year'][movie_id]
        if not (movie_title.strip(), movie_year) in res_movie_title_and_year:
          res_movie_title_and_year.append((movie_title.strip(), movie_year))
          res_score.append(score)
    return list(zip(res_movie_title_and_year, res_score))

# Function to test performance on a query dataset

def measure_accuracy(query_dataset, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, wiki_dataset):
  total = 0
  correct = 0

  for row in query_dataset.iterrows():
    query_string = row[1]['SummaryFragment']
    movie_id = row[1]['MovieId']

    hits = semantic_query(query_string, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, wiki_dataset)
    movie_title = wiki_dataset['Title'][movie_id]
    if movie_title.strip() in map(lambda x: x[0][0].strip(), hits):
      correct += 1
    total += 1

  return correct/total

In [6]:
bi_encoder = SentenceTransformer(bi_encoder_model)
bi_encoder.max_seq_length = 256     #Truncate long passages to 256 tokens
cross_encoder = CrossEncoder(cross_encoder_model)

Downloading:   0%|          | 0.00/690 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.71k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/545 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/265M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/319 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/794 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/86.7M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/316 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [7]:
query = "couple walks through paris all night"

semantic_query(query, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, plots)

[(('Target Unknown', 1951), 2.5969682),
 (('The Man Who Reclaimed His Head', 1934), 2.3785949),
 (('Midnight in Paris', 2011), 2.1639729),
 (('Before Sunset', 2004), 2.0388565),
 (('This Was Paris', 1942), 1.1195954),
 (('The Temptress', 1926), -1.8462456),
 (('Catacombs', 2007), -2.6320286),
 (('Target', 1985), -2.6779432),
 (('A Korean in Paris', 2016), -2.6810822),
 (('Can-Can', 1960), -2.9730914)]

## Testing performance on artificial query set

In [None]:
test_queries

Unnamed: 0,MovieId,PlotFragments,SummaryFragment,summary_length
0,25782,"It is also known that Prince Vijay, nephew of ...","the film now moves to Kiran's hotel, where ama...",18
1,17655,"One Valentine's evening a group of single, dat...",Brett (Guy Pearce) is a science journalist for...,12
2,19695,Charlie's friends won't tell him where Maggie ...,Charlie's friends won't tell him where Maggie ...,18
3,20660,"Thomas Smithers (Postlethwaite), who has made ...",Thomas Smithers (Postlethwaite) hires the famo...,21
4,22022,The plot revolves around the life of aspiring ...,plot revolves around the life of aspiring writ...,15
...,...,...,...,...
24743,11406,Bill's wishes end up causing more trouble due ...,bill's wishes end up causing more trouble due ...,12
24744,6930,Jesse (Robert Wagner) and Frank James (Jeffrey...,Jesse (Robert Wagner) and Frank James (Jeffrey...,16
24745,30788,"Aadhi feels that Dhana has changed a lot, so A...",aadhi attempts to send Dhana in jail for a mur...,16
24746,25561,Amateur boxer Ajay Mehra (Sunny Deol) is livin...,boxer is living with his brother and sister-in...,11


In [None]:
test_queries_small = test_queries.head(1000)

In [None]:
test_queries_small

Unnamed: 0,MovieId,PlotFragments,SummaryFragment,summary_length
0,25782,"It is also known that Prince Vijay, nephew of ...","the film now moves to Kiran's hotel, where ama...",18
1,17655,"One Valentine's evening a group of single, dat...",Brett (Guy Pearce) is a science journalist for...,12
2,19695,Charlie's friends won't tell him where Maggie ...,Charlie's friends won't tell him where Maggie ...,18
3,20660,"Thomas Smithers (Postlethwaite), who has made ...",Thomas Smithers (Postlethwaite) hires the famo...,21
4,22022,The plot revolves around the life of aspiring ...,plot revolves around the life of aspiring writ...,15
...,...,...,...,...
995,22211,"After their first-born baby, Pierre (Patrick G...",Pierre (Patrick Goyette) and Élisabeth (Suzie ...,17
996,14638,He hits the road looking for refuge in his pas...,"he visits his mother, who he hasn't seen in 30...",12
997,29,Hoax rushes to scene of the crime where he dis...,the tramp runs away and Hoax gives chase .,9
998,14120,Barry Egan is a single man who owns a company ...,"he calls a phone-sex line, but the operator at...",13


In [None]:
%time measure_accuracy(test_queries_small, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, plots)

CPU times: user 3min 44s, sys: 1.08 s, total: 3min 45s
Wall time: 3min 22s


0.723

In [8]:
measure_accuracy(test_queries, corpus_embeddings, bi_encoder, cross_encoder, id_and_summary, plots)

0.7287457572329077