In [35]:
import cohere
import numpy as np
import re
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset
import umap
import altair as alt
from sklearn.metrics.pairwise import cosine_similarity
from annoy import AnnoyIndex
import warnings
warnings.filterwarnings('ignore')
pd.set_option('display.max_colwidth', None)

In [36]:
# Get dataset
dataset = load_dataset("trec", split="train")

Found cached dataset trec (/Users/gunjan07shinde/.cache/huggingface/datasets/trec/default/2.0.0/f2469cab1b5fceec7249fda55360dfdbd92a7a5b545e91ea0f78ad108ffac1c2)


In [37]:

dataset.shape

(5452, 3)

In [38]:
# Import into a pandas dataframe, take only the first 1000 rows
df = pd.DataFrame(dataset)[:1500]
# Preview the data to ensure it has loaded correctly
df.head(10)

Unnamed: 0,text,coarse_label,fine_label
0,How did serfdom develop in and then leave Russia ?,2,26
1,What films featured the character Popeye Doyle ?,1,5
2,How can I find a list of celebrities ' real names ?,2,26
3,What fowl grabs the spotlight after the Chinese Year of the Monkey ?,1,2
4,What is the full form of .com ?,0,1
5,What contemptible scoundrel stole the cork from my lunch ?,3,29
6,What team did baseball 's St. Louis Browns become ?,3,28
7,What is the oldest profession ?,3,30
8,What are liver enzymes ?,2,24
9,Name the scar-faced bounty hunter of The Old West .,3,29


In [39]:
# Paste your API key here. Remember to not share publicly
api_key = 'AHKLrKZmFltl6RQHg4J2sNsqV4SS09YdEgLT31aL'

# Create and retrieve a Cohere API key from dashboard.cohere.ai/welcome/register
co = cohere.Client(api_key)

# Get the embeddings
embeds = co.embed(texts=list(df['text']),
                  model='large',
                  truncate='LEFT').embeddings

In [40]:
# Create the search index, pass the size of embedding
search_index = AnnoyIndex(np.array(embeds).shape[1], 'angular')
# Add all the vectors to the search index
for i in range(len(embeds)):
    search_index.add_item(i, embeds[i])
search_index.build(10) # 10 trees
search_index.save('test.ann')

True

In [41]:
# Choose an example (we'll retrieve others similar to it)
example_id = 92
# Retrieve nearest neighbors
similar_item_ids = search_index.get_nns_by_item(example_id,10,
                                                include_distances=True)
# Format and print the text and distances
results = pd.DataFrame(data={'texts': df.iloc[similar_item_ids[0]]['text'],
                             'distance': similar_item_ids[1]}).drop(example_id)
print(f"Question:'{df.iloc[example_id]['text']}'\nNearest neighbors:")
results


Question:'What are bear and bull markets ?'
Nearest neighbors:


Unnamed: 0,texts,distance
614,What animals do you find in the stock market ?,0.882624
137,What are equity securities ?,1.057622
307,What does NASDAQ stand for ?,1.077819
547,Where can stocks be traded on-line ?,1.090763
513,What do economists do ?,1.121729
363,What does it mean `` Rupee Depreciates '' ?,1.13084
922,What is the difference between a median and a mean ?,1.132415
601,What is `` the bear of beers '' ?,1.143287
932,Why did the world enter a global depression in 1929 ?,1.152498


In [55]:
query = "AI is the new electricity. "

def getResult(query):
    query_embed = co.embed(texts=[query],
                  model="large",
                  truncate="LEFT").embeddings

    similar_item_ids = search_index.get_nns_by_vector(query_embed[0],10,
                                                include_distances=True)
    results = pd.DataFrame(data={'texts': df.iloc[similar_item_ids[0]]['text'], 
                             'distance': similar_item_ids[1]})

    print(f"Query:'{query}'\nNearest neighbors:")
    return results
    
getResult(query)

Query:'AI is the new electricity. '
Nearest neighbors:


Unnamed: 0,texts,distance
635,What is artificial intelligence ?,1.004586
1065,How much electricity does the brain need to work ?,1.116887
153,Who discovered electricity ?,1.135489
612,What is the `` coppertop '' battery ?,1.17817
131,What is a transistor ?,1.182157
1374,What did Mr. Magoo flog on TV for General Electric ?,1.204926
1018,Who invented batteries ?,1.20689
878,Why are electric cars less efficient in the northeast than in California ?,1.209495
1454,What does IBM stand for ?,1.213416
650,What is the most advanced handheld calculator in the world ?,1.214575


In [56]:
li = [cols for cols in results.columns]
li

['texts', 'distance']

In [57]:
li = [cols for cols in df.columns]
li

['text', 'coarse_label', 'fine_label']

In [58]:
df.head()

Unnamed: 0,text,coarse_label,fine_label
0,How did serfdom develop in and then leave Russia ?,2,26
1,What films featured the character Popeye Doyle ?,1,5
2,How can I find a list of celebrities ' real names ?,2,26
3,What fowl grabs the spotlight after the Chinese Year of the Monkey ?,1,2
4,What is the full form of .com ?,0,1


In [59]:

# UMAP reduces the dimensions from 1024 to 2 dimensions that we can plot
reducer = umap.UMAP(n_neighbors=20) 
umap_embeds = reducer.fit_transform(embeds)
# Prepare the data to plot and interactive visualization
# using Altair
df_explore = pd.DataFrame(data={'text': df['text']})
df_explore['x'] = umap_embeds[:,0]
df_explore['y'] = umap_embeds[:,1]

# Plot
chart = alt.Chart(df_explore).mark_circle(size=60).encode(
    x=#'x',
    alt.X('x',
        scale=alt.Scale(zero=False)
    ),
    y=
    alt.Y('y',
        scale=alt.Scale(zero=False)
    ),
    tooltip=['text']
).properties(
    width=700,
    height=400
)
chart.interactive()

In [60]:
distance = [results['distance']]
distance

[614    0.882624
 137    1.057622
 307    1.077819
 547    1.090763
 513    1.121729
 363    1.130840
 922    1.132415
 601    1.143287
 932    1.152498
 Name: distance, dtype: float64]

In [61]:
query = "Welcome to new world. "

def getResult(query):
    query_embed = co.embed(texts=[query],
                  model="large",
                  truncate="LEFT").embeddings

    similar_item_ids = search_index.get_nns_by_vector(query_embed[0],10,
                                                include_distances=True)
    results2 = pd.DataFrame(data={'texts': df.iloc[similar_item_ids[0]]['text'], 
                             'distance': similar_item_ids[1]})

    print(f"Query:'{query}'\nNearest neighbors:")
    return results2
    
getResult(query)

Query:'Welcome to new world. '
Nearest neighbors:


Unnamed: 0,texts,distance
79,"What country did the Nazis occupy for 1 , CD NNS IN NNP NNP NNP .",1.185202
1479,What is after death ?,1.198188
170,Where are the 49 steps ?,1.21674
552,What is a wop ?,1.218052
1072,What novel has Big Brother watching ?,1.219108
60,Where is the Loop ?,1.219587
140,Where is the Orinoco ?,1.219869
872,Where 's the Petrified Forest ?,1.221254
592,What was the only country you were allowed to drive into Israel from in 1979 ?,1.222088
964,What century does Captain Video live in ?,1.224412


In [62]:
dist2 = [results['distance']]

In [63]:
dist2

[614    0.882624
 137    1.057622
 307    1.077819
 547    1.090763
 513    1.121729
 363    1.130840
 922    1.132415
 601    1.143287
 932    1.152498
 Name: distance, dtype: float64]

In [66]:
def cal_score(dist2, distance):
    diff = np.array(dist2) - np.array(distance)
    diff_sum = np.array(diff).sum()
    print(diff_sum)
    score = (3.4*(np.log(6.1-6.4*diff_sum))+5)
    print(score)
    return score
xasasaa = cal_score(dist2, distance)

0.0
11.148181822009501
