In [None]:
%pip install nltk transformers torch annoy seaborn matplotlib scikit-learn PyPDF2 plotly

In [1]:
import nltk
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import torch
from transformers import BertTokenizer, BertModel
from nltk.tokenize import word_tokenize, sent_tokenize
from multiprocessing import Pool
from tqdm import tqdm
from collections import defaultdict
import torch
from torch.nn import DataParallel
from paper_processing_for_embeddings import preprocess_and_read_sentences

nltk.download('punkt')
nltk.download('stopwords')

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/aayushgupta/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/aayushgupta/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)

if torch.cuda.device_count() > 1:
    model = DataParallel(model)

model = model.to(device)

In [3]:
def process_pdfs_in_parallel(folder_path, n, num_workers=8):
    all_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.pdf')]
    processed_files = all_files[:n]

    with Pool(num_workers) as p:
        results = list(tqdm(p.imap(preprocess_and_read_sentences, processed_files), total=len(processed_files)))

    return results

def count_pdfs_in_directory(directory_path):
    return len([f for f in os.listdir(directory_path) if f.endswith('.pdf') and os.path.isfile(os.path.join(directory_path, f))])

In [4]:
def embed_text_batch(text_list, batch_size=100):
    all_embeddings = []
    for i in range(0, len(text_list), batch_size):
        batch_texts = text_list[i:i + batch_size]
        inputs = tokenizer(batch_texts, padding='longest', return_tensors='pt', truncation=True, max_length=512)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = model(**inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1).detach().cpu().numpy()
        all_embeddings.extend(embeddings)
    return all_embeddings


def get_word_embeddings(sentence, tokenizer, model):
    if not sentence or type(sentence) != str:
        print(f"Invalid sentence: {sentence}")
        return {}
    
    # Truncate sentence to a reasonable length if necessary
    sentence = sentence[:512]  # Example truncation
    try:
        inputs = tokenizer(sentence, return_tensors='pt', padding=True, truncation=True, max_length=512)
        outputs = model(**inputs)
        token_embeddings = outputs.last_hidden_state.squeeze(0)
        tokenized_sentence = tokenizer.tokenize(sentence)
        word_embeddings = {word: token_embeddings[i].detach().cpu().numpy() for i, word in enumerate(tokenized_sentence)}
        return word_embeddings
    except Exception as e:
        print(f"Error tokenizing sentence: {sentence}. Error: {e}")
        return {}


In [5]:
def create_word_embeddings(folder_path, num_papers=None, num_workers=8):
    num_papers_to_process = count_pdfs_in_directory(folder_path) if num_papers is None else num_papers
    preprocessed_data = process_pdfs_in_parallel(folder_path, num_papers_to_process, num_workers)

    embeddings_dict = defaultdict(dict)
    for sentences, file_path in tqdm(preprocessed_data):  # 'data' is now a tuple of (sentences, file_path)
        for sentence in sentences:  # Iterating over sentences
            word_embeddings = get_word_embeddings(sentence, tokenizer, model)
            for word, embedding in word_embeddings.items():
                embeddings_dict[word]['embedding'] = embedding
                embeddings_dict[word]['file'] = file_path  # Using the specific file path
    return embeddings_dict


In [7]:
papers_path = './PapersDirectory/papers'

embeddings_dict = create_word_embeddings(papers_path, num_papers=2)

with open('new_word_embeddings.pkl', 'wb') as f:
    pickle.dump(embeddings_dict, f)

100%|██████████| 2/2 [00:01<00:00,  1.89it/s]
100%|██████████| 2/2 [00:35<00:00, 17.84s/it]
