In [1]:
from transformers import BertTokenizer, BertModel
from pymongo import MongoClient, ASCENDING, DESCENDING, HASHED
from tqdm.notebook import tqdm
import numpy as np
import torch

def set_device():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device != "cuda":
        print("WARNING: For this notebook to perform best, "
              "if possible, in the menu under `Runtime` -> "
              "`Change runtime type.`  select `GPU` ")
    else:
        print("GPU is enabled in this notebook.")
    return device


device = set_device()


model = BertModel.from_pretrained("bert-base-uncased")
model = model.to(device)

GPU is enabled in this notebook.


In [2]:
conn = MongoClient()

process_fibvid = conn.process_fibvid
tweets_bert_tokens  = process_fibvid.tweets_bert_tokens
tweets_bert  = process_fibvid.tweets_bert

tweets_bert.create_index([('num_tokens', ASCENDING)])
tweets_bert.create_index([('created', ASCENDING)])
tweets_bert.create_index([('tweetId', HASHED)])
tweets_bert.create_index([('userId', HASHED)])

'userId_hashed'

In [3]:
total = tweets_bert_tokens.count_documents({})


def process(to_process, size, model=model, col=tweets_bert):
    if len(to_process) == 0:
        return []
    in_data = {'input_ids': np.empty((len(to_process), size), dtype=np.int32),
              'attention_mask': np.empty((len(to_process), size), dtype=np.int32),
              'token_type_ids': np.empty((len(to_process), size), dtype=np.int32)}
    for i, v in enumerate(to_process):
        for k in in_data.keys():
            in_data[k][i, :] = v[k]
    with torch.no_grad():
        for k in in_data.keys():
                in_data[k] = torch.from_numpy(in_data[k]).to(device)
        post = {k: v.cpu().numpy() for k, v in model(**in_data).items()}
        to_save = []
        for i, v in enumerate(to_process):
            data = {'num_tokens': v['num_tokens'],
                   'created': v['created'],
                   'userId': v['userId'],
                   'tweetId': v['tweetId']}
            for k, v in post.items():
                data[k] = v[i, ...].tolist()
            to_save.append(data)
            #del data['last_hidden_state']
    return to_save

to_process = []
BATCH_SIZE = 128
DB_BATCH_SIZE = 500
current_size = -1
db_batch = []


with conn.start_session() as session: 
    for tweet in tqdm(tweets_bert_tokens.find(no_cursor_timeout=True, 
                                        session=session).sort([('num_tokens', DESCENDING)]), total=total):
        if current_size != tweet['num_tokens']:
            db_batch.extend(process(to_process, current_size))
            to_process = []
            current_size = tweet['num_tokens']
        if len(to_process) == BATCH_SIZE:
            db_batch.extend(process(to_process, current_size))
            to_process = []
        if len(db_batch) > DB_BATCH_SIZE:
            tweets_bert.insert_many(db_batch)
            db_batch = []
        to_process.append(tweet)

        
db_batch.extend(process(to_process, current_size))
if len(db_batch) > 0:
    tweets_bert.insert_many(db_batch)
    db_batch = []

  0%|          | 0/299118 [00:00<?, ?it/s]