In [189]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import json
import pandas as pd
import time
import faiss

In [140]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
bi_model = AutoModel.from_pretrained("models/bi_model")

In [141]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

Преобразования

In [142]:
def bi_mean_pool(token_embeds: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
    in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(in_mask.sum(1), min=1e-9)
    return pool


def bi_encode(input_texts, tokenizer: AutoTokenizer, model: AutoModel, device: str = "cpu"
) -> torch.tensor:

    model.eval()
    tokenized_texts = tokenizer(input_texts, max_length=128,
                                padding='max_length', truncation=True, return_tensors="pt")
    token_embeds = bi_model(tokenized_texts["input_ids"].to(device),
                         tokenized_texts["attention_mask"].to(device)).last_hidden_state
    pooled_embeds = bi_mean_pool(token_embeds, tokenized_texts["attention_mask"].to(device))
    return pooled_embeds

Выбор k кандидатов с помощью BI-енкодера

In [145]:
import json 
def get_top_k(query, corpus, top_k=5):
    """
      Выбор k кандидатов. Bi-Encoder
    """
    bi_pooled_embeds = torch.tensor(corpus['pooled_embeds'].apply(json.loads))

    bi_pooled_embeds_query = bi_encode(query, tokenizer, bi_model, device)
    bi_pooled_embeds_query = bi_pooled_embeds_query.cpu().detach().numpy() 

    similarities = cosine_similarity(bi_pooled_embeds_query, bi_pooled_embeds)

    sim_indexies = np.argsort(similarities)[0, ::-1]
    sim_indexies = sim_indexies[:top_k]
    return corpus.iloc[sim_indexies], similarities[0, sim_indexies]

Подгружаем данные и ищем релевантных кандидатов

In [166]:
import pandas as pd

house_answers = pd.read_csv('data/house_answers.csv')

In [167]:
house_answers.head()

Unnamed: 0,name,line,responder,response,token_embeds,pooled_embeds
0,James,You can't go in there.,House,"Who are you, and why are you wearing a tie?",0,"[0.343617707490921, 0.14434118568897247, -0.02..."
1,James,I'm Dr. Cuddy's new assistant. Can I tell her...,House,Yes. I would like to know why she gets a secr...,0,"[0.07738427072763443, -0.024036986753344536, 0..."
2,James,"I'm her assistant, not her secretary. I gradu...",House,Hmm. I didn't know they had a secretarial sch...,0,"[0.18388719856739044, -0.1642012745141983, -0...."
3,Cuddy,"Dr. House, we are in the middle of a meeting.",House,What's with hiring a male secretary? JDate no...,0,"[-0.20013472437858582, -0.08547274768352509, 0..."
4,Stacy,He is cute. Be careful.,House,She's not like you. She can't just walk into ...,0,"[0.061828095465898514, -0.13099761307239532, 0..."


In [147]:
query = "You can't go in there."
candidates = get_top_k(query, house_answers) 

In [148]:
candidates

(          name                                         line responder  \
 0        James                       You can't go in there.     House   
 15587    Cuddy           You can't ride that thing in here.     House   
 12424  Beasley   You can't go up there. It's yard time now.     House   
 3310   Cameron            You can't just be walking around.     House   
 3786   Cameron                           You can't do that.     House   
 
                                                 response  \
 0            Who are you, and why are you wearing a tie?   
 15587   Speaking of things, (He looks through the sta...   
 12424                                  Put it on my tab.   
 3310                                Well, then, stop me.   
 3786    Can't do what? Administer a prescription pain...   
 
                                             token_embeds  \
 0      [[ 0.34750035  0.039516   -0.07297491 ... -0.1...   
 15587  [[ 0.18516737  0.18311346 -0.09143019 ... -0.4...   
 12

# Cross Encoder

In [126]:
MAX_LENGTH = 128
class CrossEncoderBert(torch.nn.Module):
    def __init__(self, max_length: int = MAX_LENGTH):
        super().__init__()
        self.max_length = max_length
        self.bert_model = AutoModel.from_pretrained('distilbert-base-uncased')
        self.bert_tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
        self.linear = torch.nn.Linear(self.bert_model.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]  # Use the CLS token's output
        return self.linear(pooled_output)

In [None]:
ce_model = CrossEncoderBert()
ce_model.load_state_dict(torch.load('models/CE_model', weights_only=True))
ce_model.to(device)
ce_model.eval()

CrossEncoderBert(
  (bert_model): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1):

In [163]:
def get_ranked_docs(
    query: str, candidates,
    tokenizer: AutoTokenizer, 
    finetuned_ce: CrossEncoderBert 
) -> None:
    corpus = candidates['line'].to_list()

    queries = [query] * len(corpus)
    tokenized_texts = tokenizer(
        queries, corpus, max_length=MAX_LENGTH, padding=True, truncation=True, return_tensors="pt"
    ).to(device)

    # Finetuned CrossEncoder model scoring
    with torch.no_grad():
        ce_scores = finetuned_ce(tokenized_texts['input_ids'], tokenized_texts['attention_mask']).squeeze(-1)
        ce_scores = torch.sigmoid(ce_scores)  # Apply sigmoid if needed

    # Process scores for finetuned model
    print(f"Query - {query} [Finetuned Cross-Encoder]\n---")
    scores = ce_scores.cpu().numpy()
    scores_ix = np.argsort(scores)[::-1]
    for ix in scores_ix:  # Limit to corpus size
        print(f"{scores[ix]: >.2f}\t{corpus[ix]}")
        
    return candidates.iloc[scores_ix], scores[scores_ix]

In [165]:
ranked_candidates = get_ranked_docs(query, candidates[0], tokenizer, ce_model)

Query - You can't go in there. [Finetuned Cross-Encoder]
---
0.98	 You can't go in there.
0.92	 You can't ride that thing in here.
0.90	 You can't go up there. It's yard time now.
0.90	 You can't do that.
0.84	 You can't just be walking around.


In [168]:
def get_answer(query: str):
    candidates = get_top_k(query, house_answers) 
    ranked_candidate = get_ranked_docs(query, candidates[0], tokenizer, ce_model)
    return ranked_candidate

In [169]:
query = "hello"
get_answer(query)

Query - hello [Finetuned Cross-Encoder]
---
0.92	 Hello?
0.92	 Hello?
0.92	 Hello?
0.92	 Hello?
0.81	 Hi.


(          name     line responder  \
 9159     Cuddy   Hello?     House   
 14994    Chase   Hello?     House   
 6117   Cameron   Hello?     House   
 2870   Foreman   Hello?     House   
 16135     Park      Hi.     House   
 
                                                 response  token_embeds  \
 9159                Don't hang up. What was the verdict?             0   
 14994                                              Yeah.             0   
 6117                               He's not a sociopath.             0   
 2870    [in a hazmat suit on a hands-free phone] I'm ...             0   
 16135   I'm not interested in another department's sl...             0   
 
                                            pooled_embeds  
 9159   [-0.26658883690834045, -0.49618983268737793, -...  
 14994  [-0.26658883690834045, -0.49618983268737793, -...  
 6117   [-0.26658883690834045, -0.49618983268737793, -...  
 2870   [-0.26658883690834045, -0.49618983268737793, -...  
 16135  [0.1605694

In [170]:
!pip3.11 install faiss-cpu

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting faiss-cpu
  Downloading faiss_cpu-1.10.0-cp311-cp311-macosx_11_0_arm64.whl.metadata (4.4 kB)
Downloading faiss_cpu-1.10.0-cp311-cp311-macosx_11_0_arm64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.10.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


In [186]:
import faiss

index = faiss.IndexFlatIP((768))
index.add(torch.tensor(house_answers['pooled_embeds'].apply(json.loads)))

In [198]:
def get_answer_faiss(query: str):
    start_time = time.time()
    bi_pooled_embeds_query = bi_encode(query, tokenizer, bi_model, device)
    bi_pooled_embeds_query = bi_pooled_embeds_query.cpu().detach().numpy() 
    # candidates = get_top_k(query, house_answers) 
    candidates = index.search(bi_pooled_embeds_query, k=10)
    
    print(candidates[1])
    candidates = house_answers.iloc[candidates[1][0]]
    
    ranked_candidate = get_ranked_docs(query, candidates, tokenizer, ce_model)
    end_time = time.time() - start_time
    return *ranked_candidate, end_time

In [199]:
query = "You can't go in there."
# print(get_answer(query))
print(get_answer_faiss(query))

[[    0 10405  1633  8123 12501 13547  7608   975 12240 11388]]
Query - You can't go in there. [Finetuned Cross-Encoder]
---
0.98	 You can't go in there.
0.93	 You were just in there.
0.92	 You can't.
0.88	 You're not staying here.
0.83	 Don't do this.
0.83	 Don't do this.
0.83	 Just leave it alone.
0.82	 Don't do it.
0.80	 You're alone.
0.73	 This is not okay. Use your own bathroom.
(          name                                       line responder  \
0        James                     You can't go in there.     House   
7608    Kutner                    You were just in there.     House   
1633   Cameron                                 You can't.     House   
10405   Wilson                   You're not staying here.     House   
11388    Cuddy                             Don't do this.     House   
12240    Amber                             Don't do this.     House   
975     Kalvin                       Just leave it alone.     House   
12501    Alvie                              