In [None]:
from transformers import BertModel, BertTokenizer, BertConfig
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

In [None]:
pretrained_Model = 'bert-base-uncased'  # BERT
# pretrained_Model = 'allenai/scibert_scivocab_uncased'
tokenizer = BertTokenizer.from_pretrained(pretrained_Model, do_lower_case = True) 
BERT_model = BertModel.from_pretrained(pretrained_Model)

In [None]:
class SentDataset(Dataset):
    def __init__(self, tokenizer, max_length, data_dir):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data_dir = data_dir
        self.data = self.read_file()
        
    def read_file(self):
        with open(self.data_dir,'r',encoding='utf-8') as file: #50000_WoS
            docs = file.readlines()
            return docs
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        text = self.data[index]
        text = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length)
        
        input_ids = torch.tensor(text.input_ids)
        token_type_ids = torch.tensor(text.token_type_ids)
        attention_mask = torch.tensor(text.attention_mask)
        
        return input_ids, token_type_ids, attention_mask
        

In [None]:
dataset = SentDataset(tokenizer, 512, '../datasets/50000_WoS.txt')

In [None]:
input_ids, token_type_ids, attention_mask = dataset[0]

In [None]:
dataloader = DataLoader(
        dataset,
        batch_size=256,
        num_workers=0,
        pin_memory=True,
        drop_last=False,
        shuffle=False
    )

In [None]:
input_ids, token_type_ids, attention_mask = next(iter(dataloader))

In [None]:
input_ids.shape

In [None]:
doc_vectors_cls = []
doc_vectors_meanpool = []

#BERT_model = torch.nn.DataParallel(BERT_model, device_ids=[0, 1, 2, 3])

for step, batch in tqdm(enumerate(dataloader)):   # A batch is 256, so the number of iterations is 50,000/256 ≈ 196
    input_ids, token_type_ids, attention_mask = batch
    with torch.no_grad():
        out = BERT_model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        mean_out = (out.last_hidden_state * attention_mask.unsqueeze(-1)).sum(1) / torch.sum(attention_mask, dim=1, keepdim=True)
        
        for l in out.pooler_output:
            doc_vectors_cls.append(l.numpy())
        for l in mean_out:
            doc_vectors_meanpool.append(l.numpy())


In [None]:
print(len(doc_vectors_cls))
print(len(doc_vectors_meanpool))

In [None]:
'''
np.save('sentence_vectors.npy',sentence_features)

import numpy as np
arr = np.load('sentence_vectors.npy')
print(arr)
'''

In [None]:
from sklearn.cluster import KMeans
classNumber = 10
# doc_vectors_meanpool
# doc_vectors_cls
doc_vectors = doc_vectors_meanpool
# doc_vectors = doc_vectors_cls
kmean_model = KMeans(n_clusters = classNumber).fit(doc_vectors)
labels = kmean_model.labels_

from collections import Counter
center_dict = Counter(labels)
center_dict

In [None]:
def get_ground_truth_label():
    ground_truth_label = []
    # 50000_WoS_WC.txt
    # 50000_MedLine_Label.txt
    # 56529_CV_Label.txt
    with open('../datasets/50000_MedLine_Label.txt','r',encoding = 'utf-8') as f: #50000_WoS_WC
        lines = f.readlines()
        for line in lines:
            line = int(line.replace('\n',''))
            ground_truth_label.append(line)
    return ground_truth_label
ground_truth_label = get_ground_truth_label()
# len(ground_truth_label) 50000

In [None]:
# wos: 0.7733 0.7960 0.7896
# med: 0.4123 0.4718 0.5092

In [None]:
from sklearn import metrics
print(metrics.adjusted_rand_score(labels, ground_truth_label))
print(metrics.fowlkes_mallows_score(labels, ground_truth_label))
print(metrics.adjusted_mutual_info_score(labels, ground_truth_label))