In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split
import transformers
from transformers import BertModel, BertTokenizer
from transformers import get_linear_schedule_with_warmup
from torchtext.vocab import Vectors
import collections
from collections import Counter
import csv 
import numpy as np
from tqdm import tqdm
import math
import warnings
warnings.filterwarnings('ignore')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [2]:
# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)
logger = logging.getLogger(__name__)

In [5]:
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)

In [152]:
# vectors = Vectors(name='enwiki_100d.txt')

In [7]:
class DataProcess():
    
    def __init__(self, root, text_id_root, entity_id_root, labels_root, entity_vector_root):
        self.root = root
        self.text_id_root = text_id_root
        self.entity_id_root = entity_id_root
        self.labels_root = labels_root
        self.entity_vector_root = entity_vector_root
    
    def prepare_data(self):
        text_list = []
        entity_list= []
        label_list = []
        with open(self.root, 'r') as f:
            reader = csv.reader(f, delimiter='\t')
            for line in reader:
                text_list.append(line[0])
                entity_list.append(line[1].split('|'))
                label_list.append(int(line[2]))
        return text_list, entity_list, label_list

    # Function to get token ids for a list of texts 
    def encode_text(self):
        text_list, _, label_list = self.prepare_data()
        all_input_ids = []    
        num = 0
        for text in text_list:
            num += 1
            if num % 10000 == 0:
                print(num)
            input_ids = tokenizer.encode(
                            text,                      
                            add_special_tokens = True,             
                            truncation=True,
                            padding = 'max_length',     
                            return_tensors = 'pt'       
                       )
            all_input_ids.append(input_ids)    
        all_input_ids = torch.cat(all_input_ids, dim=0)
        labels = torch.tensor(label_list, dtype=torch.float)
        # Save tensor
        torch.save(all_input_ids, self.text_id_root)
        torch.save(labels,self.labels_root)
        print("Saved success!")
        return all_input_ids, labels
    
    def encode_entity(self, en_vocab_size, en_pad_size):
        _, entity_list, _= self.prepare_data()
        # get all entity
        entity_list_all = [en for entity in entity_list for en in entity]
        # build entity vocab
        entity_vocab = collections.OrderedDict(Counter(entity_list_all).most_common(en_vocab_size-2))
        entity_list_uniq = [entity for entity in entity_vocab.keys()]
        entity_to_index = {entity : i+2 for i, entity in enumerate(entity_list_uniq)}
        entity_to_index['<unk>'] = 0
        entity_to_index['<pad>'] = 1
        entity_to_index = collections.OrderedDict(sorted(entity_to_index.items(), key=lambda entity_to_index: entity_to_index[1]))
        index_to_entity = [entity for i, entity in enumerate(entity_to_index)]
        # build entity vector
        idx_to_vector=[]
        print("Entity vocab size: ", len(entity_to_index))
        for entity in entity_to_index.keys():
            if 0 == (vectors["ENTITY/"+entity.replace(' ','_')] == vectors["<unk>"]).sum():
                idx_to_vector.append(vectors["ENTITY/"+entity.replace(' ','_')])
            else:
                idx_to_vector.append(vectors[entity.lower().replace(' ','_')])
        entity_vector = torch.stack(idx_to_vector)
        torch.save(entity_vector, self.entity_vector_root)
        # build entity index
        all_entity_ids = []
        for entities in entity_list:
            entity_ids = [entity_to_index.get(entity, entity_to_index["<unk>"]) for entity in entities][:en_pad_size]
            for i in range(en_pad_size - len(entity_ids)):
                entity_ids.append(entity_to_index["<pad>"])
            all_entity_ids.append(entity_ids)
        all_entity_ids = torch.tensor(all_entity_ids)
        torch.save(all_entity_ids, self.entity_id_root)
        print("Saved success!")
        return all_entity_ids, entity_vector, entity_to_index, index_to_entity
        
    
    def load_data(self, ratio, batch_size):
        all_input_ids = torch.load(self.text_id_root)
        all_entity_ids = torch.load(self.entity_id_root)
        labels = torch.load(self.labels_root)
        # Split data into train and validation
        dataset = TensorDataset(all_input_ids, all_entity_ids, labels)
        train_size = int(ratio * len(dataset))
        valid_size = len(dataset) - train_size
        train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

        # Create train and validation dataloaders
        train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
        valid_dataloader = DataLoader(valid_dataset, batch_size = batch_size, shuffle = False)

        return train_dataloader, valid_dataloader

In [13]:
ratio = 0.8
batch_size = 32
en_vocab_size = 200000
en_pad_size = 12
root = "data/all_data_1015.tsv"
text_id_root = "data/text_ids_1015.pt"
entity_id_root = "data/entity_ids.pt"
labels_root = "data/labels_1015.pt"
entity_vector_root = "data/entity_vectors.pt"

In [14]:
processor = DataProcess(root, text_id_root, entity_id_root, labels_root, entity_vector_root)

In [15]:
# all_input_ids, labels = processor.encode_text()

In [11]:
all_entity_ids, entity_vector, entity_to_index, index_to_entity = processor.encode_entity(en_vocab_size, en_pad_size)

Entity vocab size:  200000
Saved success!


In [16]:
train_dataloader, valid_dataloader = processor.load_data(ratio, batch_size)

In [17]:
print("Num of train_dataloader: ", len(train_dataloader))
print("Num of valid_dataloader: ", len(valid_dataloader))

Num of train_dataloader:  4248
Num of valid_dataloader:  1062


In [13]:
# en_vocab_size = 100
# en_pad_size = 12

In [14]:
# entity_list_all = [en for entity in entity_list for en in entity]
# entity_vocab = collections.OrderedDict(Counter(entity_list_all).most_common(en_vocab_size-2))
# # entity_vocab

In [15]:
# entity_list_uniq = [entity for entity in entity_vocab.keys()]
# entity_to_index = {entity : i+2 for i, entity in enumerate(entity_list_uniq)}
# entity_to_index['<unk>'] = 0
# entity_to_index['<pad>'] = 1
# entity_to_index = collections.OrderedDict(sorted(entity_to_index.items(), key=lambda entity_to_index: entity_to_index[1]))
# index_to_entity = [entity for i, entity in enumerate(entity_to_index)]

In [16]:
# idx_to_vector=[]
# for entity in entity_to_index.keys():
#     if 0 == (vectors["ENTITY/"+entity.replace(' ','_')] == vectors["<unk>"]).sum():
#         idx_to_vector.append(vectors["ENTITY/"+entity.replace(' ','_')])
#     else:
#         idx_to_vector.append(vectors[entity.lower().replace(' ','_')])

In [18]:
# entity_vector = torch.stack(idx_to_vector)

In [19]:
# entity_vector.shape

In [20]:
# all_entity_ids = []
# for entities in entity_list:
#     entity_ids = [entity_to_index.get(entity, entity_to_index["<unk>"]) for entity in entities][:en_pad_size]
#     for i in range(en_pad_size - len(entity_ids)):
#         entity_ids.append(entity_to_index["<pad>"])
#     all_entity_ids.append(entity_ids)
# all_entity_ids = torch.tensor(all_entity_ids)

In [21]:
# all_entity_ids.shape

In [180]:
class Model(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.bert = BertModel.from_pretrained(config.model_name)
        self.en_encoder = EntityEncoder(config)
        self.dropout = nn.Dropout(config.dropout_prob)
        self.fc1 = nn.Linear(self.bert.pooler.dense.weight.shape[0]+self.en_encoder.mlp[2].weight.shape[0], config.fc_hidden_size)
        self.fc2 = nn.Linear(config.fc_hidden_size, config.output_size)
        
        
    def configure_optimizers(self, train_config):
#         param_optimizer = list(model.named_parameters())  
#         no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
#         optimizer_grouped_parameters = [
#                 {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': train_config.weight_decay},
#                 {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
#         optimizer = transformers.AdamW(optimizer_grouped_parameters, lr=train_config.learning_rate, betas=train_config.betas)
        optimizer = torch.optim.AdamW(self.parameters(), lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer

    def forward(self, input_ids, entity_ids=None, labels=None, token_type_ids=None, attention_mask=None):
        _, bert_output = self.bert(input_ids, token_type_ids, attention_mask,)
        en_encoder_output = self.en_encoder(entity_ids)
        x = torch.cat((bert_output, pooled_output),dim=1)
        # pooled_output: [batch_size, dim=768]
        x = self.dropout(x)
        y_pred = self.fc(x).squeeze(-1)
        # y_pred: [batch_size, output_dim]
        if labels is not None:
            loss = F.binary_cross_entropy_with_logits(y_pred, labels)
            return y_pred, loss
        else:
            return y_pred 

In [181]:
class EntityEncoder(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config):
        super().__init__()
        self.en_embeddings = nn.Embedding(config.en_vocab_size, config.en_embd_dim, padding_idx=config.pad_index)

        self.ln1 = nn.LayerNorm(config.en_embd_dim, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout_prob)
        self.mlp = nn.Sequential(
            nn.Linear(config.en_embd_dim, config.en_hidden_size1),
            nn.GELU(),
            nn.Linear(config.en_hidden_size1, config.en_hidden_size2),
            nn.Dropout(config.dropout_prob),
        )
        self.avgpool = torch.nn.AvgPool1d(config.en_pad_size)
        self.ln2 = nn.LayerNorm(config.en_hidden_size2, eps=1e-12)

    def forward(self, input_ids):
        input_shape = input_ids.size()
        embeddings = self.en_embeddings(input_ids)
        
        x = self.ln1(embeddings)
        x = self.dropout(x)
        
        x = self.mlp(embeddings).transpose(1,2)
        x = self.avgpool(x).squeeze(-1)
        x = self.ln2(x)
        
        return x

In [182]:
class ModelConfig:
    """ base GPT config, params common to all GPT versions """
    dropout_prob = 0.1
    
    def __init__(self, output_size, model_name, en_vocab_size, en_embd_dim, en_hidden_size1, en_hidden_size2, en_pad_size, pad_index, **kwargs):
        self.output_size = output_size
        self.model_name = model_name
        self.en_vocab_size = en_vocab_size
        self.en_embd_dim = en_embd_dim
        self.en_hidden_size1 = en_hidden_size1
        self.en_hidden_size2 = en_hidden_size2
        self.en_pad_size = en_pad_size
        self.pad_index = pad_index
        for k, v in kwargs.items():
            setattr(self, k, v)

In [183]:
class Trainer:

    def __init__(self, model, train_loader, test_loader, config):
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.config = config

        # take over whatever gpus are on the system
        self.device = 'cpu'
        if torch.cuda.is_available():
            self.device = torch.cuda.current_device()
            self.model = torch.nn.DataParallel(self.model).to(self.device)

    def save_checkpoint(self):
        # DataParallel wrappers keep raw model object in .module attribute
        raw_model = self.model.module if hasattr(self.model, "module") else self.model
        logger.info("saving %s", self.config.ckpt_path)
        torch.save(raw_model.state_dict(), self.config.ckpt_path)
        
    def binary_accuracy(self, preds, y):
        rounded_preds = torch.round(torch.sigmoid(preds))
        correct = (rounded_preds == y).float()
        acc = correct.sum() / len(correct)
        return acc

    def train(self):
        model, config = self.model, self.config
        raw_model = model.module if hasattr(self.model, "module") else model
        optimizer = raw_model.configure_optimizers(config)

        def run_epoch(split):
            is_train = split == 'train'
            model.train(is_train)
            loader = self.train_loader if is_train else self.test_loader
            
            losses = []
            all_y = []
            all_y_pred = []
            pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
            for it, (text_ids, entity_ids, y) in pbar:
                # place data on the correct device
                text_ids = text_ids.to(self.device)
                entity_ids = entity_ids.to(device)
                y = y.to(self.device)
                # forward the model
                with torch.set_grad_enabled(is_train):
                    y_pred, loss = model(text_ids, entity_ids, y)
                    loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
                    losses.append(loss.item())
                    step_score = self.binary_accuracy(y_pred, y)
                    all_y.extend(y)
                    all_y_pred.extend(y_pred)
                
                if is_train:

                    # backprop and update the parameters
                    model.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
                    optimizer.step()

                    # decay the learning rate based on our progress
                    if config.lr_decay:
                        self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
                        if self.tokens < config.warmup_tokens:
                            # linear warmup
                            lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
                        else:
                            # cosine learning rate decay
                            progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
                            lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
                        lr = config.learning_rate * lr_mult
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                    else:
                        lr = config.learning_rate

                    # report progress
                    pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. score {step_score:.5f}. lr {lr:e}")

            if not is_train:
                test_loss = float(np.mean(losses))
                all_y = torch.stack(all_y, dim=0)
                all_y_pred = torch.stack(all_y_pred, dim=0)
                test_score = self.binary_accuracy(all_y_pred, all_y)
                logger.info("test loss: %f", test_loss)
                logger.info("test score: %f", test_score)
                return test_loss

        self.tokens = 0 # counter used for learning rate decay
        best_loss = float('inf')
#         best_loss = run_epoch('test')
        for epoch in range(config.max_epochs):

            run_epoch('train')
            if self.test_loader is not None:
                test_loss = run_epoch('test')

            # supports early stopping based on the test loss, or just save always if no test set is provided
            good_model = self.test_loader is None or test_loss < best_loss
            if self.config.ckpt_path is not None and good_model:
                best_loss = test_loss
                self.save_checkpoint()

In [184]:
class TrainerConfig:
    # optimization parameters
    max_epochs = 10
    learning_rate = 3e-4
    betas = (0.9, 0.95)
    grad_norm_clip = 1.0
    weight_decay = 0.1 # only applied on matmul weights
    # learning rate decay params: linear warmup followed by cosine decay to 10% of original
    lr_decay = False
    warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
    final_tokens = 260e9 # (at what point we reach 10% of original LR)
    # checkpoint settings
    ckpt_path = 'local-likely-model.pt'
    num_workers = 0 # for DataLoader

    def __init__(self, **kwargs):
        for k,v in kwargs.items():
            print(k,v)
            setattr(self, k, v)

In [185]:
output_size = 1
en_embd_dim = 100
en_hidden_size1 = 128
en_hidden_size2 = 256
pad_index = 0

In [186]:
mconf = ModelConfig(output_size, model_name, en_vocab_size, en_embd_dim, en_hidden_size1, 
                    en_hidden_size2, en_pad_size, pad_index,fc_hidden_size=512)

In [187]:
model = Model(mconf)

In [None]:
model

In [128]:
for par in model.bert.embeddings.parameters():
    par.requires_grad = False
for par in model.bert.encoder.layer[:11].parameters():
    par.requires_grad = False

In [129]:
model.en_encoder.en_embeddings.weight.data.copy_(torch.load(entity_vector_root))
model.en_encoder.en_embeddings.weight.requires_grad = False

In [130]:
print('{} : all params: {:4f}M'.format(model._get_name(), sum(p.numel() for p in model.parameters()) / 1000 / 1000))
print('{} : need grad params: {:4f}M'.format(model._get_name(), sum(p.numel() for p in model.parameters() if p.requires_grad) / 1000 / 1000))

Model : all params: 129.529929M
Model : need grad params: 7.726153M


In [131]:
tconf = TrainerConfig(max_epochs=2, learning_rate=6e-4, lr_decay=True, 
                      warmup_tokens=32*200, final_tokens=2*batch_size*len(train_dataloader),
                      num_workers=1)

max_epochs 2
learning_rate 0.0006
lr_decay True
warmup_tokens 6400
final_tokens 271872
num_workers 1


In [132]:
trainer = Trainer(model, train_dataloader, valid_dataloader, tconf)

In [133]:
trainer.train()

epoch 1 iter 4247: train loss 0.08797. score 0.94444. lr 3.114076e-04: 100%|██████████| 4248/4248 [45:44<00:00,  1.55it/s]
10/19/2020 12:07:26 - test loss: 0.122334
10/19/2020 12:07:26 - test score: 0.956888
10/19/2020 12:07:26 - saving local-likely-model.pt
epoch 2 iter 4247: train loss 0.11281. score 0.94444. lr 6.000000e-05: 100%|██████████| 4248/4248 [45:38<00:00,  1.55it/s]
10/19/2020 13:02:30 - test loss: 0.107255
10/19/2020 13:02:30 - test score: 0.963185
10/19/2020 13:02:30 - saving local-likely-model.pt


In [134]:
class Predict:
    
    def __init__(self, model):
        self.model = model.to(device)
    
    def predict(self, text, entities):
        input_ids = tokenizer.encode(
                        text,                      
                        add_special_tokens = True,             
                        truncation=True,
                        padding = 'max_length',     
                        return_tensors = 'pt'       
                   ).to(device)
        
        entity_ids = [entity_to_index.get(entity, entity_to_index["<unk>"]) for entity in entities][:en_pad_size]
        for i in range(en_pad_size - len(entity_ids)):
            entity_ids.append(entity_to_index["<pad>"])
        entity_ids = torch.tensor(entity_ids).unsqueeze(0).to(device)
        self.model.eval()
        pred = torch.sigmoid(self.model(input_ids, entity_ids)[0])
        return pred.item()
    
    def count_acc(self, text_list, local):
        result = []
        for text in text_list:
            result.append(self.predict(text))
        result = torch.tensor(result, dtype = torch.float)
        if local:
            acc = sum(result > 0.5).item()/len(result)
        else:
            acc = sum(result < 0.5).item()/len(result)
        return result, acc
        

In [135]:
predict = Predict(model)

In [136]:
test_text_list = []
test_entity_list = []
with open('data/test_data_1k.tsv') as f:
    reader= csv.reader(f, delimiter='\t')
    for line in reader:
        test_text_list.append(line[0])
        test_entity_list.append(line[1].split('|'))

In [137]:
# test_text_list[:2]

In [139]:
# model_predict = []
# for text, entities in zip(test_text_list,test_entity_list):
#     prob = predict.predict(text,entities)
#     model_predict.append(prob)

In [142]:
print(len(model_predict))
fout = open('model-predict.tsv','w')
for prob in model_predict:
    fout.write('{}\n'.format(prob))

999
