In [61]:
import numpy as np
import pandas as pd
import rank_bm25 as rbm

In [62]:
dd = pd.read_pickle('data/eng_script_vectorized_v2.pkl')

In [63]:
dd.columns

Index(['series_num', 'ep_num', 'ep_name', 'phrase_rank', 'person', 'text',
       'person_orig', 'person_context', 'scene_context', 'text_5_prev',
       'person_5_prev', 'person_2_prev', 'text_1_shift', 'text_2_shift',
       'text_3_shift', 'text_4_shift', 'text_5_shift', 'glove_v_text',
       'glove_v_text_1_shift', 'glove_v_text_2_shift', 'glove_v_text_3_shift',
       'glove_v_text_4_shift', 'glove_v_text_5_shift', 'sbert_v_text_1_shift',
       'sbert_v_text', 'sbert_v_text_2_shift', 'sbert_v_text_3_shift',
       'sbert_v_text_4_shift', 'sbert_v_text_5_shift'],
      dtype='object')

In [64]:
ddd = dd[dd.person == 'Leonard'][['text','sbert_v_text', "text_1_shift","sbert_v_text_1_shift" ]].head(55).reset_index(drop=True)
# ddd.loc[0, "sbert_v_text"]
ddd

Unnamed: 0,text,sbert_v_text,text_1_shift,sbert_v_text_1_shift
0,"See, the liquid metal Terminators were created...","[0.01891789, -0.006340133, -0.028041823, -0.01...",,"[-0.023187159, 0.051497556, -0.0023922704, -0...."
1,Skynet is kinky? I don’t know.,"[0.046432924, 0.048894834, -0.00023074828, -0....","Okay, then riddle me this. Assuming all the go...","[0.01170853, 0.023849228, 0.004196144, -0.0166..."
2,"Alright, oh wait, they use it to in…","[0.030348463, 0.009559736, -0.03141301, -0.020...",Artificial intelligences do not have teen feti...,"[0.023357784, 0.08639555, -0.031002609, -0.013..."
3,What the hell is that?,"[0.012077914, -0.059381146, -0.012467875, -0.0...",Let’s go-oh-oh Ou-oooo-ut tonight. I have to g...,"[-0.054732356, 0.08357615, 0.039391242, -0.066..."
4,What? Oh we just had to… mail some letters and...,"[0.097821146, 0.05905508, -0.01299829, -0.0052...","You wanna prowl, be my night owl, (Leonard and...","[0.05469582, 0.025094619, -0.004739377, -0.042..."
5,"Oh, I give up.","[0.01933068, -0.0047633224, 0.0055105365, 0.04...",You’ll never guess what just happened.,"[-0.019691234, 0.10576699, 0.024145065, -0.023..."
6,"Believe it or not, personal growth. What happe...","[0.05228075, 0.11404847, 0.0035048623, -0.0318...",What was that?,"[-0.0009989009, -0.04978135, -0.020151217, 0.0..."
7,No you don’t. No he doesn’t.,"[-0.019268319, 0.14347932, -0.012406823, 0.021...",I have a conclusion based on an observation.,"[0.00017333697, 0.07467678, 0.0028149867, -0.0..."
8,"Oh, congratulations, what a lucky break.","[-0.043697108, 0.013066628, -0.018665018, -0.0...","Well, the girl they picked to play Mimi, she d...","[0.0457013, 0.008733222, 0.008208125, -0.03620..."
9,No you don’t. He doesn’t.,"[-0.035525065, 0.11832944, -0.005925776, 0.036...",I think I know.,"[0.055987854, 0.01944254, -0.020300826, -0.030..."


In [65]:
corp_q_text = ddd['text_1_shift'].apply(str).to_list()
corp_q_vect = ddd['sbert_v_text_1_shift'].to_list()

In [66]:
corp_a_vect = ddd['sbert_v_text'].to_list()
corp_a_text = ddd['text'].to_list()

In [117]:
from rank_bm25 import BM25Okapi # https://pypi.org/project/rank-bm25/
from sentence_transformers import SentenceTransformer, util, CrossEncoder
from torch import from_numpy
from pprint import pprint

class Ranker:

    def __init__(self, corpus: list, tokenizer = None, base_ranker: str = 'bm25kapi', base_ranker_path: str = None, extra_ranker: str = None, extra_ranker_path: str = None):
        '''
        inputs:
            corpus: list of strings (before tokenization) in case of 'bm25kapi' and list of vectors in case of 'bi_encoder'
            tokenizer: object to tokenize corpust as list[str] to list[list[str]] must have 
                        tokenize_corpus method or None if you don't need to tokenize corpus.
            base_ranker: ranker algorithm, one of 'bm25kapi' or 'bi_encoder'
            base_ranker_path: path to model if base_ranker = 'bi_encoder'
            extra_ranker: rerank results of base rnaker, only 'cross_encoder' is implemented
            extra_ranker_path: path to to the re-ranker
        '''
        self.corpus = corpus
        self.tokenizer = tokenizer
        self.base_ranker = base_ranker
        self.base_ranker_path = base_ranker_path
        self.extra_ranker = extra_ranker
        self.extra_ranker_path = extra_ranker_path        
        if self.base_ranker =='bm25kapi': # https://pypi.org/project/rank-bm25/
            self.first_ranker = BM25Okapi(tk.tokenize_corpus(corpus))
        elif self.base_ranker =='bi_encoder': # https://github.com/UKPLab/sentence-transformers/blob/master/examples/applications/retrieve_rerank/retrieve_rerank_simple_wikipedia.ipynb
            self.first_ranker = SentenceTransformer(self.base_ranker_path)
        else:
            raise KeyError(f"{self.base_ranker=} not implemented yet, see documentation")
        if extra_ranker != None:
            if extra_ranker == 'cross_encoder':
                self.second_ranker = CrossEncoder(self.base_ranker_path)
            else:
                raise KeyError(f"{extra_ranker.base_ranker=} not implemented yet, see documentation")
            
            
    def get_top_n(self, query: str, top_n: int = 1) -> list[str]:
        '''
        return top answers for query as list of strings (the answers)
        '''
        # only one ranker
        if self.extra_ranker is None:
            if self.base_ranker =='bm25kapi':
                q_tokenized = self.tokenizer.tokenize_corpus([query])[0]
                return self.first_ranker.get_top_n(q_tokenized, self.corpus, top_n)
            elif self.base_ranker =='bi_encoder':
                q_vectrorized = self.first_ranker.encode(query)
                self.corpus = [from_numpy(i) for i in self.corpus]
                curr_res_long = util.semantic_search(q_vectrorized, self.corpus, top_k=top_n)
                curr_res = [ self.corpus[k] for k in [i['corpus_id'] for i in curr_res_long[0]]]
                return curr_res
            else:
                raise KeyError(f"{self.base_ranker=} not implemented yet, see documentation")
                
        # two consequtive rankers:       
        else:
            pass

# corp_q = ['nan',
#  'Okay, then riddle me this. Assuming all the good Terminators were originally evil Terminators created by Skynet but then reprogrammed by the future John Connor, why would Skynet, an artificial computer intelligence, bother to create a petite hot 17 year-old killer robot?',
#  'Artificial intelligences do not have teen fetishes. ',
#  'Let’s go-oh-oh Ou-oooo-ut tonight. I have to go-oh-oh-oh ou-ooooo-ut tonight. ',
#  'You wanna prowl, be my night owl, (Leonard and Sheldon reappear, running down the stairs) we’ll take my… (appearing) Hey guys, hi! Where you going?']

from my_tokenize_vectorize import Tokenizer
ss  = 'I need to speak to you.'

# test for bm25
print('================BM25 test===============')
tk = Tokenizer(tokenizer = 'bm25', lower = True, lang = 'english')
r = Ranker(corpus = corp_q_text, tokenizer = tk, base_ranker = 'bm25kapi')
res = r.get_top_n(ss, top_n=5)
pprint(res)
# for i in range(len(res)):
#     print(i, res[i], ddd[ddd["text_1_shift"]==res[i]]["text"].reset_index(drop=True).iat[0], sep = '\t')

# test for pure bi encoder
print('================bi_encoder test===============')
tk = Tokenizer(tokenizer = 'wordpunct_tokenize', lower = True, lang = 'english')
r = Ranker(corpus = corp_q_vect, base_ranker = 'bi_encoder', base_ranker_path = "model\\triple-20e-1000-fit-all-mpnet-base-v2")
res = r.get_top_n(ss, top_n=5)

for i in res:
    print(ddd[  ddd['sbert_v_text_1_shift'].apply(lambda x: (x == i.numpy()).all())]["text_1_shift"].tolist()[0])

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\satyr\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


['I need to speak to you.',
 'Oh, too bad, well, I got to get to rehearsal, see you guys. ',
 'Then I suppose you could have agreed to go. ',
 'To help you. ',
 'You just lied to Penny.']
I need to speak to you.
You wanna prowl, be my night owl, (Leonard and Sheldon reappear, running down the stairs) we’ll take my… (appearing) Hey guys, hi! Where you going?
To help you. 
It’s okay, Leonard (hugs him.)
Thanks. I just wanted to come by and wish you guys luck with your symposium. 


In [111]:
(res[0].numpy() == ddd['sbert_v_text_1_shift'].iloc[0]).all()

False

In [105]:
ddd['sbert_v_text_1_shift'].iloc[0].shape

(768,)

In [91]:
ddd['sbert_v_text_1_shift'].iloc[0].shape

(768,)

In [28]:
for i in range(len(res[0])):
    print(i, res[0][i]['score'], corp_a_text[res[0][i]['corpus_id']], sep = '\t')

0	0.9067506194114685	It’s two o’clock in the morning
1	0.36087754368782043	I’m sorry, I’m not seeing the help.
2	0.2903665602207184	Sheldon, what is it? 
3	0.26684990525245667	Oh, well, thankyou.
4	0.23647460341453552	That’s very true. 


In [36]:
corp_a_text[res[0][1]['corpus_id']]

'I’m sorry, I’m not seeing the help.'

In [9]:
util.semantic_search?

[1;31mSignature:[0m
[0mutil[0m[1;33m.[0m[0msemantic_search[0m[1;33m([0m[1;33m
[0m    [0mquery_embeddings[0m[1;33m:[0m [0mtorch[0m[1;33m.[0m[0mTensor[0m[1;33m,[0m[1;33m
[0m    [0mcorpus_embeddings[0m[1;33m:[0m [0mtorch[0m[1;33m.[0m[0mTensor[0m[1;33m,[0m[1;33m
[0m    [0mquery_chunk_size[0m[1;33m:[0m [0mint[0m [1;33m=[0m [1;36m100[0m[1;33m,[0m[1;33m
[0m    [0mcorpus_chunk_size[0m[1;33m:[0m [0mint[0m [1;33m=[0m [1;36m500000[0m[1;33m,[0m[1;33m
[0m    [0mtop_k[0m[1;33m:[0m [0mint[0m [1;33m=[0m [1;36m10[0m[1;33m,[0m[1;33m
[0m    [0mscore_function[0m[1;33m:[0m [0mCallable[0m[1;33m[[0m[1;33m[[0m[0mtorch[0m[1;33m.[0m[0mTensor[0m[1;33m,[0m [0mtorch[0m[1;33m.[0m[0mTensor[0m[1;33m][0m[1;33m,[0m [0mtorch[0m[1;33m.[0m[0mTensor[0m[1;33m][0m [1;33m=[0m [1;33m<[0m[0mfunction[0m [0mcos_sim[0m [0mat[0m [1;36m0x0000021148D70D60[0m[1;33m>[0m[1;33m,[0m[1;33m
[0m[1;33m)[0m 

In [10]:
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')

In [11]:
bi_encoder

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False})
  (2): Normalize()
)