In [1]:
from tqdm import tqdm
import datasets
import pandas as pd
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = datasets.load_from_disk(dataset_path='/home/tobias/CAS Machine Learning/taproject/General-Knowledge/train')


In [3]:
df = pd.DataFrame(data, columns=['Question', 'Answer'])
df

Unnamed: 0,Question,Answer
0,What is Artificial Intelligence?,Artificial Intelligence refers to the developm...
1,What are the two main categories of Artificial...,The two main categories of Artificial Intellig...
2,What is Machine Learning?,Machine Learning is a subset of Artificial Int...
3,What is Deep Learning?,Deep Learning is a subset of Machine Learning ...
4,What is Natural Language Processing?,Natural Language Processing is a subset of Art...
...,...,...
37630,Did Viv Richards have a stellar batting averag...,"Yes, Viv Richards had a stellar batting averag..."
37631,Has Michel Platini won the UEFA European Champ...,"Yes, Michel Platini has won the UEFA European ..."
37632,Did Brian Lara hold the record for the highest...,"Yes, Brian Lara held the record for the highes..."
37633,Was Johan Cruyff known for his influential pla...,"Yes, Johan Cruyff was known for his influentia..."


In [4]:
df['Question'] = df['Question'].apply(lambda x: 'query: '+str(x))
df['Answer'] = df['Answer'].apply(lambda x: 'passage: '+str(x))
df

Unnamed: 0,Question,Answer
0,query: What is Artificial Intelligence?,passage: Artificial Intelligence refers to the...
1,query: What are the two main categories of Art...,passage: The two main categories of Artificial...
2,query: What is Machine Learning?,passage: Machine Learning is a subset of Artif...
3,query: What is Deep Learning?,passage: Deep Learning is a subset of Machine ...
4,query: What is Natural Language Processing?,passage: Natural Language Processing is a subs...
...,...,...
37630,query: Did Viv Richards have a stellar batting...,"passage: Yes, Viv Richards had a stellar batti..."
37631,query: Has Michel Platini won the UEFA Europea...,"passage: Yes, Michel Platini has won the UEFA ..."
37632,query: Did Brian Lara hold the record for the ...,"passage: Yes, Brian Lara held the record for t..."
37633,query: Was Johan Cruyff known for his influent...,"passage: Yes, Johan Cruyff was known for his i..."


In [5]:
def process_chunk(chunk, tokenizer, model, batch_size=32):
    embeddings_list = []
    for i in range(0, len(chunk), batch_size):
        batch = chunk[i:i+batch_size]
        input_texts = [f"query: {text}" if idx % 2 == 0 else f"passage: {text}" for idx, text in enumerate(batch)]
        
        batch_dict = tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt')
        
        with torch.no_grad():
            outputs = model(**batch_dict)
        
        chunk_embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
        chunk_embeddings = F.normalize(chunk_embeddings, p=2, dim=1)
        embeddings_list.append(chunk_embeddings)
    
    return torch.cat(embeddings_list)



In [6]:
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel

def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


In [7]:
chunk_size = 3000
chunks = np.array_split(df,chunk_size)

  return bound(*args, **kwds)


In [8]:
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-small-v2')
model = AutoModel.from_pretrained('intfloat/e5-small-v2')

all_embeddings = []
texts = []
for chunk in tqdm(chunks):
    input_text = []
    input_text.extend(chunk['Question'])
    input_text.extend(chunk['Answer'])
    chunk_embeddings = process_chunk(input_text, tokenizer, model)
    all_embeddings.append(chunk_embeddings)
    texts.append(input_text)

final_embeddings = torch.cat(all_embeddings)


100%|██████████| 3000/3000 [16:36<00:00,  3.01it/s]


In [9]:
torch.save(final_embeddings, 'final_embeddings.pt')

In [10]:
limit = None # Testzwecke
# Berechnet die Scores für den ersten Eintrag gegen den Rest
scores = (final_embeddings[:1] @ final_embeddings[1:limit].T) * 100
print(len(scores[0]))
print(scores.tolist())

75269
[[87.85655212402344, 92.95586395263672, 85.57814025878906, 87.42521667480469, 83.62891387939453, 93.98435974121094, 83.9970932006836, 93.98435974121094, 81.0186767578125, 85.54678344726562, 89.12947082519531, 89.11884307861328, 90.93838500976562, 84.83082580566406, 88.03341674804688, 85.26959991455078, 85.87922668457031, 86.71925354003906, 88.07791137695312, 83.60580444335938, 88.22415924072266, 82.2174301147461, 82.4678955078125, 88.29364013671875, 86.01194763183594, 86.88536071777344, 80.99850463867188, 91.44918823242188, 79.96177673339844, 94.26539611816406, 86.17044830322266, 89.2818832397461, 87.96067810058594, 83.87419128417969, 86.17100524902344, 80.29085540771484, 86.59498596191406, 84.33937072753906, 84.50930786132812, 84.04639434814453, 88.09358215332031, 77.0896987915039, 88.08671569824219, 84.06803894042969, 83.71202087402344, 84.34056854248047, 82.3663330078125, 85.9171142578125, 79.44811248779297, 85.38343048095703, 79.93467712402344, 89.83026885986328, 84.873069763

In [11]:
len(final_embeddings)

75270

In [13]:
final_embeddings.shape

torch.Size([75270, 384])