In [112]:
import numpy as np 
import pandas as pd 
from tqdm import tqdm 
import telegram
from transformers import AutoTokenizer, AutoModel
from transformers.optimization import get_linear_schedule_with_warmup, AdamW 
from sentence_transformers import SentenceTransformer, util, InputExample, losses, evaluation
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import pickle
import re
import os 
import time 
import datetime
import faiss
import pickle 

In [172]:
test_model = SentenceTransformer("PatentSBERTa_900000_finetuned", device='cuda') 
test_model 

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: MPNetModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
)

In [115]:
test_files = os.listdir("FGH_claim_triplet_v0.1s/test/")
tel_token = "5059732158:AAE87TaReNbDKH3_Fy-CAYCUuIO2qiUyK2I"
chat_id = 1720119057 
bot = telegram.Bot(token=tel_token)

text = "Creating test triplets..." 
bot.sendMessage(chat_id=chat_id, text=text)

L = 30 

q_test, p_test, n_test = [], [], [] 

for i in tqdm(range(len(test_files)), position=0, leave=True): 
    try:
        if i % 100000 == 0:
            bot.sendMessage(chat_id=chat_id, text="currently at iteration {}/{}".format(i, len(test_files)))
    except Exception as e:
        print(e)
            
    file = open("FGH_claim_triplet_v0.1s/test/" + str(test_files[i]), "r") 
    claims = file.readlines() 
    
    query, positive, negative = [], [], []
    q, p, n = False, False, False 
    for j in range(len(claims)):
        if q == False:
            query.append(claims[j]) 
        elif p == False:
            positive.append(claims[j]) 
        else:
            negative.append(claims[j])
        if j < len(claims)-1 and claims[j] == '\n' and claims[j+1] == '\n':
            if q == False:
                q = True
            elif p == False:
                p = True 
    
    clean_query, clean_positive, clean_negative = [],[],[] 
    for q in query:
        if len(q.strip()) < L: 
            continue
        t = re.sub(r'\w[.)]\s*', '', q)
        t = re.sub('[^A-Za-z]+',' ',t) 
        clean_query.append(t.strip()) 
    q_test.append(' '.join(clean_query)) 
    
    for p in positive:
        if len(p.strip()) < L:
            continue 
        t = re.sub(r'\w[.)]\s*', '', p)
        t = re.sub('[^A-Za-z]+',' ',t) 
        clean_positive.append(t.strip()) 
    p_test.append(' '.join(clean_positive))
    
    for n in negative:
        if len(n.strip()) < L:
            continue 
        t = re.sub(r'\w[.)]\s*', '', n)
        t = re.sub('[^A-Za-z]+',' ',t)
        clean_negative.append(t.strip()) 
    n_test.append(' '.join(clean_negative))

100%|██████████| 50000/50000 [00:57<00:00, 875.11it/s]


In [117]:
with open('tabnet_900000_finetuned.pkl', 'rb') as f:
    candidate_embeddings = pickle.load(f)
idx = 0 
q = q_test[idx]
q_embedding = test_model.encode(q) 
ranks = [] 
for i in tqdm(range(len(candidate_embeddings)),position=0, leave=True):
    dist = util.cos_sim(q_embedding, candidate_embeddings[i]) 
    ranks.append((i, dist)) 

ranks = sorted(ranks, key=lambda x: x[1], reverse=True)
for i in range(len(ranks)):
    if ranks[i][0] == idx:
        print("rank = {}, score = {}".format(i+1, ranks[i][1].item())) 
        break 

100%|██████████| 100000/100000 [00:13<00:00, 7374.20it/s]


rank = 1, score = 0.912592351436615


In [123]:
with open('tabnet_900000_finetuned.pkl', 'rb') as f:
    candidate_embeddings = pickle.load(f)

candidate_embeddings = np.array(candidate_embeddings)
    
index = faiss.IndexIDMap(faiss.IndexFlatIP(768))

faiss.normalize_L2(candidate_embeddings) 

index.train(candidate_embeddings)

index.add_with_ids(candidate_embeddings, np.array(range(0, len(candidate_embeddings))))

faiss.write_index(index, "test_index.index")

In [181]:
tel_token = "5059732158:AAE87TaReNbDKH3_Fy-CAYCUuIO2qiUyK2I"
chat_id = 1720119057 
bot = telegram.Bot(token=tel_token)

t0 = time.time() 
q = test_model.encode(q_test) 

faiss.normalize_L2(q)

D, I = index.search(q, len(candidate_embeddings))

elapsed = time.time() - t0

bot.sendMessage(chat_id=chat_id, text="Time elapsed: {}s".format(elapsed))
print(elapsed)

833.6708905696869


In [184]:
ranks = [] 
reciprocal_ranks = [] 
reciprocal_ranks_100 = [] 
p5, p10, p50, p100 = [], [], [], [] 

for idx in tqdm(range(len(D)), position=0, leave=True):
    for i in range(len(I[idx])):
        if I[idx][i] == idx:
            ranks.append(i+1) 
            reciprocal_ranks.append(1/(i+1)) 
            if i+1 > 100:
                reciprocal_ranks_100.append(0) 
            else:
                reciprocal_ranks_100.append(1/(i+1)) 
                
            if i+1 <= 5:
                p5.append(True) 
            else:
                p5.append(False) 
                
            if i+1 <= 10:
                p10.append(True)
            else:
                p10.append(False) 
            
            if i+1 <= 50:
                p50.append(True)
            else:
                p50.append(False) 
                
            if i+1 <= 100:
                p100.append(True) 
            else:
                p100.append(False)
            

100%|██████████| 50000/50000 [55:21<00:00, 15.05it/s]  


In [185]:
text = "MRR : {}".format(np.mean(reciprocal_ranks)) 
bot.sendMessage(chat_id=chat_id, text=text) 

text = "MRR@100 : {}".format(np.mean(reciprocal_ranks_100)) 
bot.sendMessage(chat_id=chat_id, text=text) 

text = "Average Rank : {}".format(np.mean(ranks)) 
bot.sendMessage(chat_id=chat_id, text=text) 

cnt = 0
for i in range(len(p5)):
    if p5[i]==True:
        cnt += 1 
text = "p5 : {}".format(cnt/len(p5))
bot.sendMessage(chat_id=chat_id, text=text) 


cnt = 0
for i in range(len(p10)):
    if p10[i]==True:
        cnt += 1 
text = "p10 : {}".format(cnt/len(p10))
bot.sendMessage(chat_id=chat_id, text=text) 


cnt = 0
for i in range(len(p50)):
    if p50[i]==True:
        cnt += 1 
text = "p50 : {}".format(cnt/len(p50))
bot.sendMessage(chat_id=chat_id, text=text) 


cnt = 0
for i in range(len(p100)):
    if p100[i]==True:
        cnt += 1 
text = "p100 : {}".format(cnt/len(p100))
bot.sendMessage(chat_id=chat_id, text=text) 



<telegram.message.Message at 0x31467a88>