In [8]:
from transformers import BertTokenizer, RobertaTokenizer
import torch
from gensim.models import Word2Vec
import os
import uuid
import json
import nltk

import sys
sys.path.append(os.path.join('..', 'src'))
from utils import get_sha256, clean_text, remove_non_word_chars, clean_text, tokens_to_embeddings

#### Prompt user for query
- Specify the tokenizer, consistant with Q&A model and parsed pdfs (TOKENIZER)

In [9]:
# Set the Tokenizer for your specific BERT model variant
# TOKENIZER = 'bert'
TOKENIZER = 'roberta'

bert_base_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
roberta_tokenizer = RobertaTokenizer.from_pretrained("deepset/roberta-base-squad2")

tokenizers = {'bert': bert_base_tokenizer, 'roberta': roberta_tokenizer}

tokenizer = tokenizers[TOKENIZER]

In [10]:
# Prompt the user for an input query
user_query = input("Enter your query: ")
user_query = user_query.lower()

# clean query for BERT input
user_query = clean_text(user_query)

# clean query for candidate search
user_query_for_search = remove_non_word_chars(user_query)

# Tokenize the query for BERT input
tokenized_query = tokenizer.tokenize(user_query)

# Tokenize the query for candidate search
tokenized_query_for_search = tokenizer.tokenize(user_query_for_search)

# Remove the stop words for the tokenized query for search
nltk_stop_words = nltk.corpus.stopwords.words('english')
nltk_stop_words.extend(["Ġ" + word for word in nltk_stop_words])  # Add the roberta modified tokens
tokenized_query_for_search_less_sw = [token for token in tokenized_query_for_search if token not in nltk_stop_words]

# Pad or truncate the query to a fixed length of 20 tokens (BERT input)
max_query_length = 20
if len(tokenized_query) > max_query_length:
    tokenized_query = tokenized_query[:max_query_length]
else:
    padding_length = max_query_length - len(tokenized_query)
    tokenized_query = tokenized_query + [tokenizer.pad_token] * padding_length

# Convert the tokenized query to input IDs and attention mask
input_ids_query = tokenizer.convert_tokens_to_ids(tokenized_query)
attention_mask_query = [1] * len(input_ids_query)

# Convert to tensors
input_ids_query = torch.tensor(input_ids_query).unsqueeze(0)  # Add batch dimension
attention_mask_query = torch.tensor(attention_mask_query).unsqueeze(0)  # Add batch dimension

print("Tokenized query:\n", tokenized_query, "\n")
print("Tokenized query for seach:\n", tokenized_query_for_search, "\n")
print("Tokenized query for seach less stop words:\n", tokenized_query_for_search_less_sw, "\n")
print("Input IDs query:\n", input_ids_query, "\n")
print("Attention mask query:\n", attention_mask_query, "\n")

Enter your query:  What was generalized in the paper "On The Definition of Higher Gamma functions"?


Tokenized query:
 ['What', 'Ġwas', 'Ġgeneralized', 'Ġin', 'Ġthe', 'Ġpaper', 'Ġ"', 'On', 'ĠThe', 'ĠDefinition', 'Ġof', 'ĠHigher', 'ĠGamma', 'Ġfunctions', '"?', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'] 

Tokenized query for seach:
 ['What', 'Ġwas', 'Ġgeneralized', 'Ġin', 'Ġthe', 'Ġpaper', 'ĠOn', 'ĠThe', 'ĠDefinition', 'Ġof', 'ĠHigher', 'ĠGamma', 'Ġfunctions'] 

Tokenized query for seach less stop words:
 ['What', 'Ġgeneralized', 'Ġpaper', 'ĠOn', 'ĠThe', 'ĠDefinition', 'ĠHigher', 'ĠGamma', 'Ġfunctions'] 

Input IDs query:
 tensor([[ 2264,    21, 44030,    11,     5,  2225,    22,  4148,    20, 38764,
             9, 13620, 43566,  8047, 24681,     1,     1,     1,     1,     1]]) 

Attention mask query:
 tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) 



##### Add the embeddings
- Specify the embedding model

In [11]:
# Load your trained Word2Vec model
# model_fname = os.path.join("..", "models", "word_embeddings", "word2vec_model.bin")
model_fname = os.path.join("..", "models", "word_embeddings", "roberta_word2vec_model.bin")
model = Word2Vec.load(model_fname)

In [12]:
# Get the query embeddings for the candidate document search
query_embeddings = tokens_to_embeddings(tokenized_query, model, RANDOM=False)
query_embeddings_search = tokens_to_embeddings(tokenized_query_for_search, model, RANDOM=False)
query_embeddings_less_sw = tokens_to_embeddings(tokenized_query_for_search_less_sw, model, RANDOM=False)

print("\t\t\t\t\t\tTokens Length\t\tEmbeddings Shape")
print(f"\t\t\t   Query embeddings:\t      {len(tokenized_query)}\t\t     {query_embeddings.shape}")
print(f"\t\tQuery embeddings for search:\t      {len(tokenized_query_for_search)}\t\t\t     {query_embeddings_search.shape}")
print(f" Query embeddings for search less stopwords:\t      {len(tokenized_query_for_search_less_sw)}\t\t\t     {query_embeddings_less_sw.shape}")

						Tokens Length		Embeddings Shape
			   Query embeddings:	      20		     (20, 100)
		Query embeddings for search:	      13			     (13, 100)
 Query embeddings for search less stopwords:	      9			     (9, 100)


##### Store the output the the query directory, filename is hash of query

In [13]:
# store the query
query_data = {
    "query": user_query,
    "input_ids_query":input_ids_query.tolist(),
    "attention_mask_query": attention_mask_query.tolist(),
    "query_search":user_query_for_search,
    "tokenized_query":tokenized_query,
    "tokenized_query_search":tokenized_query_for_search,
    "tokenized_query_search_less_sw":tokenized_query_for_search_less_sw,
    "query_embedding": query_embeddings_search.tolist(), # Just used for the candidate search
    "query_embedding_search": query_embeddings_search.tolist(), # Just used for the candidate search, cleaned
    "query_embedding_search_less_sw": query_embeddings_less_sw.tolist() # Just used for the candidate search, cleaned more
}

json_string = json.dumps(query_data['query'], indent=2)
# print(json_string)

# Specify the directory path
directory_path = os.path.join("..", 'query')

# Check if the directory exists, if not create the directory
if not os.path.exists(directory_path):
    os.makedirs(directory_path)

# Generate a UUID
# unique_id = uuid.uuid4()
unique_id = get_sha256(json_string)
print(unique_id)

fname = os.path.join(directory_path, str(unique_id)+'.json')
print(fname)

with open(fname, 'w') as j_file:
    json.dump(query_data, j_file, indent=4)

7ff4d51a240429a1e936dbe31c4c79924ad112923dcc4e0d1a7de488400dde79
../query/7ff4d51a240429a1e936dbe31c4c79924ad112923dcc4e0d1a7de488400dde79.json
