In [1]:
import torch
import transformers
import datasets

print(f"Running on torch {torch.__version__}v, transformers {transformers.__version__}v, datasets {datasets.__version__}")

Running on torch 1.9.0v, transformers 4.8.1v, datasets 1.8.0


In [2]:
import numpy as np
import pandas as pd
from transformers import logging, BertConfig, BertTokenizer, BertModel
from datasets import load_dataset

import os
import sys
import random
import warnings
import gc; gc.enable()
from tqdm.notebook import tqdm
from IPython.display import clear_output

logging.set_verbosity_error()
warnings.filterwarnings('ignore')
SEED = 1618
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def set_seed(seed = 0):
    np.random.seed(seed)
    random_state = np.random.RandomState(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    return random_state

random_state = set_seed(SEED)

In [4]:
train = pd.read_csv('../data/claim_train_df.csv', low_memory=False)
validation = pd.read_csv('../data/claim_val_df.csv', low_memory=False)
test = pd.read_csv('../data/claim_test_df.csv', low_memory=False)

claim_ds = datasets.DatasetDict({
    'train': datasets.Dataset.from_pandas(train),
    'test': datasets.Dataset.from_pandas(test),
    'validation': datasets.Dataset.from_pandas(validation)})

train = pd.read_csv('../data/main_text_train_df.csv', low_memory=False)
validation = pd.read_csv('../data/main_text_val_df.csv', low_memory=False)
test = pd.read_csv('../data/main_text_test_df.csv', low_memory=False)

main_text_ds = datasets.DatasetDict({
    'train': datasets.Dataset.from_pandas(train),
    'test': datasets.Dataset.from_pandas(test),
    'validation': datasets.Dataset.from_pandas(validation)})

del train, validation, test
gc.collect()
clear_output()

In [5]:
claim_ds

DatasetDict({
    train: Dataset({
        features: ['sentence', 'claim_id', 'label'],
        num_rows: 9814
    })
    test: Dataset({
        features: ['sentence', 'claim_id', 'label'],
        num_rows: 1235
    })
    validation: Dataset({
        features: ['sentence', 'claim_id', 'label'],
        num_rows: 1217
    })
})

In [6]:
main_text_ds

DatasetDict({
    train: Dataset({
        features: ['claim_id', 'sent_id', 'sentence'],
        num_rows: 333010
    })
    test: Dataset({
        features: ['claim_id', 'sent_id', 'sentence'],
        num_rows: 42961
    })
    validation: Dataset({
        features: ['claim_id', 'sent_id', 'sentence'],
        num_rows: 41122
    })
})

In [7]:
cp = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
tokenizer = BertTokenizer.from_pretrained(cp)
model = BertModel.from_pretrained(cp)
model.to(DEVICE)
model.eval()
clear_output()

In [8]:
MAX_LENGTH=256
def tokenize_and_encode(examples):
    return tokenizer.batch_encode_plus(examples['sentence'], truncation=True, padding='max_length', max_length=MAX_LENGTH)

def generate_embeddings_dataframe_for(ds, feature_name):
    cols = ds["train"].column_names
    encoded_datasets = ds.map(tokenize_and_encode, batched=True, remove_columns=cols, num_proc=2)
    clear_output()
    encoded_datasets.set_format("torch")

    for key in encoded_datasets:
        cls_embeddings = []
        dataloader = torch.utils.data.DataLoader(encoded_datasets[key], batch_size=100)
        for batch in tqdm(dataloader):
            inputs = {
                    'attention_mask': batch['attention_mask'].to(DEVICE),
                    'input_ids': batch['input_ids'].to(DEVICE),
                    'token_type_ids': batch['token_type_ids'].to(DEVICE),
                }
            with torch.no_grad():
                output = model(**inputs)
                cls_embeddings.extend(output.last_hidden_state[:,0,:].detach().cpu().numpy())
        
        column_names = []
        for i in range(len(cls_embeddings[0])):
            column_names.append(str(i))
        
        embeddings_df = pd.DataFrame(data=cls_embeddings, columns=column_names)
        if feature_name == "claim":
            embeddings_df['claim_id'] = ds[key]['claim_id']
        else:
            embeddings_df['claim_id'] = ds[key]['claim_id']
            embeddings_df['sent_id'] = ds[key]['sent_id']
        
        filename = key +"_"+ feature_name +"_embeddings.csv"
        embeddings_df.to_csv('../data/'+filename, index=False)
    return 'Embeddings files generated'

In [9]:
generate_embeddings_dataframe_for(claim_ds, 'claim')

HBox(children=(FloatProgress(value=0.0, max=99.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))




'Embeddings files generated'

In [10]:
generate_embeddings_dataframe_for(main_text_ds, 'main_text')

HBox(children=(FloatProgress(value=0.0, max=3331.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=430.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=412.0), HTML(value='')))




'Embeddings files generated'