In [1]:
TOP_K = 5
N_RECALLS = 10
MAX_SEQ_LEN = 512


MODEL_NAME = "output_simcse_model"

import warnings
warnings.simplefilter('ignore')

import os
import re
import gc
import sys
import multiprocessing

import numpy as np
import pandas as pd
pd.set_option('display.max_columns', None)
from tqdm.auto import tqdm
from copy import deepcopy
import torch
import blingfire as bf
from sentence_transformers import SentenceTransformer
from sentence_transformers import util
import tokenizers
import transformers
print(f"tokenizers.__version__: {tokenizers.__version__}")
print(f"transformers.__version__: {transformers.__version__}")
from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from transformers import get_cosine_schedule_with_warmup, DataCollatorWithPadding
from sklearn.model_selection import train_test_split
from datasets import load_dataset
from pathlib import Path
from glob import glob
# 加载数据

DATA_DIR = './learning-equality-curriculum-recommendations'


def read_csv(file_path):
    df = pd.read_csv(file_path)
    if 'id' in df.columns:
        df  = df.drop(columns="id")
    return df

# 验证集 topic_id
final_name = './external_train_data/stem_dataset_gpt4.csv'

df_train = pd.concat([
    read_csv(i) for i in glob('./external_train_data/*') if 'stem_dataset_gpt4'  in i  or \
                                                          'stem_dataset.csv' in i or \
                                                          '15k_gpt3.5-turbo.csv' in i
])
df_train['is_train'] = 1
df_eval = read_csv('train.csv')
df_eval['is_train'] = 0
df_train = pd.concat([df_train, df_eval])
# df = pd.read_csv('retrive_dataset.csv')
# dev_ids = np.load('dev_id.npy',allow_pickle=True)
# dev_df =  df[df['url'].isin(dev_ids)]
dev_df = df_train.copy()
dev_df.reset_index(drop=True, inplace=True)
final_res = deepcopy(dev_df)
files = list(map(str, Path("./wiki_sci").glob("*.parquet")))
ds = load_dataset("parquet", data_files=files, split="train")
content_df = pd.DataFrame(ds)
# 加载预训练模型

# ====================================================
# Model
# ====================================================
class MeanPooling(nn.Module):
    def __init__(self):
        super(MeanPooling, self).__init__()

    def forward(self, last_hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        mean_embeddings = sum_embeddings / sum_mask
        return mean_embeddings

class WeightedLayerPooling(nn.Module):
    def __init__(self, num_hidden_layers, layer_start: int = 4, layer_weights=None):
        super(WeightedLayerPooling, self).__init__()
        self.layer_start = layer_start
        self.num_hidden_layers = num_hidden_layers
        self.layer_weights = layer_weights if layer_weights is not None \
            else nn.Parameter(
            torch.tensor([1] * (num_hidden_layers + 1 - layer_start), dtype=torch.float)
        )

    def forward(self, all_hidden_states):
        all_layer_embedding = all_hidden_states[self.layer_start:, :, :, :]
        weight_factor = self.layer_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(all_layer_embedding.size())
        weighted_average = (weight_factor * all_layer_embedding).sum(dim=0) / self.layer_weights.sum()
        return weighted_average

class CustomModel(nn.Module):
    def __init__(self, cfg, config_path=None, pretrained=False):
        super().__init__()
        self.cfg = cfg
        if config_path is None:
            self.config = AutoConfig.from_pretrained(cfg.model, output_hidden_states=True)
            # self.config.hidden_dropout = 0.
            # self.config.hidden_dropout_prob = 0.
            # self.config.attention_dropout = 0.
            # self.config.attention_probs_dropout_prob = 0.
        else:
            self.config = torch.load(config_path)

        if pretrained:
            self.model = AutoModel.from_pretrained(cfg.model, config=self.config)
        else:
            self.model = AutoModel.from_config(self.config)
        # if self.cfg.gradient_checkpointing:
        #     self.model.gradient_checkpointing_enable

        self.pool = MeanPooling()
        self.fc_dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.config.hidden_size, 1)
        self._init_weights(self.fc)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, inputs):
        outputs = self.model(**inputs)
        last_hidden_states = outputs[0]
        feature = self.pool(last_hidden_states, inputs['attention_mask'])
        #feature = F.normalize(feature, p=2, dim=1)
        return feature

    
def get_sentences(document):
    res = []
    _, sentence_offsets = bf.text_to_sentences_and_offsets(document)
    for o in sentence_offsets:
        if o[1]-o[0] < 20:
            continue
        sentence = document[o[0]:o[1]]
        res.append(sentence)
    return res
#tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = 'sentence-transformers/all-mpnet-base-v1'
tokenizer = AutoTokenizer.from_pretrained(model)

model = CustomModel(cfg=None, config_path=MODEL_NAME + '/config.pth', pretrained=False)
state = torch.load(MODEL_NAME + '/sentence-transformers-all-mpnet-base-v1_fold0_best.pth',
                   map_location=torch.device('cpu'))
model.load_state_dict(state['model'])

device = torch.device('cuda:1') if torch.cuda.device_count() > 1 else torch.device('cuda:0')
model.eval()
model.to(device)


class TestDataset(Dataset):
    def __init__(self, texts):
        self.texts = texts

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        text = self.texts[item]
        # text = self.texts[item].replace('[SEP]', '</s>')
        inputs = tokenizer(text,
                           max_length=512,
                           pad_to_max_length=True,
                           add_special_tokens=True,
                           return_offsets_mapping=False)

        for k, v in inputs.items():
            inputs[k] = torch.tensor(v, dtype=torch.long)
        return inputs

def get_model_feature(model, texts):
    feature_outs_all = []
    test_dataset = TestDataset(texts)
    test_loader = DataLoader(test_dataset,
                             batch_size=128,
                             shuffle=False,
                             collate_fn=DataCollatorWithPadding(tokenizer=tokenizer, padding='longest'),
                             num_workers=0, pin_memory=True, drop_last=False)

    # tk0 = tqdm(test_loader, total=len(test_loader))
    for inputs in tqdm(test_loader):
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        with torch.no_grad():
            feature_outs = model(inputs)
            feature_outs_all.append(feature_outs.cpu())

    feature_outs_all_final = torch.cat(feature_outs_all, dim=0)
    #print(feature_outs_all_final.shape)

    return feature_outs_all_final


tokenizers.__version__: 0.13.3
transformers.__version__: 4.32.0


Found cached dataset parquet (/root/.cache/huggingface/datasets/parquet/default-8fbfeb4bcfeca33a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


In [2]:
content_df['sentence'] = content_df['text'].apply(lambda x:get_sentences(x))

In [3]:
content_df = content_df.explode('sentence')
content_df.shape

(4086356, 4)

In [4]:
content_df.reset_index(drop=True, inplace=True)
content_df

Unnamed: 0,text,url,title,sentence
0,Stomatin also known as human erythrocyte integ...,https://en.wikipedia.org/wiki/Stomatin,Stomatin,Stomatin also known as human erythrocyte integ...
1,Stomatin also known as human erythrocyte integ...,https://en.wikipedia.org/wiki/Stomatin,Stomatin,Clinical significance \n\nStomatin is a 31 kDa...
2,Stomatin also known as human erythrocyte integ...,https://en.wikipedia.org/wiki/Stomatin,Stomatin,This gene encodes a member of a highly conserv...
3,Stomatin also known as human erythrocyte integ...,https://en.wikipedia.org/wiki/Stomatin,Stomatin,The encoded protein localizes to the cell memb...
4,Stomatin also known as human erythrocyte integ...,https://en.wikipedia.org/wiki/Stomatin,Stomatin,Loss of localization of the encoded protein is...
...,...,...,...,...
4086351,"Active rectification, or synchronous rectifica...",https://en.wikipedia.org/wiki/Active%20rectifi...,Active rectification,"IEE Proceedings - Electric Power Applications,..."
4086352,"Active rectification, or synchronous rectifica...",https://en.wikipedia.org/wiki/Active%20rectifi...,Active rectification,Digital Object Identifier:10.1049/ip-epa:19990...
4086353,"Active rectification, or synchronous rectifica...",https://en.wikipedia.org/wiki/Active%20rectifi...,Active rectification,"Santiago, A. Birchenough. (2005)."
4086354,"Active rectification, or synchronous rectifica...",https://en.wikipedia.org/wiki/Active%20rectifi...,Active rectification,Single Phase Passive Rectification versus Acti...


In [34]:
#test = pd.read_csv('test.csv')

In [5]:
topic_embedding_list = get_model_feature(model, dev_df['prompt'].values)
print('question embedding done')
print(topic_embedding_list.shape)

  0%|          | 0/439 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
You're using a MPNetTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████| 439/439 [07:28<00:00,  1.02s/it]

question embedding done
torch.Size([56095, 768])





In [35]:
# test_embedding_list = get_model_feature(model, test['prompt'].values)
# print('question embedding done')
# print(test_embedding_list.shape)

100%|██████████| 2/2 [00:01<00:00,  1.27it/s]

question embedding done
torch.Size([200, 768])





In [6]:
corpus_embeddings = get_model_feature(model, content_df['sentence'].values)
print('content embedding done')
print(corpus_embeddings.shape)

100%|██████████| 31925/31925 [9:08:47<00:00,  1.03s/it]  


content embedding done
torch.Size([4086356, 768])


In [66]:
#corpus_embeddings = torch.as_tensor(np.load('text_embedding.npy')).to('cuda')

In [7]:
np.save('text_sentence_embedding', corpus_embeddings.cpu().numpy())

In [8]:
content_df.to_parquet('wiki_sci_text_sentence.parquet')

In [None]:
N_RECALLS= 30
pred_final = []
pred_text = []
for idx, row in tqdm(dev_df.iterrows(), total=len(dev_df)):

    query_embedding = topic_embedding_list[idx, :]

    cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
    top_k = min([N_RECALLS, len(corpus_embeddings)])
    top_results = torch.topk(cos_scores, k=top_k)
    #print(top_results)
    indics = top_results[1].cpu().numpy()

    # threshold = 0.8
    # score_top = top_results[0].cpu().numpy()
    # in_use = np.where(score_top > threshold)
    # indics = indics[in_use]

    #pid = content_dict[lang]['id'][indics]
    try:
        pid = content_df['url'][indics]
        pred_final.append(' '.join(pid))

        pid = content_df['sentence'][indics]
        pred_text.append('<recall_wiki_text>'.join(pid))
    except:
        pred_final.append('')
        pred_text.append('')

dev_df['recall_ids'] = pred_final
dev_df['recall_text'] = pred_text


 60%|██████    | 33693/56095 [17:20:00<11:43:02,  1.88s/it]

In [36]:
N_RECALLS= 30
pred_final = []
pred_text = []
for idx, row in tqdm(test.iterrows(), total=len(test)):

    query_embedding = test_embedding_list[idx, :]

    cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
    top_k = min([N_RECALLS, len(corpus_embeddings)])
    top_results = torch.topk(cos_scores, k=top_k)
    #print(top_results)
    indics = top_results[1].cpu().numpy()

    # threshold = 0.8
    # score_top = top_results[0].cpu().numpy()
    # in_use = np.where(score_top > threshold)
    # indics = indics[in_use]

    #pid = content_dict[lang]['id'][indics]
    try:
        pid = content_df['url'][indics]
        pred_final.append(' '.join(pid))

        pid = content_df['sentence'][indics]
        pred_text.append('<recall_wiki_text>'.join(pid))
    except:
        pred_final.append('')
        pred_text.append('')

test['recall_ids'] = pred_final
test['recall_text'] = pred_text


  4%|▍         | 8/200 [00:16<06:38,  2.08s/it]


KeyboardInterrupt: 

In [41]:
test.iloc[0,1]

'Which of the following statements accurately describes the impact of Modified Newtonian Dynamics (MOND) on the observed "missing baryonic mass" discrepancy in galaxy clusters?'

In [39]:
print(pred_text[0].split('<recall_wiki_text>')) 

["MOND \n\nModified Newtonian Dynamics (MOND) is a relatively modern proposal to explain the galaxy rotation problem based on a variation of Newton's Second Law of Dynamics at low accelerations.", 'Modified Newtonian Dynamics (MOND)\nAccording to Brada and Milgrom,\n\nReferences\n\nObservational astronomy\nExtragalactic astronomy', 'The corresponding force acting on a test mass  is\n\nTo account for the anomalous rotation curves of spiral galaxies, Milgrom proposed a modification of this force law in the form\n\nwhere  is an arbitrary function subject to the following conditions:\n\nIn this form, MOND is not a complete theory: for instance, it violates the law of momentum conservation.', 'This unusual-looking galaxy appears to be one partner in a cosmic collision, and appeared to show dynamics consistent with a dark galaxy (and apparently inconsistent with the predictions of the Modified Newtonian Dynamics (MOND) theory).', 'MOND has had a considerable amount of success in predicting t

In [14]:
len(pred_final)

53190

In [13]:
len(dev_df)

56095

In [21]:
dev_df

Unnamed: 0,prompt,A,B,C,D,E,answer,is_train,recall_ids,recall_text
0,What is the molecular weight of the Stomatin p...,Stomatin is a 21 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,Stomatin is a 41 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,B,1,https://en.wikipedia.org/wiki/Stomatin https:/...,Although the wide distribution of stomatin and...
1,What is the potential role of Stomatin in cell...,Stomatin may play a role in the formation of m...,Stomatin may play a role in the formation of c...,Stomatin may play a role in the formation of c...,Stomatin may play a role in the formation of c...,Stomatin may play a role in the formation of c...,A,1,https://en.wikipedia.org/wiki/Stomatin https:/...,Although the wide distribution of stomatin and...
2,What is the condition associated with the loss...,Loss of localization of the Stomatin protein i...,Loss of localization of the Stomatin protein i...,Loss of localization of the Stomatin protein i...,Loss of localization of the Stomatin protein i...,Loss of localization of the Stomatin protein i...,A,1,https://en.wikipedia.org/wiki/Stomatin https:/...,Although the wide distribution of stomatin and...
3,What is the encoded protein by the STOM gene i...,The STOM gene in humans encodes for the protei...,The STOM gene in humans encodes for the protei...,The STOM gene in humans encodes for the protei...,The STOM gene in humans encodes for the protei...,The STOM gene in humans encodes for the protei...,C,1,https://en.wikipedia.org/wiki/GLUT1 https://en...,This protein interacts with STOM.<recall_wiki_...
4,Where does the Stomatin protein localize in th...,Stomatin protein localizes to the cell nucleus...,Stomatin protein localizes to the cell mitocho...,Stomatin protein localizes to the cell ribosom...,Stomatin protein localizes to the cell membran...,Stomatin protein localizes to the cell lysosom...,D,1,https://en.wikipedia.org/wiki/Stomatin https:/...,Although the wide distribution of stomatin and...
...,...,...,...,...,...,...,...,...,...,...
56090,What is the relation between the three moment ...,The three moment theorem expresses the relatio...,The three moment theorem is used to calculate ...,The three moment theorem describes the relatio...,The three moment theorem is used to calculate ...,The three moment theorem is used to derive the...,C,0,https://en.wikipedia.org/wiki/Theorem%20of%20t...,The deflection of k1 and k2 relative to the po...
56091,"What is the throttling process, and why is it ...",The throttling process is a steady flow of a f...,The throttling process is a steady adiabatic f...,The throttling process is a steady adiabatic f...,The throttling process is a steady flow of a f...,The throttling process is a steady adiabatic f...,B,0,https://en.wikipedia.org/wiki/Joule%E2%80%93Th...,Throttling is a fundamentally irreversible pro...
56092,What happens to excess base metal as a solutio...,"The excess base metal will often solidify, bec...",The excess base metal will often crystallize-o...,"The excess base metal will often dissolve, bec...","The excess base metal will often liquefy, beco...","The excess base metal will often evaporate, be...",B,0,https://en.wikipedia.org/wiki/Enthalpy%20chang...,The temperature of the solution eventually dec...
56093,"What is the relationship between mass, force, ...",Mass is a property that determines the weight ...,Mass is an inertial property that determines a...,Mass is an inertial property that determines a...,Mass is an inertial property that determines a...,Mass is a property that determines the size of...,D,0,https://en.wikipedia.org/wiki/Force https://en...,Newton's second law asserts the direct proport...


In [None]:
dev_df.to_parquet('all_train_with_retrive.parquet')

In [25]:
dev_df.iloc[123,0]

"What is the characteristic of the marginal distributions of Mardia's Multivariate Pareto distribution of the First Kind?"

In [26]:
print(dev_df.iloc[123,-1].split('<recall_wiki_text>'))

["Multivariate Pareto distributions\n\nMultivariate Pareto distribution of the first kind\n\nMardia's Multivariate Pareto distribution of the First Kind has the joint probability density function given by\n\nThe marginal distributions have the same form as (1), and the one-dimensional marginal distributions have a Pareto Type I distribution.", 'Bivariate Pareto distributions\n\nBivariate Pareto distribution of the first kind\n\nMardia (1962) defined a bivariate distribution with cumulative distribution function (CDF) given by\n\nand joint density function\n\n \nThe marginal distributions are Pareto Type 1 with density functions\n\n \n\nThe means and variances of the marginal distributions are\n\nand for a > 2, X1 and X2 are positively correlated with\n\nBivariate Pareto distribution of the second kind\n\nArnold   suggests representing the bivariate Pareto Type I complementary CDF by\n\nIf the location and scale parameter are allowed to differ, the complementary CDF is\n\nwhich has Pare

In [27]:
dev_df['length'] = dev_df['recall_text'].apply(lambda x:len(x.split()))

In [30]:
dev_df['recall_text'].isna().sum()

0

In [31]:
dev_df['length'].describe()

count    56095.000000
mean       565.878599
std        274.623093
min         78.000000
25%        448.000000
50%        531.000000
75%        629.000000
max      16483.000000
Name: length, dtype: float64

In [4]:
prompt_values = dev_df['prompt'].values.tolist()

In [24]:
len(dev_df)

56095

In [6]:
# 算分环节
dev_df['recall_ids'] = pred_final
# df_metric = dev_df.copy()
# df_metric['content_ids'] = df_metric['url']

In [32]:
corpus = pd.read_parquet('wiki_sci_text_sentence.parquet')

In [33]:
corpus

Unnamed: 0,text,url,title,sentence
0,Stomatin also known as human erythrocyte integ...,https://en.wikipedia.org/wiki/Stomatin,Stomatin,Stomatin also known as human erythrocyte integ...
1,Stomatin also known as human erythrocyte integ...,https://en.wikipedia.org/wiki/Stomatin,Stomatin,Clinical significance \n\nStomatin is a 31 kDa...
2,Stomatin also known as human erythrocyte integ...,https://en.wikipedia.org/wiki/Stomatin,Stomatin,This gene encodes a member of a highly conserv...
3,Stomatin also known as human erythrocyte integ...,https://en.wikipedia.org/wiki/Stomatin,Stomatin,The encoded protein localizes to the cell memb...
4,Stomatin also known as human erythrocyte integ...,https://en.wikipedia.org/wiki/Stomatin,Stomatin,Loss of localization of the encoded protein is...
...,...,...,...,...
4086351,"Active rectification, or synchronous rectifica...",https://en.wikipedia.org/wiki/Active%20rectifi...,Active rectification,"IEE Proceedings - Electric Power Applications,..."
4086352,"Active rectification, or synchronous rectifica...",https://en.wikipedia.org/wiki/Active%20rectifi...,Active rectification,Digital Object Identifier:10.1049/ip-epa:19990...
4086353,"Active rectification, or synchronous rectifica...",https://en.wikipedia.org/wiki/Active%20rectifi...,Active rectification,"Santiago, A. Birchenough. (2005)."
4086354,"Active rectification, or synchronous rectifica...",https://en.wikipedia.org/wiki/Active%20rectifi...,Active rectification,Single Phase Passive Rectification versus Acti...


In [None]:
corpus.to_parquet('wiki_sci_text_sentence_without_text.parquet')

In [7]:
# def get_pos_score(y_true, y_pred, top_n):
#     y_true = y_true.apply(lambda x: set(x.split()))
#     y_pred = y_pred.apply(lambda x: set(x.split()[:top_n]))
#     int_true = np.array([len(x[0] & x[1]) / len(x[0]) for x in zip(y_true, y_pred)])
#     return round(np.mean(int_true), 5)

# pos_score = get_pos_score(df_metric['content_ids'], df_metric['recall_ids'], 50)
# print(f'Our max positive score top 50 is {pos_score}')

# pos_score = get_pos_score(df_metric['content_ids'], df_metric['recall_ids'], 70)
# print(f'Our max positive score top 70 is {pos_score}')

# pos_score = get_pos_score(df_metric['content_ids'], df_metric['recall_ids'], 100)
# print(f'Our max positive score top 100 is {pos_score}')

# pos_score = get_pos_score(df_metric['content_ids'], df_metric['recall_ids'], 150)
# print(f'Our max positive score top 150 is {pos_score}')

# pos_score = get_pos_score(df_metric['content_ids'], df_metric['recall_ids'], 200)
# print(f'Our max positive score top 200 is {pos_score}')

# df_metric['content_ids'] = df_metric['content_ids'].astype(str).apply(lambda x: x.split())
# df_metric['recall_ids'] = df_metric['recall_ids'].astype(str).apply(lambda x: x.split())
# f2_scores = []

# N_RECALLS = [3, 5, 10, 30, 50, 100, 200, 300, 400, 500, 600]
# N_TOP_F2 = [5, 10, 15]
# # for n_top in N_TOP_F2:
# #     for _, row in tqdm(df_metric.iterrows(), total=len(df_metric)):
# #         true_ids = set(row['content_ids'])
# #         pred_ids = set(row['recall_ids'][:n_top])
# #         tp = len(true_ids.intersection(pred_ids))
# #         fp = len(pred_ids - true_ids)
# #         fn = len(true_ids - pred_ids)
# #         if pred_ids:
# #             precision = tp / (tp + fp)
# #             recall = tp / (tp + fn)
# #             f2 = tp / (tp + 0.2 * fp + 0.8 * fn)
# #         else:
# #             f2 = 0
# #         f2_scores.append(f2)
# #     print(f'Average F2@{n_top}:', np.mean(f2_scores))
# for n_recall in N_RECALLS:
#     total = 0
#     correct = 0
#     for _, row in tqdm(df_metric.iterrows(), total=len(df_metric)):
#         y_trues = row['content_ids']
#         y_preds = row['recall_ids'][:n_recall]
#         for y_true in y_trues:
#             total += 1
#             if y_true in y_preds:
#                 correct += 1
#     print(f'hitrate@{n_recall}:', correct/total)

In [8]:
def split_long_doc(text):
    max_lenth = 128
    window_size=16
    text_list = [i for i in text.split() if i]
    res = []
    
    i = 0
    while i + window_size < len(text_list):
        res.append(' '.join(text_list[i:min(i+max_lenth, len(text_list))]))
        i += 64
    
    return res

In [9]:
dev_df['recall_text'] = dev_df['recall_text'].apply(lambda x:x.split('<recall_wiki_text>'))

In [10]:
dev_df = dev_df.explode('recall_text')

In [11]:
dev_df['recall_sentence'] = dev_df['recall_text'].apply(lambda x:split_long_doc(x))

In [12]:
dev_df = dev_df.explode('recall_sentence')
dev_df

Unnamed: 0,id,prompt,A,B,C,D,E,answer,recall_ids,recall_text,recall_sentence
0,0,What is the molecular weight of the Stomatin p...,Stomatin is a 21 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,Stomatin is a 41 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,B,https://en.wikipedia.org/wiki/Stomatin https:/...,Stomatin also known as human erythrocyte integ...,Stomatin also known as human erythrocyte integ...
0,0,What is the molecular weight of the Stomatin p...,Stomatin is a 21 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,Stomatin is a 41 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,B,https://en.wikipedia.org/wiki/Stomatin https:/...,Stomatin also known as human erythrocyte integ...,"of red blood cells and other cell types, where..."
0,0,What is the molecular weight of the Stomatin p...,Stomatin is a 21 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,Stomatin is a 41 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,B,https://en.wikipedia.org/wiki/Stomatin https:/...,Stomatin also known as human erythrocyte integ...,where it may regulate ion channels and transpo...
0,0,What is the molecular weight of the Stomatin p...,Stomatin is a 21 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,Stomatin is a 41 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,B,https://en.wikipedia.org/wiki/Stomatin https:/...,Stomatin also known as human erythrocyte integ...,suggests a possible structural role for this p...
0,0,What is the molecular weight of the Stomatin p...,Stomatin is a 21 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,Stomatin is a 41 kDa integral membrane protein...,Stomatin is a 31 kDa integral membrane protein...,B,https://en.wikipedia.org/wiki/Stomatin https:/...,SH3 and cysteine-rich domain-containing protei...,SH3 and cysteine-rich domain-containing protei...
...,...,...,...,...,...,...,...,...,...,...,...
35894,35894,In which conditions was the expression of Smr3...,"Bacterial growth in TY, minimal medium (MM) an...","Bacterial growth in TY, minimal medium (MM) an...",Bacterial growth in TY and minimal medium (MM)...,Bacterial growth in TY and luteolin-MM broth o...,Bacterial growth in minimal medium (MM) and lu...,A,https://en.wikipedia.org/wiki/%CE%91r35%20RNA ...,Smaug is a RNA-binding protein in Drosophila t...,Smaug is a RNA-binding protein in Drosophila t...
35894,35894,In which conditions was the expression of Smr3...,"Bacterial growth in TY, minimal medium (MM) an...","Bacterial growth in TY, minimal medium (MM) an...",Bacterial growth in TY and minimal medium (MM)...,Bacterial growth in TY and luteolin-MM broth o...,Bacterial growth in minimal medium (MM) and lu...,A,https://en.wikipedia.org/wiki/%CE%91r35%20RNA ...,Smaug is a RNA-binding protein in Drosophila t...,"by maternal mRNAs like Hsp83, nanos, string, P..."
35894,35894,In which conditions was the expression of Smr3...,"Bacterial growth in TY, minimal medium (MM) an...","Bacterial growth in TY, minimal medium (MM) an...",Bacterial growth in TY and minimal medium (MM)...,Bacterial growth in TY and luteolin-MM broth o...,Bacterial growth in minimal medium (MM) and lu...,A,https://en.wikipedia.org/wiki/%CE%91r35%20RNA ...,Smaug is a RNA-binding protein in Drosophila t...,"of three miRNAs – miR-3, miR-6, miR-309 and mi..."
35894,35894,In which conditions was the expression of Smr3...,"Bacterial growth in TY, minimal medium (MM) an...","Bacterial growth in TY, minimal medium (MM) an...",Bacterial growth in TY and minimal medium (MM)...,Bacterial growth in TY and luteolin-MM broth o...,Bacterial growth in minimal medium (MM) and lu...,A,https://en.wikipedia.org/wiki/%CE%91r35%20RNA ...,Smaug is a RNA-binding protein in Drosophila t...,recruiting a protein called Cup (an eIF4E-bind...


In [13]:
dev_df = dev_df.fillna("")

In [None]:
sentence_embeddings = get_model_feature(model, dev_df['recall_sentence'].values)

 58%|█████▊    | 9050/15637 [5:29:44<4:02:35,  2.21s/it]

In [None]:
dev_df.reset_index(drop=True, inplace=True)

In [None]:
pred_final = []
N_RECALLS = 3
prompt_length = len(prompt_values)
for idx in tqdm(range(prompt_length)):

    query_text = prompt_values[idx]
    query_embedding = topic_embedding_list[idx, :]
    sentence_embeddings_index = dev_df[dev_df['question'] == query_text].index
    cos_scores = util.cos_sim(query_embedding.cuda(), sentence_embeddings[sentence_embeddings_index].cuda())[0]
    top_k = min([N_RECALLS, len(corpus_embeddings)])
    top_results = torch.topk(cos_scores, k=top_k)
    #print(top_results)
    indics = top_results[1].cpu().numpy()
    
    pid = dev_df['recall_sentence'][sentence_embeddings_index[indics]]
    pred_final.append('<new_recall_wiki_sep>'.join(pid))
    

In [None]:
len(pred_final)

In [None]:
final_res['recall_info'] = pred_final

In [112]:
final_res.to_csv(f'{final_name}_with_retrive.csv', index=None)

['conditions are not verified? if we assume or , we get the following differential equation (it has the same form in both cases, we will use only the notation of the temporal soliton): This equation has soliton-like solutions. For the first order (N = 1): The plot of is shown in the picture on the right. For higher order solitons () we can use the following closed form expression: It is a soliton, in the sense that it propagates without changing its shape, but it is not made by a normal pulse; rather, it is a lack of energy in a continuous time beam. The intensity is constant, but for a short time during which it jumps to zero and back again, thus generating a "dark pulse"\'. Those',
 'In optics, the term soliton is used to refer to any optical field that does not change during propagation because of a delicate balance between nonlinear and linear effects in the medium. There are two main kinds of solitons: spatial solitons: the nonlinear effect can balance the diffraction. The electro