In [1]:
import pandas
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df_question = pandas.read_csv('../data/ground-truth-retrieval.csv')

In [3]:
df_question.head()


Unnamed: 0,id,question
0,39264,Who created the humanoid robot named Jet Jagua...
1,39264,What undersea race of people seized Jet Jaguar...
2,39264,Why did the Seatopians send Megalon to the sur...
3,39264,Which companies produced the movie Godzilla vs...
4,39264,What are some keywords associated with the mov...


In [4]:
ground_truth = df_question.to_dict(orient='records')

In [5]:
ground_truth[0]

{'id': 39264,
 'question': 'Who created the humanoid robot named Jet Jaguar in the movie Godzilla vs. Megalon?'}

In [6]:
def hit_rate(relevance_total):
    cnt = 0

    for line in relevance_total:
        if True in line:
            cnt = cnt + 1

    return cnt / len(relevance_total)

def mrr(relevance_total):
    total_score = 0.0

    for line in relevance_total:
        for rank in range(len(line)):
            if line[rank] == True:
                total_score = total_score + 1 / (rank + 1)

    return total_score / len(relevance_total)

In [7]:
from elasticsearch import Elasticsearch
es_client = Elasticsearch('http://localhost:9200')

index_name = 'movies'

def elastic_search(query):

    search_query = {
        "size": 10,
        "query": {
            "multi_match": {
                "query": query,
                "fields": ["title^3", "description^2", "overview^1.5", "genres", "keywords"],
                "type": "best_fields",
                "fuzziness": "AUTO"
            }
        }
    }

    response = es_client.search(index=index_name, body=search_query)
    
    result_docs = [hit['_source'] for hit in response['hits']['hits']]
    
    return result_docs

In [8]:
def evaluate(ground_truth, search_function):
    relevance_total = []

    for q in tqdm(ground_truth):
        doc_id = q['id']
        results = search_function(q)
        relevance = [d['id'] == doc_id for d in results]
        relevance_total.append(relevance)

    return {
        'hit_rate': hit_rate(relevance_total),
        'mrr': mrr(relevance_total),
    }

In [9]:
evaluate(ground_truth, lambda q: elastic_search(q['question']))

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [05:41<00:00,  2.93it/s]


{'hit_rate': 0.687, 'mrr': 0.6390579365079364}

## finding the best paremeters

In [10]:
from hyperopt import fmin, tpe, hp, Trials
from hyperopt.pyll.base import scope

In [11]:
search_space = {
    'title_boost': hp.uniform('title_boost', 1.0, 3.0),
    'genres_boost': hp.uniform('genres_boost', 1.0, 3.0),
    'overview_boost': hp.uniform('overview_boost', 1.0, 3.0),
    'production_companies_boost': hp.uniform('production_companies_boost', 1.0, 3.0),
    'tagline_boost': hp.uniform('tagline_boost', 1.0, 3.0),
    'credits_boost': hp.uniform('credits_boost', 1.0, 3.0),
    'keywords_boost': hp.uniform('keywords_boost', 1.0, 3.0),
}

In [12]:
def elastic_search_search(query, params):

    title_boost = params['title_boost']
    genres_boost = params['genres_boost']
    overview_boost = params['overview_boost']
    production_companies_boost = params['production_companies_boost']
    tagline_boost = params['tagline_boost']
    credits_boost = params['credits_boost']
    keywords_boost = params['keywords_boost']
    
    # Define your query with dynamic boosts
    search_query = {
        "size": 10,
        "query": {
            "multi_match": {
                "query": query,
                "fields": [
                    f"title^{title_boost}",
                    f"genres^{genres_boost}",
                    f"overview^{overview_boost}",
                    f"production_companies^{production_companies_boost}",
                    f"tagline^{tagline_boost}",
                    f"credits^{credits_boost}",
                    f"keywords^{keywords_boost}"
                ],
                "type": "best_fields",
                "fuzziness": "AUTO"
            }
        }
    }

    response = es_client.search(index=index_name, body=search_query)
    
    result_docs = [hit['_source'] for hit in response['hits']['hits']]
    
    return result_docs

In [13]:
def objective(params):
    
    def search_function(q):
        return elastic_search_search(q['question'], params)  
    
    results = evaluate(ground_truth[:200], search_function)    
    return -(results['mrr'])


In [16]:
# Initialize trials object to store the result history
trials = Trials()

# Run optimization
best_params = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=10,  # The number of iterations you want to run
    trials=trials
)

print("Best parameters:", best_params)

  0%|          | 0/10 [00:00<?, ?trial/s, best loss=?]

  0%|          | 0/200 [00:00<?, ?it/s]
  0%|          | 1/200 [00:01<03:47,  1.14s/it]
  1%|1         | 2/200 [00:01<03:06,  1.06it/s]
  2%|1         | 3/200 [00:02<02:55,  1.12it/s]
  2%|2         | 4/200 [00:03<02:23,  1.37it/s]
  2%|2         | 5/200 [00:03<02:21,  1.38it/s]
  3%|3         | 6/200 [00:05<03:06,  1.04it/s]
  4%|3         | 7/200 [00:05<02:32,  1.27it/s]
  4%|4         | 8/200 [00:06<02:30,  1.27it/s]
  4%|4         | 9/200 [00:07<02:15,  1.41it/s]
  5%|5         | 10/200 [00:08<02:24,  1.32it/s]
  6%|5         | 11/200 [00:09<02:42,  1.16it/s]
  6%|6         | 12/200 [00:09<02:16,  1.37it/s]
  6%|6         | 13/200 [00:10<02:15,  1.38it/s]
  7%|7         | 14/200 [00:10<02:07,  1.46it/s]
  8%|7         | 15/200 [00:11<01:54,  1.62it/s]
  8%|8         | 16/200 [00:11<01:56,  1.58it/s]
  8%|8         | 17/200 [00:12<01:36,  1.89it/s]
  9%|9         | 18/200 [00:12<01:27,  2.07it/s]
 10%|9         | 19/200 [00:12<01:11,  2.53it/s]
 10%|#         | 20/200 [00:13<01:13, 

 10%|█         | 1/10 [01:27<13:11, 87.97s/trial, best loss: -0.2545]

  0%|          | 0/200 [00:00<?, ?it/s]
  0%|          | 1/200 [00:00<01:26,  2.30it/s]
  1%|1         | 2/200 [00:00<01:13,  2.71it/s]
  2%|1         | 3/200 [00:01<01:14,  2.65it/s]
  2%|2         | 4/200 [00:01<01:01,  3.19it/s]
  2%|2         | 5/200 [00:01<00:59,  3.26it/s]
  3%|3         | 6/200 [00:02<01:22,  2.35it/s]
  4%|3         | 7/200 [00:02<01:11,  2.72it/s]
  4%|4         | 8/200 [00:02<01:07,  2.83it/s]
  4%|4         | 9/200 [00:03<01:05,  2.92it/s]
  5%|5         | 10/200 [00:03<01:05,  2.89it/s]
  6%|5         | 11/200 [00:04<01:16,  2.46it/s]
  6%|6         | 12/200 [00:04<01:03,  2.98it/s]
  6%|6         | 13/200 [00:04<01:00,  3.08it/s]
  7%|7         | 14/200 [00:04<00:59,  3.13it/s]
  8%|7         | 15/200 [00:05<00:56,  3.26it/s]
  8%|8         | 16/200 [00:05<01:08,  2.68it/s]
  8%|8         | 17/200 [00:05<01:03,  2.90it/s]
  9%|9         | 18/200 [00:06<01:03,  2.87it/s]
 10%|9         | 19/200 [00:06<00:55,  3.27it/s]
 10%|#         | 20/200 [00:07<01:06, 

 20%|██        | 2/10 [03:02<12:16, 92.05s/trial, best loss: -0.37589087301587326]

  0%|          | 0/200 [00:00<?, ?it/s]
  0%|          | 1/200 [00:00<01:33,  2.14it/s]
  1%|1         | 2/200 [00:00<01:17,  2.57it/s]
  2%|1         | 3/200 [00:01<01:22,  2.40it/s]
  2%|2         | 4/200 [00:01<01:05,  3.00it/s]
  2%|2         | 5/200 [00:01<01:02,  3.10it/s]
  3%|3         | 6/200 [00:02<01:31,  2.11it/s]
  4%|3         | 7/200 [00:02<01:21,  2.38it/s]
  4%|4         | 8/200 [00:03<01:16,  2.52it/s]
  4%|4         | 9/200 [00:03<01:15,  2.54it/s]
  5%|5         | 10/200 [00:03<01:14,  2.56it/s]
  6%|5         | 11/200 [00:04<01:28,  2.14it/s]
  6%|6         | 12/200 [00:04<01:10,  2.65it/s]
  6%|6         | 13/200 [00:05<01:06,  2.82it/s]
  7%|7         | 14/200 [00:05<01:06,  2.81it/s]
  8%|7         | 15/200 [00:05<01:07,  2.74it/s]
  8%|8         | 16/200 [00:06<01:18,  2.36it/s]
  8%|8         | 17/200 [00:06<01:14,  2.46it/s]
  9%|9         | 18/200 [00:07<01:14,  2.43it/s]
 10%|9         | 19/200 [00:07<01:03,  2.85it/s]
 10%|#         | 20/200 [00:07<01:14, 

 30%|███       | 3/10 [04:42<11:08, 95.50s/trial, best loss: -0.37589087301587326]

  0%|          | 0/200 [00:00<?, ?it/s]
  0%|          | 1/200 [00:00<01:26,  2.31it/s]
  1%|1         | 2/200 [00:00<01:11,  2.79it/s]
  2%|1         | 3/200 [00:01<01:10,  2.80it/s]
  2%|2         | 4/200 [00:01<00:58,  3.34it/s]
  2%|2         | 5/200 [00:01<00:57,  3.41it/s]
  3%|3         | 6/200 [00:02<01:24,  2.29it/s]
  4%|3         | 7/200 [00:02<01:14,  2.60it/s]
  4%|4         | 8/200 [00:02<01:09,  2.77it/s]
  4%|4         | 9/200 [00:03<01:08,  2.80it/s]
  5%|5         | 10/200 [00:03<01:08,  2.79it/s]
  6%|5         | 11/200 [00:04<01:22,  2.30it/s]
  6%|6         | 12/200 [00:04<01:06,  2.85it/s]
  6%|6         | 13/200 [00:04<01:02,  3.00it/s]
  7%|7         | 14/200 [00:05<01:02,  2.96it/s]
  8%|7         | 15/200 [00:05<01:01,  3.01it/s]
  8%|8         | 16/200 [00:05<01:12,  2.52it/s]
  8%|8         | 17/200 [00:06<01:10,  2.59it/s]
  9%|9         | 18/200 [00:06<01:10,  2.60it/s]
 10%|9         | 19/200 [00:06<01:00,  3.02it/s]
 10%|#         | 20/200 [00:07<01:09, 

 40%|████      | 4/10 [06:18<09:33, 95.66s/trial, best loss: -0.44173412698412706]

  0%|          | 0/200 [00:00<?, ?it/s]
  0%|          | 1/200 [00:00<01:11,  2.77it/s]
  1%|1         | 2/200 [00:00<01:03,  3.11it/s]
  2%|1         | 3/200 [00:00<01:04,  3.03it/s]
  2%|2         | 4/200 [00:01<00:53,  3.64it/s]
  2%|2         | 5/200 [00:01<00:52,  3.70it/s]
  3%|3         | 6/200 [00:01<01:09,  2.79it/s]
  4%|3         | 7/200 [00:02<00:58,  3.28it/s]
  4%|4         | 8/200 [00:02<00:57,  3.31it/s]
  4%|4         | 9/200 [00:02<00:54,  3.54it/s]
  5%|5         | 10/200 [00:03<00:57,  3.31it/s]
  6%|5         | 11/200 [00:03<01:05,  2.89it/s]
  6%|6         | 12/200 [00:03<00:54,  3.47it/s]
  6%|6         | 13/200 [00:03<00:53,  3.50it/s]
  7%|7         | 14/200 [00:04<00:50,  3.66it/s]
  8%|7         | 15/200 [00:04<00:47,  3.88it/s]
  8%|8         | 16/200 [00:04<00:56,  3.27it/s]
  8%|8         | 17/200 [00:05<00:51,  3.56it/s]
  9%|9         | 18/200 [00:05<00:51,  3.50it/s]
 10%|9         | 19/200 [00:05<00:46,  3.88it/s]
 10%|#         | 20/200 [00:05<00:53, 

 50%|█████     | 5/10 [07:40<07:34, 90.84s/trial, best loss: -0.44173412698412706]

  0%|          | 0/200 [00:00<?, ?it/s]
  0%|          | 1/200 [00:00<01:15,  2.64it/s]
  1%|1         | 2/200 [00:00<01:06,  2.97it/s]
  2%|1         | 3/200 [00:01<01:11,  2.76it/s]
  2%|2         | 4/200 [00:01<00:58,  3.34it/s]
  2%|2         | 5/200 [00:01<00:57,  3.41it/s]
  3%|3         | 6/200 [00:02<01:16,  2.54it/s]
  4%|3         | 7/200 [00:02<01:04,  2.98it/s]
  4%|4         | 8/200 [00:02<01:02,  3.07it/s]
  4%|4         | 9/200 [00:02<00:59,  3.23it/s]
  5%|5         | 10/200 [00:03<01:02,  3.06it/s]
  6%|5         | 11/200 [00:03<01:11,  2.64it/s]
  6%|6         | 12/200 [00:03<00:58,  3.19it/s]
  6%|6         | 13/200 [00:04<00:57,  3.23it/s]
  7%|7         | 14/200 [00:04<00:56,  3.32it/s]
  8%|7         | 15/200 [00:04<00:52,  3.52it/s]
  8%|8         | 16/200 [00:05<01:03,  2.88it/s]
  8%|8         | 17/200 [00:05<00:58,  3.13it/s]
  9%|9         | 18/200 [00:05<00:58,  3.12it/s]
 10%|9         | 19/200 [00:06<00:52,  3.42it/s]
 10%|#         | 20/200 [00:06<01:06, 

 60%|██████    | 6/10 [09:10<06:01, 90.34s/trial, best loss: -0.44173412698412706]

  0%|          | 0/200 [00:00<?, ?it/s]
  0%|          | 1/200 [00:00<01:13,  2.72it/s]
  1%|1         | 2/200 [00:00<01:04,  3.08it/s]
  2%|1         | 3/200 [00:00<01:04,  3.05it/s]
  2%|2         | 4/200 [00:01<00:54,  3.60it/s]
  2%|2         | 5/200 [00:01<00:53,  3.64it/s]
  3%|3         | 6/200 [00:02<01:12,  2.69it/s]
  4%|3         | 7/200 [00:02<01:00,  3.17it/s]
  4%|4         | 8/200 [00:02<00:58,  3.25it/s]
  4%|4         | 9/200 [00:02<00:55,  3.46it/s]
  5%|5         | 10/200 [00:03<00:58,  3.27it/s]
  6%|5         | 11/200 [00:03<01:06,  2.86it/s]
  6%|6         | 12/200 [00:03<00:54,  3.45it/s]
  6%|6         | 13/200 [00:03<00:54,  3.46it/s]
  7%|7         | 14/200 [00:04<00:52,  3.56it/s]
  8%|7         | 15/200 [00:04<00:48,  3.84it/s]
  8%|8         | 16/200 [00:04<00:56,  3.27it/s]
  8%|8         | 17/200 [00:05<00:50,  3.60it/s]
  9%|9         | 18/200 [00:05<00:51,  3.53it/s]
 10%|9         | 19/200 [00:05<00:46,  3.91it/s]
 10%|#         | 20/200 [00:05<00:54, 

 70%|███████   | 7/10 [10:31<04:22, 87.56s/trial, best loss: -0.44173412698412706]

  0%|          | 0/200 [00:00<?, ?it/s]
  0%|          | 1/200 [00:00<01:24,  2.35it/s]
  1%|1         | 2/200 [00:00<01:09,  2.83it/s]
  2%|1         | 3/200 [00:01<01:17,  2.56it/s]
  2%|2         | 4/200 [00:01<01:01,  3.18it/s]
  2%|2         | 5/200 [00:01<00:59,  3.30it/s]
  3%|3         | 6/200 [00:02<01:22,  2.34it/s]
  4%|3         | 7/200 [00:02<01:13,  2.62it/s]
  4%|4         | 8/200 [00:02<01:08,  2.78it/s]
  4%|4         | 9/200 [00:03<01:08,  2.80it/s]
  5%|5         | 10/200 [00:03<01:07,  2.82it/s]
  6%|5         | 11/200 [00:04<01:20,  2.35it/s]
  6%|6         | 12/200 [00:04<01:04,  2.90it/s]
  6%|6         | 13/200 [00:04<01:01,  3.05it/s]
  7%|7         | 14/200 [00:04<01:02,  2.99it/s]
  8%|7         | 15/200 [00:05<01:01,  3.02it/s]
  8%|8         | 16/200 [00:05<01:12,  2.54it/s]
  8%|8         | 17/200 [00:06<01:08,  2.68it/s]
  9%|9         | 18/200 [00:06<01:09,  2.63it/s]
 10%|9         | 19/200 [00:06<00:58,  3.07it/s]
 10%|#         | 20/200 [00:07<01:10, 

 80%|████████  | 8/10 [12:08<03:01, 90.51s/trial, best loss: -0.44173412698412706]

  0%|          | 0/200 [00:00<?, ?it/s]
  0%|          | 1/200 [00:00<01:25,  2.32it/s]
  1%|1         | 2/200 [00:00<01:10,  2.79it/s]
  2%|1         | 3/200 [00:01<01:15,  2.62it/s]
  2%|2         | 4/200 [00:01<01:00,  3.24it/s]
  2%|2         | 5/200 [00:01<00:58,  3.32it/s]
  3%|3         | 6/200 [00:02<01:24,  2.29it/s]
  4%|3         | 7/200 [00:02<01:14,  2.59it/s]
  4%|4         | 8/200 [00:02<01:09,  2.75it/s]
  4%|4         | 9/200 [00:03<01:09,  2.77it/s]
  5%|5         | 10/200 [00:03<01:08,  2.77it/s]
  6%|5         | 11/200 [00:04<01:21,  2.32it/s]
  6%|6         | 12/200 [00:04<01:05,  2.85it/s]
  6%|6         | 13/200 [00:04<01:03,  2.92it/s]
  7%|7         | 14/200 [00:05<01:03,  2.93it/s]
  8%|7         | 15/200 [00:05<01:02,  2.98it/s]
  8%|8         | 16/200 [00:05<01:13,  2.51it/s]
  8%|8         | 17/200 [00:06<01:07,  2.69it/s]
  9%|9         | 18/200 [00:06<01:08,  2.64it/s]
 10%|9         | 19/200 [00:06<00:59,  3.03it/s]
 10%|#         | 20/200 [00:07<01:10, 

 90%|█████████ | 9/10 [13:44<01:32, 92.31s/trial, best loss: -0.44173412698412706]

  0%|          | 0/200 [00:00<?, ?it/s]
  0%|          | 1/200 [00:00<01:14,  2.68it/s]
  1%|1         | 2/200 [00:00<01:04,  3.05it/s]
  2%|1         | 3/200 [00:00<01:04,  3.05it/s]
  2%|2         | 4/200 [00:01<00:54,  3.60it/s]
  2%|2         | 5/200 [00:01<00:53,  3.66it/s]
  3%|3         | 6/200 [00:02<01:14,  2.61it/s]
  4%|3         | 7/200 [00:02<01:02,  3.10it/s]
  4%|4         | 8/200 [00:02<01:00,  3.18it/s]
  4%|4         | 9/200 [00:02<00:57,  3.34it/s]
  5%|5         | 10/200 [00:03<00:58,  3.23it/s]
  6%|5         | 11/200 [00:03<01:07,  2.81it/s]
  6%|6         | 12/200 [00:03<00:54,  3.42it/s]
  6%|6         | 13/200 [00:04<00:53,  3.51it/s]
  7%|7         | 14/200 [00:04<00:52,  3.53it/s]
  8%|7         | 15/200 [00:04<00:49,  3.73it/s]
  8%|8         | 16/200 [00:04<00:59,  3.11it/s]
  8%|8         | 17/200 [00:05<00:54,  3.37it/s]
  9%|9         | 18/200 [00:05<00:54,  3.37it/s]
 10%|9         | 19/200 [00:05<00:47,  3.79it/s]
 10%|#         | 20/200 [00:06<00:55, 

100%|██████████| 10/10 [15:07<00:00, 90.79s/trial, best loss: -0.44173412698412706]
Best parameters: {'credits_boost': np.float64(1.3731919195889106), 'genres_boost': np.float64(1.3064193340115824), 'keywords_boost': np.float64(2.203553448501116), 'overview_boost': np.float64(1.588702891542049), 'production_companies_boost': np.float64(1.7860593122737678), 'tagline_boost': np.float64(2.5652702870580586), 'title_boost': np.float64(2.2135110497579085)}


In [15]:
best_params

{'credits_boost': np.float64(1.2455632927531133),
 'genres_boost': np.float64(1.8690964556691751),
 'keywords_boost': np.float64(2.495763043344679),
 'overview_boost': np.float64(2.5734245250947705),
 'production_companies_boost': np.float64(2.364879048983533),
 'tagline_boost': np.float64(2.4531032723552704),
 'title_boost': np.float64(2.6316450728726513)}

In [43]:
evaluate(ground_truth,lambda q: elastic_search_search(q['question'], best_params))

100%|██████████| 225/225 [01:38<00:00,  2.28it/s]


{'hit_rate': 0.6088888888888889, 'mrr': 0.5090828924162256}