In [1]:
import torch
import numpy as np
from transformers import BertTokenizer
from wikipedia2vec import Wikipedia2Vec
import csv
import logging
import warnings

In [2]:
from utils import *
from data import DataProcess
from model import Model, ModelConfig

In [3]:
warnings.filterwarnings('ignore')
logger = logging.getLogger(__name__)

In [4]:
set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # confirm device
print('device:', device.type)

In [5]:
# use pretrained bert model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)

In [6]:
# use pretrained wiki_vector model
model_file = '/data/suyinpei/wiki_vector.model'
wiki2vec = Wikipedia2Vec.load(model_file)

## Data process

In [7]:
ratio = 0.8 # ratio of train data to valid data
batch_size = 32 # batch size
en_pad_size = 12 # max entity number of one data
en_embd_dim = 100 # entity embedding dim
idf_file = '/data/suyinpei/idf_bigram5.txt'
entity_frep_file = '/data/suyinpei/entity_frep.tsv'
data_root = "/data/suyinpei/all_data_1028.tsv" # data: docid, text, entities, label
text_id_root = "/data/suyinpei/text_ids_1028.pt" # data_size * 512
labels_root = "/data/suyinpei/labels_1028.pt" # data_size
entity_id_root = "/data/suyinpei/entity_ids_1028.pt" # data_size * 12
entity_length_root = "/data/suyinpei/entity_length_1028.pt" # data_size
entity_score_root = "/data/suyinpei/entity_score_1028.pt" # data_size * 3
entity_vector_root = "/data/suyinpei/entity_vectors_1028.pt" # en_vocab_size * 100

In [8]:
processor = DataProcess(data_root, text_id_root, labels_root, entity_id_root, entity_length_root, entity_score_root)

In [9]:
idf_dict, unk_idf = processor.load_idf(idf_file)
entity_score_dict = processor.load_entity_score_dict(entity_frep_file)

Entity Score vocab size:  308750


In [10]:
# # run this when using new data
# all_input_ids, labels = processor.encode_text(tokenizer)

In [11]:
# get entity vocab for predict
entity_to_index, index_to_entity = processor.encode_entity()

All Entity number:  7744598
Entity vocab size:  1600870


In [12]:
# # run this when use new data
# all_input_ids, labels = processor.encode_text(tokenizer)
# build_entity_vector = processor.build_entity_vector(entity_to_index, index_to_entity, wiki2vec, idf_dict, unk_idf, en_embd_dim, entity_vector_root)
# all_entity_ids, all_entity_length = processor.build_entity_id(entity_to_index, index_to_entity, en_pad_size)

In [13]:
# get entity score mean and std
_, entity_score_mean, entity_score_std = processor.build_entity_score(entity_score_dict)

Entity score mean:  tensor([[  0.0493, -15.2043,   0.2277]])
Entity score std:  tensor([[  6.7894, 332.7149,   2.0418]])


In [14]:
entity_vector = processor.load_entity_vector(entity_vector_root) # get pretrained entity_vector

Entity vector shape:  torch.Size([1600870, 100])


In [15]:
train_dataloader, valid_dataloader = processor.load_data(ratio, batch_size) # build train/valid dataloader

Num of train_dataloader:  12715
Num of valid_dataloader:  3179


## Model

In [16]:
mconf = ModelConfig(model_name, entity_vector, en_embd_dim, en_hidden_size1=128, 
                    en_hidden_size2=128, en_score_dim=3, use_en_encoder=True)

In [17]:
model = Model(mconf)