In [63]:
# !pip install transformers

In [32]:
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from transformers import AutoTokenizer, AutoModel
import torch

In [46]:
#Text
def get_texts():
  data = pd.read_csv('asr.csv', usecols=['clip_id', 'clean_text'])
  data = data[~data['clean_text'].isnull()]
  return data

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

In [64]:
#Load ASR with preprocessing
data = get_texts()
asr_texts = data['clean_text'].tolist()

#Load AutoModel from huggingface model repository
tokenizer = AutoTokenizer.from_pretrained("sberbank-ai/sbert_large_nlu_ru")
model = AutoModel.from_pretrained("sberbank-ai/sbert_large_nlu_ru")


encoded_input = tokenizer(asr_texts, padding=True, truncation=True, max_length=24, return_tensors='pt')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
model.to(device)
batch_size = 32
asr_embeddings = []

with torch.no_grad():
  for i in tqdm(range(0, len(asr_texts), batch_size)):
    texts_batch = encoded_input["input_ids"][i : i + batch_size].to(device)
    masks_batch = encoded_input["attention_mask"][i : i + batch_size].to(device) 
    type_batch = encoded_input['token_type_ids'][i : i + batch_size].to(device)

    model_output = model(texts_batch, type_batch, masks_batch) 
    asr_embeddings_batch = mean_pooling(model_output, masks_batch).cpu()
    asr_embeddings.append(asr_embeddings_batch)
    
asr_embeddings = np.concatenate(asr_embeddings, axis=0)


  0%|          | 0/235 [00:00<?, ?it/s]

In [58]:
data["sbert_mean_pooling_1024_embedding"] = list(asr_embeddings)

In [62]:
data[['clip_id', "sbert_mean_pooling_1024_embedding"]].to_csv('clip_id_emb_pooling')