In [3]:
import nltk
from nltk.corpus import wordnet as wn
import re

import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
import pickle
from scipy import spatial
import os
print(os.getcwd())

C:\Users\yiehy\OneDrive\Desktop\cs425-nlc-project\9.Query Expansion


In [1]:
!pip install sentence-transformers



# Best performing bi-encoder (retrieve and re-rank)

In [4]:
bi_encoder_model = SentenceTransformer("../8.Fine-tuned Models/finetuned-bertbase-1epoch")
cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')



# On test set

In [5]:
# Read embeddings
test_answer_embeddings = pickle.load(open("../4.Retrieval/finetuned_bertbase/finetuned_bertbase_test_answer_embeddings.pkl", 'rb'))

In [6]:
test_df = pd.read_csv("../0.Datasets/train_test_split/test.csv")

In [7]:
question_answer_index_map = {}
for _,row in test_df.iterrows():
    if row["qid"] not in question_answer_index_map:
        question_answer_index_map[row["qid"]]= []
        question_answer_index_map[row["qid"]].append(row["docid"])
    else:
        question_answer_index_map[row["qid"]].append(row["docid"])

In [8]:
labels = []
for v in question_answer_index_map.values():
    labels.append(v)

In [9]:
question_map = {}
label_map = {}
for _,row in test_df.iterrows():
    if row["qid"] not in question_map:
        question_map[row["qid"]] = row["question"]
    if row["answer"] not in label_map:
        label_map[row["answer"]] = row["docid"]

In [36]:
#Performs basic data cleaning on query: standardize capitalization to lower, remove punctuations, remove redundant whitespaces
def basic_cleaning(query):
    query = str(query)
    query = query.lower()
    query = re.sub(r'[^\w\s]','',query)
    query = ' '.join(query.split())
    return query

# Extract nouns
def nouns_only(query):
    try:
        tagged_text = nltk.tag.pos_tag(query.split())
        nouns_list = [word for word,tag in tagged_text if  tag == 'NNP' or tag == 'NNPS' or tag=="NN" or tag=="NNS"]
        return list(set(nouns_list))
    except:
        return []
    
def query_noun_mapping(query_nouns):
    synonym_dict = {}
    for query_noun in query_nouns:
        try:
            closest_noun = wn.synsets(query_noun)[0].lemmas()[1].name()
            synonym_dict[query_noun] = closest_noun
        except:
            pass
    return synonym_dict

def query_expansion(query):
    try:
        clean_query = basic_cleaning(query)
        query_nouns_list = nouns_only(clean_query)
        if len(query_nouns_list)==0:
            return query
        else:
            synonym_dict = query_noun_mapping(query_nouns_list)
        if len(synonym_dict.keys()) == 0:
            return query
        else:
            for k,v in synonym_dict.items():
                idx = query.lower().index(k)
                query = query[:idx] + query[idx:] + f" and {v}"
        return query
    except:
        return query



In [37]:
test_answer_list = test_df["answer"].tolist()
predictions = []
count=1
for k,v in question_map.items():
    if count%100==0:
        print(count)
    try:
        v = query_expansion(v)
    except:   
        pass
    question_embedding = bi_encoder_model.encode(query_expansion(v))
    answer_similiarity = {}
    for i,embed in enumerate(test_answer_embeddings):
        answer_similiarity[i]= np.dot(question_embedding, embed)
    answer_similiarity = {k: v for k, v in sorted(answer_similiarity.items(), key=lambda item: item[1], reverse=True)}
    top_20_hits = []
    for item in list(answer_similiarity)[:20]:
        top_20_hits.append(test_answer_list[item])
    cross_encoder_answer_similiarity = {}
    for hit in top_20_hits:
        cross_encoder_answer_similiarity[hit] = cross_encoder_model.predict([v,hit])
    cross_encoder_answer_similiarity = {k: v for k, v in sorted(cross_encoder_answer_similiarity.items(), key=lambda item: item[1], reverse=True)}
    label_index = label_map[list(cross_encoder_answer_similiarity.keys())[0]]
    predictions.append([label_index])
    count+=1   

100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500


In [38]:
#save prediction and results
results = {"labels":labels,"predictions":predictions}
with open("../7.Evaluate/query_expansion.pkl", 'wb') as f:
    pickle.dump(results, f, pickle.HIGHEST_PROTOCOL)

In [75]:
#Performs basic data cleaning on query: standardize capitalization to lower, remove punctuations, remove redundant whitespaces
def basic_cleaning(query):
    query = str(query)
    query = query.lower()
    query = re.sub(r'[^\w\s]','',query)
    query = ' '.join(query.split())
    return query

# Extract nouns
def nouns_only(query):
    try:
        tagged_text = nltk.tag.pos_tag(query.split())
        nouns_list = [word for word,tag in tagged_text if  tag == 'NNP' or tag == 'NNPS' or tag=="NN" or tag=="NNS"]
        return list(set(nouns_list))
    except:
        return []
    
def query_noun_mapping(query_nouns):
    synonym_dict = {}
    for query_noun in query_nouns:
        try:
            closest_noun = wn.synsets(query_noun)[0].lemmas()[1].name()
            synonym_dict[query_noun] = closest_noun
        except:
            pass
    return synonym_dict

def query_expansion(query):
    try:
        clean_query = basic_cleaning(query)
        query_nouns_list = nouns_only(clean_query)
        if len(query_nouns_list)==0:
            return query
        else:
            synonym_dict = query_noun_mapping(query_nouns_list)
        if len(synonym_dict.keys()) == 0:
            return query
        else:
            for k,v in synonym_dict.items():
                new_string = f"{k} and {v}"
                query = query.replace(k, new_string)
        return query
    except:
        return query

