In [1]:
# Data from:  https://www.kaggle.com/datasets/raingo/tumblr-gif-description-dataset?resource=download
# Function to put the data in a list
def form_list(FILE_NAME):
  list = []
  with open(FILE_NAME) as f:
    for line in f: 
        line = line.strip()
        list.append(line)
  return list

# Function to put the annotation data in a list
def form_annotation_list(FILE_NAME):
  list = []
  with open(FILE_NAME) as f:
    for line in f: 
        line = line.strip()
        # Remove the first and last characters from a string
        line = line[1:-1]
        annotation = []
        for label in line.split(','):
          annotation.append(int(label))
        list.append(annotation)
  return list

In [2]:
!pip install nltk
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...


True

In [3]:
import nltk
from nltk.corpus import wordnet as wn
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
import re


In [4]:
from nltk.corpus import stopwords

def remove_stop_words(caption):
  filtered_text = ' '.join([word for word in caption.split() if word not in (stopwords.words('english'))])
  return filtered_text


In [5]:
# stemming of words
from nltk.stem.porter import PorterStemmer
def stem(tokens):
  porter = PorterStemmer()
  stemmed = [porter.stem(word) for word in tokens]
  return stemmed

In [6]:
def cleanup(caption):
    caption = caption.lower()
    # remove extra spaces in between
    caption = re.sub(' +', ' ', caption)
    # remove punctuation
    caption = re.sub('[^a-zA-Z]', ' ', caption)
    # remove underscores
    caption = re.sub(r'_', ' ', caption)
    # remove stop words
    caption = remove_stop_words(caption)
    # stem
    return ' '.join(stem(caption.split()))


In [7]:
text = "eating flying dogs"
print(cleanup(text))

eat fli dog


In [8]:
# Loading data 
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

ANNOTATIONS_BASE_FILE_NAME = '/content/annotations_1.txt'
ANNOTATIONS_SYNEXP_FILE_NAME = '/content/annotations_2.txt'
ANNOTATIONS_RELFEED_FILE_NAME = '/content/annotations_3.txt'
LABELS_FILE_NAME = '/content/tgif-v1.0.tsv'

annotations_base = form_annotation_list(ANNOTATIONS_BASE_FILE_NAME)
annotations_synexp = form_annotation_list(ANNOTATIONS_SYNEXP_FILE_NAME)
annotations_rf = form_annotation_list(ANNOTATIONS_RELFEED_FILE_NAME)

print(annotations_base[0])
labels = {}
original_labels = {}

file = open(LABELS_FILE_NAME, 'r')
Lines = file.readlines()
  
# Strips the newline character
for line in Lines:
    split_line = line.split("\t",1)
    link = split_line[0]
    caption = split_line[1]
    caption = caption.replace("\n", "")
    caption = caption.replace(".", "") # remove period at the end
    original_labels[link] = original_labels.get(link, "") + " " + caption
    original_labels[link] = re.sub(' +', ' ', original_labels[link])

    labels[link] = labels.get(link, "") + " " + cleanup(caption)
    labels[link] = re.sub(' +', ' ', labels[link])

# print(labels[training_list[0]])

Mounted at /content/drive
[1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1]


In [9]:
print(original_labels["https://31.media.tumblr.com/00e5de7a267faca618b337af77a99aac/tumblr_n9qcwdJTNX1tib6auo1_400.gif"])

 a woman wearing a black outfit moves her head around a game show contestant is thinking about her answer a woman in black is about to cry


In [10]:
print(len(original_labels))

102068


In [11]:
## Queries 
queries = form_list('/content/queries.txt')

print(queries)

['salute', 'apple', 'run fast', 'shoot gun', 'happy dancing', 'baby laughing', 'eating pizza', 'screaming loud', 'cats fighting', 'driving a car', 'bike', 'sleeping in bed', 'christmas holidays', 'santa clause', 'fall on floor', 'eating ice cream', 'reading a book', 'harry potter']


In [12]:
def combine_dictionaries(big_dict, small_dict):
  keys = list(small_dict.keys())
  for key in keys:
    value = big_dict.get(key, 0) + small_dict[key]
    big_dict[key] = value

  # Updated big_dict 
  return big_dict

In [13]:
def get_term2freq_for_gif(gif, caption):
    term2freq = {}
    # Get sentence
    words = caption.split()
    for word in words:
      term2freq[word] = term2freq.get(word, 0) + 1

    return term2freq

In [14]:
def get_term2freq_for_query(query):
    term2freq = {}
    # Get sentence
    words = query.split()
    for word in words:
      term2freq[word] = term2freq.get(word, 0) + 1

    return term2freq

In [15]:
# Build an inverted index for the gifs
def build_inverted_index(labels):
    invertedIndex = {}
    for gif in labels.keys(): 
      caption = labels[gif]
      term2freq_for_gif = get_term2freq_for_gif(gif, caption)

      words = list(term2freq_for_gif.keys())

      for word in words:
        if word not in invertedIndex.keys():
            gifIDCount = {gif : term2freq_for_gif[word]}
            invertedIndex[word] = gifIDCount
        else:
            gifIDCount = {gif : term2freq_for_gif[word]}
            invertedIndex[word].update(gifIDCount)
  
    return invertedIndex

In [16]:
inverted_index = build_inverted_index(labels)

In [17]:
print(inverted_index['rais'])

{'https://38.media.tumblr.com/a5f27adf06dac45af78b4976707a2f55/tumblr_mrmmwo38f01qc03ipo1_250.gif': 1, 'https://38.media.tumblr.com/b27e943b5267ca7b4323907f02fcb747/tumblr_mkn76n09cL1ryp9xpo1_500.gif': 1, 'https://38.media.tumblr.com/634701cb36814d10a4db6679552dd8cd/tumblr_nnv3paZgju1tqq823o1_500.gif': 1, 'https://38.media.tumblr.com/6176943d5d86679c6821bca330d69a57/tumblr_nph5z5to4E1uv0y2ro1_250.gif': 1, 'https://38.media.tumblr.com/914ba3f9340178e1482367d16094ac51/tumblr_nqdtt93Mpf1uwuiu2o1_250.gif': 1, 'https://38.media.tumblr.com/4298051c401bff504d17abf44aa6b9b8/tumblr_n9xkasJJ4C1r18t02o1_500.gif': 1, 'https://38.media.tumblr.com/2fc206b7d4915f05d6002a6811562779/tumblr_noz2mrcvzi1qzt7d8o1_500.gif': 1, 'https://38.media.tumblr.com/de649d1e0a0a98b92aae1d7db40b305a/tumblr_novvdnSPAI1ruma88o1_500.gif': 1, 'https://38.media.tumblr.com/cf6215e6902e59a79b75d2024ea044ea/tumblr_n3x8r7dwbR1rsgsgho1_r3_500.gif': 1, 'https://38.media.tumblr.com/e98281172ea24271db4b279a8d29356d/tumblr_nq1uexbpw

In [18]:
import math 
def get_idf(labels):
    num_of_gifs = len(labels.keys())
    idf = {}
    gif_urls = list(labels.keys())
    for gif_url in gif_urls:
        caption = labels[gif_url]
        words = list(set(caption.split()))
        for word in words:
          idf[word] = idf.get(word, 0) + 1

    # From: https://en.wikipedia.org/wiki/Okapi_BM25
    for word in list(idf.keys()):
        idf[word] = math.log((num_of_gifs - idf[word] + 0.5) / (idf[word] + 0.5))
    return idf

In [19]:
idf = get_idf(labels)
print(idf['appl'])

7.783483786514411


In [20]:
# Get average caption length
def get_avg_caption_len(labels):
  gif_urls = list(labels.keys())
  total_num_of_words = 0

  for gif_url in gif_urls:
    caption = labels[gif_url]
    num_of_words = len(list(caption.split()))
    total_num_of_words += num_of_words

  avg_caption_len = total_num_of_words / len(gif_urls)
  return avg_caption_len

avg_caption_len = get_avg_caption_len(labels)
print(f'avg_caption_len: {avg_caption_len}')

avg_caption_len: 6.890396598346201


In [21]:
def bm25(query, gif, idf, avg_caption_len=avg_caption_len):
    k1 = 1.2
    k2 = 1
    b = 0.75
    score = 0.0
    for word in query:
        if gif.get(word) == None:
            continue
        f_i = gif[word]
        qf_i = query[word]
        doc_len = sum(gif.values())
        K = k1 * (1 - b + b * doc_len / avg_caption_len)
        R1 = f_i * (k1 + 1) / (f_i + K)
        R2 = qf_i * (k2 + 1) / (qf_i + k2) # Might not need R2
        R = R1 * R2
        score += idf[word] * R
    return score

In [22]:
def GetScore(query, gif_url, idf, labels):
    query = get_term2freq_for_query(query)
    gif = get_term2freq_for_gif(gif_url, labels[gif_url])
    score = bm25(query, gif, idf)
    return score

In [23]:
def search(query, k, idf, labels):
  results = []
  gif_urls = list(labels.keys())

  for gif_url in gif_urls:
     gif_score = GetScore(query, gif_url, idf, labels)
     results.append((gif_url, gif_score))

  # sort results
  results = sorted(results, key = lambda x: x[1], reverse=True)
  return results[:k]

In [24]:
from IPython.display import display, Image

def display_gif(gif_url):
  display(Image(url=gif_url))

In [25]:
def display_results(results):
  for result in results:
    gif_url = result[0]
    caption = original_labels[gif_url]
    display_gif(gif_url)
    print(caption)


In [55]:
## For testing
results = search("run fast run talli test fast fast debauch", 15, idf, labels)
print(results)
display_results(results)

[('https://38.media.tumblr.com/1f288181de4561a6dd2d7477658152c9/tumblr_niquu7RXao1rw4ylno1_500.gif', 17.56417826772382), ('https://31.media.tumblr.com/88d5faacbdeb262920ffc1b39b63252c/tumblr_nqmangoIJt1thaub3o1_400.gif', 16.30534961304674), ('https://38.media.tumblr.com/1414b751a27823351703d8df1ba0e22a/tumblr_npx54aVWl01tlfqo1o1_500.gif', 16.30534961304674), ('https://33.media.tumblr.com/4d3671077ffe19629a9497ad1cdba361/tumblr_nq4ti4hkpl1tlq8yzo1_400.gif', 16.30534961304674), ('https://33.media.tumblr.com/f78734767b60782af5579c682e4f1e33/tumblr_ni9n6zeZrD1r60dsbo1_500.gif', 16.30534961304674), ('https://31.media.tumblr.com/f90de865921d24fb8c6c6db5436d1162/tumblr_na2jgwVY0O1to7f2mo1_500.gif', 16.30534961304674), ('https://38.media.tumblr.com/73723201f3c6d067930b6dceb1155668/tumblr_nhtnf6wTEt1tx8mn0o1_400.gif', 16.30534961304674), ('https://38.media.tumblr.com/8c14aa571911985d6d774762f8159452/tumblr_n4gk7efcMw1spy7ono1_400.gif', 15.214894650029972), ('https://31.media.tumblr.com/5a640ea1

 men are running as fast as they can


 a red car is running so fast


 a boy with a handicap is running so fast


 this is a woman running very fast on a sidewalk


 a fast car is running in a road


 a woman runs very fast and weird


 a fast car is running through all the road


 the tiger is running really fast in the garden


 a very fast person runs into a smoke tornado


 many people is running fast on the road


 red car is running on the street fast


 a man is running very fast around a track


 A kangaroo is running through the woods fast and falling


 the car blue is running on road fast


 a woman is running fast and looking behind her


In [None]:
idf = get_idf(labels)
k = 15
for query in queries:
  # clean up query
  query = cleanup(query)
  
  results = search(query, k, idf, labels)
  print(results)
  display_results(results)

In [28]:
query2rel = {}
for i in range(len(queries)):
  query = cleanup(queries[i])
  annotation = annotations_base[i]
  query2rel[query] = annotation

print(query2rel)

{'salut': [1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1], 'appl': [1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'run fast': [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1], 'shoot gun': [1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'happi danc': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'babi laugh': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'eat pizza': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'scream loud': [1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1], 'cat fight': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'drive car': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'bike': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'sleep bed': [1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'christma holiday': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'santa claus': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'fall floor': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'eat ice cream': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'read book': [1, 1, 1, 1, 1, 1, 1, 1

**MAP**

In [29]:
def calculate_average_precision(results, relevance_list):
  sum_of_relevant_precisions = 0
  num_of_relevant_seen_so_far = 0

  for i in range(len(results)):
    gif_url = results[i][0]
    relevance = relevance_list[i]
    # If gif is relevant then we need to compute precision
    if relevance > 0:
      num_of_relevant_seen_so_far += 1
      doc_precision = num_of_relevant_seen_so_far / (i + 1)
      sum_of_relevant_precisions += doc_precision
  if num_of_relevant_seen_so_far == 0:
    return 0
  return sum_of_relevant_precisions / num_of_relevant_seen_so_far

In [30]:
idf = get_idf(labels)
k = 15
sum_of_average_precisions = 0
num_of_queries = len(queries)

for query in queries:
  # clean up query
  query = cleanup(query)
  
  results = search(query, k, idf, labels)
  average_precision = calculate_average_precision(results, query2rel[query])
  sum_of_average_precisions += average_precision

mean_average_precision = sum_of_average_precisions / num_of_queries
print(f"MAP@{k}: {mean_average_precision}")


MAP@15: 0.9532268773034951


Query Expansion using Synonyms only

In [31]:
from nltk.corpus import wordnet

def expand_query(query):
  # add synonyms to query to get more terms
  synonyms = [query]
  count = 0
  for x in query.split():
      for syn in wordnet.synsets(x):
          for l in syn.lemmas():
              if(count < 3):
                  if l.name() not in synonyms:
                      synonyms.append(cleanup(l.name()))
                      count += 1
                      
      count=0

  return ' '.join(synonyms)
  

In [47]:
expand_query(cleanup("run fast"))

'run fast run talli test fast fast debauch'

In [None]:
## For testing
results = search(expand_query(cleanup("harry potter")), 15, idf, labels)
print(results)
display_results(results)

In [None]:
idf = get_idf(labels)
k = 15
for query in queries:
  # clean up query
  query = cleanup(query)
  # expand query
  query = expand_query(query)
  results = search(query, k, idf, labels)
  print(results)
  display_results(results)

In [35]:
query2rel_synexp = {}
for i in range(len(queries)):
  query = expand_query(cleanup(queries[i]))
  annotation = annotations_synexp[i]
  query2rel_synexp[query] = annotation

print(query2rel_synexp)

{'salut': [1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1], 'appl': [1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'run fast run talli test fast fast debauch': [1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1], 'shoot gun shoot hit pip gun artilleri heavi weapon': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'happi danc': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'babi laugh laugh laughter joke': [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0], 'eat pizza eat feed eat pizza pizza pie': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'scream loud scream scream shriek loud brassi cheap': [1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1], 'cat fight cat true cat guy battl conflict fight': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'drive car drive thrust drive forc car auto automobil': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'bike motorcycl bicycl wheel': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'sleep bed sleep slumber sopor bed bottom seam': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [36]:
idf = get_idf(labels)
k = 15
sum_of_average_precisions = 0
num_of_queries = len(queries)

for query in queries:
  # clean up query
  query = cleanup(query)
  # expand query
  query = expand_query(query)

  results = search(query, k, idf, labels)
  average_precision = calculate_average_precision(results, query2rel_synexp[query])
  sum_of_average_precisions += average_precision

mean_average_precision = sum_of_average_precisions / num_of_queries
print(f"MAP@{k}: {mean_average_precision}")


MAP@15: 0.9113117708050554


**Relevance Feedback Part**

In [37]:
def queryFrequency(query, invertedIndex):
    queryFreq = {}
    for term in query.split():
        if term in queryFreq.keys():
            queryFreq[term] += 1
        else:
            queryFreq[term] = 1
    for term in invertedIndex:
        if term not in queryFreq.keys():
            queryFreq[term] = 0
    return queryFreq

In [38]:
import math
def findDocMagnitude(docIndex):
    mag = 0
    for term in docIndex:
        mag += float(docIndex[term]**2)
        mag = float(math.sqrt(mag))
    return mag

In [39]:
def findDocs(sortedBM25Score, invertedIndex, labels, annotations):
    relIndex = {}
    nonRelIndex = {}

    for i in range(len(sortedBM25Score)):
      gif_url,gif_score = sortedBM25Score[i]
      caption = labels[gif_url]
      inverted_index_for_gif = get_term2freq_for_gif(gif_url, caption)
      annotation = annotations[i]
      # Relevant
      if annotation == 1:
        relIndex = combine_dictionaries(relIndex, inverted_index_for_gif)

      else:
        nonRelIndex = combine_dictionaries(nonRelIndex, inverted_index_for_gif)

    for term in invertedIndex:
        if term not in relIndex.keys():
            relIndex[term] = 0

    for term in invertedIndex:
        if term not in nonRelIndex.keys():
            nonRelIndex[term] = 0 

    return relIndex, nonRelIndex

In [40]:
ALPHA = 1
BETA = 0.75
GAMMA = 0.15
def findRocchioScore(term, queryFreq, relDocMag, relIndex, nonRelMag, nonRelIndex):
    Q1 = ALPHA * queryFreq[term] 
    Q2 = (BETA/relDocMag) * relIndex[term]
    Q3 = (GAMMA/nonRelMag) * nonRelIndex[term]
    rocchioScore = ALPHA * queryFreq[term] + (BETA/relDocMag) * relIndex[term] - (GAMMA/nonRelMag) * nonRelIndex[term]
    return rocchioScore

In [41]:
def findNewQuery(query, old_results, invertedIndex, labels):
    queryFreq = queryFrequency(query, invertedIndex)
    
    relIndex, nonRelIndex = findDocs(old_results, invertedIndex, labels, query2rel[query])
    relDocMag = findDocMagnitude(relIndex)
    nonRelMag = findDocMagnitude(nonRelIndex)

    if nonRelMag == 0:
       # no unrelevant docs
       return query

    updatedQuery = {}
    newQuery = query
    for term in invertedIndex:
        updatedQuery[term] = findRocchioScore(term, queryFreq, relDocMag, relIndex, nonRelMag, nonRelIndex)
    sortedUpdatedQuery = sorted(updatedQuery.items(), key=lambda x:x[1], reverse=True)
    print(sortedUpdatedQuery)
    if len(sortedUpdatedQuery) < 3:
        print("Less than three")
        loopRange = len(sortedUpdatedQuery)
    else:
        loopRange = 3
    for i in range(loopRange):
        term,frequency = sortedUpdatedQuery[i]
        if term not in query:
            newQuery +=  " "
            newQuery +=  term
    return newQuery

In [42]:
## Lile the original score of the BM25 baseline model, a higher score should mean a better result.
def new_ranking_function(query, old_results, labels, k=15):
  invertedIndex = build_inverted_index(labels)
  idf = get_idf(labels)

  
  newQuery = findNewQuery(query, old_results, invertedIndex, labels)
  print("OLD QUERY: {}".format(query))
  print("NEW QUERY: {}".format(newQuery))
  results = search(query, k, idf, labels)

  return results

In [43]:
query2rel_rf = {}
for i in range(len(queries)):
  query = cleanup(queries[i])
  annotation = annotations_rf[i]
  query2rel_rf[query] = annotation

print(query2rel_rf)

{'salut': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'appl': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'run fast': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'shoot gun': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'happi danc': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'babi laugh': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'eat pizza': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'scream loud': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'cat fight': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'drive car': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'bike': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'sleep bed': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'christma holiday': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'santa claus': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'fall floor': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'eat ice cream': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'read book': [1, 1, 1, 1, 1, 1, 1, 1

In [44]:
idf = get_idf(labels)
k = 15
sum_of_average_precisions = 0
num_of_queries = len(queries)

for query in queries:
  # clean up query
  query = cleanup(query)
  
  old_results = search(query, k, idf, labels)
  results = new_ranking_function(query, old_results, labels, k=15)
  average_precision = calculate_average_precision(results, query2rel_rf[query]) #TODO
  sum_of_average_precisions += average_precision

mean_average_precision = sum_of_average_precisions / num_of_queries
print(f"MAP@{k}: {mean_average_precision}")


[('salut', 14.2), ('woman', 9.6), ('man', 6.0), ('make', 2.25), ('hand', 2.25), ('walk', 1.5), ('sing', 1.5), ('talk', 1.5), ('finger', 1.5), ('sign', 1.5), ('sunglass', 0.75), ('motion', 0.75), ('room', 0.75), ('laugh', 0.75), ('light', 0.75), ('young', 0.75), ('away', 0.75), ('smile', 0.75), ('microphon', 0.75), ('singer', 0.75), ('friend', 0.75), ('person', 0.75), ('turn', 0.75), ('point', 0.75), ('two', 0.75), ('pink', 0.75), ('give', 0.75), ('audienc', 0.75), ('glass', 0.75), ('mirror', 0.75), ('suit', 0.75), ('ladi', 0.75), ('use', 0.75), ('track', 0.75), ('somebodi', 0.75), ('spotlight', 0.75), ('glare', 0.0), ('someon', 0.0), ('appear', 0.0), ('cat', 0.0), ('tri', 0.0), ('catch', 0.0), ('mous', 0.0), ('tablet', 0.0), ('dress', 0.0), ('red', 0.0), ('danc', 0.0), ('anim', 0.0), ('come', 0.0), ('close', 0.0), ('anoth', 0.0), ('jungl', 0.0), ('hat', 0.0), ('adjust', 0.0), ('tie', 0.0), ('weird', 0.0), ('face', 0.0), ('put', 0.0), ('wrap', 0.0), ('paper', 0.0), ('bow', 0.0), ('brune

In [48]:
idf = get_idf(labels)
k = 15
sum_of_average_precisions = 0
num_of_queries = len(queries)


query = cleanup("salute")

old_results = search(query, k, idf, labels)
results = new_ranking_function(query, old_results, labels, k=15)




[('salut', 14.2), ('woman', 9.6), ('man', 6.0), ('make', 2.25), ('hand', 2.25), ('walk', 1.5), ('sing', 1.5), ('talk', 1.5), ('finger', 1.5), ('sign', 1.5), ('sunglass', 0.75), ('motion', 0.75), ('room', 0.75), ('laugh', 0.75), ('light', 0.75), ('young', 0.75), ('away', 0.75), ('smile', 0.75), ('microphon', 0.75), ('singer', 0.75), ('friend', 0.75), ('person', 0.75), ('turn', 0.75), ('point', 0.75), ('two', 0.75), ('pink', 0.75), ('give', 0.75), ('audienc', 0.75), ('glass', 0.75), ('mirror', 0.75), ('suit', 0.75), ('ladi', 0.75), ('use', 0.75), ('track', 0.75), ('somebodi', 0.75), ('spotlight', 0.75), ('glare', 0.0), ('someon', 0.0), ('appear', 0.0), ('cat', 0.0), ('tri', 0.0), ('catch', 0.0), ('mous', 0.0), ('tablet', 0.0), ('dress', 0.0), ('red', 0.0), ('danc', 0.0), ('anim', 0.0), ('come', 0.0), ('close', 0.0), ('anoth', 0.0), ('jungl', 0.0), ('hat', 0.0), ('adjust', 0.0), ('tie', 0.0), ('weird', 0.0), ('face', 0.0), ('put', 0.0), ('wrap', 0.0), ('paper', 0.0), ('bow', 0.0), ('brune