In [146]:
import torch
import csv 
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import transformers
from transformers import BertModel, BertTokenizer
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [144]:
# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)
logger = logging.getLogger(__name__)

In [3]:
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)

In [91]:
def prepare_data(filename):
    text_list = []
    label_list = []
    f = open(filename, 'r')
    reader = csv.reader(f, delimiter='\t')
    for line in reader:
        text_list.append(line[0])
        label_list.append(int(line[1]))
    return text_list, label_list

In [92]:
# Function to get token ids for a list of texts 
def encode_fn(text_list):
    all_input_ids = []    
    for text in text_list:
        input_ids = tokenizer.encode(
                        text,                      
                        add_special_tokens = True,  
                        max_length = 512,           
                        truncation=True,
                        pad_to_max_length = True,     
                        return_tensors = 'pt'       
                   )
        all_input_ids.append(input_ids)    
    all_input_ids = torch.cat(all_input_ids, dim=0)
    return all_input_ids

In [93]:
def split_data(ratio, batch_size, filename):
    
    text_list, label_list = prepare_data(filename)
    all_input_ids = encode_fn(text_list)
    labels = torch.tensor(label_list, dtype=torch.float)
    # Split data into train and validation
    dataset = TensorDataset(all_input_ids, labels)
    train_size = int(ratio * len(dataset))
    valid_size = len(dataset) - train_size
    train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

    # Create train and validation dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
    valid_dataloader = DataLoader(valid_dataset, batch_size = batch_size, shuffle = False)
    
    return train_dataloader, valid_dataloader

In [94]:
ratio = 0.8
batch_size = 4
filename = "data/head.tsv"
train_dataloader, valid_dataloader = split_data(ratio, batch_size, filename)

In [95]:
print("Num of train_dataloader: ", len(train_dataloader))
print("Num of valid_dataloader: ", len(valid_dataloader))

20
5


In [125]:
class Bert(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.output_size = config.output_size
        self.bert = BertModel.from_pretrained(config.model_name)
        self.dropout = nn.Dropout(config.dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.output_size)
        
    def configure_optimizers(self, train_config):
        optimizer = torch.optim.AdamW(self.parameters(), lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer

    def forward(self, input_ids, labels=None, token_type_ids=None, attention_mask=None):
        _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,)
        # pooled_output: [batch_size, dim=768]
        x = self.dropout(pooled_output)
        y_pred = self.classifier(x).squeeze(-1)
        # y_pred: [batch_size, output_dim]
        if labels is not None:
            loss = F.binary_cross_entropy_with_logits(y_pred, labels)
            return y_pred, loss
        else:
            return y_pred 

In [126]:
class BertConfig:
    """ base GPT config, params common to all GPT versions """
    dropout_prob = 0.1

    def __init__(self, hidden_size, output_size, model_name, **kwargs):
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.model_name = model_name
        for k, v in kwargs.items():
            setattr(self, k, v)

In [128]:
def binary_accuracy(preds, y):
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float()
    acc = correct.sum() / len(correct)
    return acc

In [129]:
HIDDEN_SIZE = 768
OUTPUT_SIZE = 1

In [130]:
mconf = BertConfig(HIDDEN_SIZE, OUTPUT_SIZE, model_name)

In [131]:
model = Bert(mconf).to(device)

In [132]:
# model

In [133]:
for par in model.parameters():
    par.requires_grad = False
model.classifier.weight.requires_grad = True

In [134]:
print('{} : all params: {:4f}M'.format(model._get_name(), sum(p.numel() for p in model.parameters()) / 1000 / 1000))
print('{} : need grad params: {:4f}M'.format(model._get_name(), sum(p.numel() for p in model.parameters() if p.requires_grad) / 1000 / 1000))

Bert : all params: 109.483009M
Bert : need grad params: 0.000768M


In [135]:
optimizer = AdamW(model.parameters(), lr=2e-5)
epochs = 4
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

In [136]:
# for epoch in range(epochs):
#     model.train()
#     total_loss, total_val_loss = 0, 0
#     total_eval_accuracy = 0
#     for it, (x,y) in enumerate(train_dataloader):
#         x = x.to(device)
#         y = y.to(device)
#         model.zero_grad()
#         y_pred, loss = model(x, y, token_type_ids=None, attention_mask=None)
#         total_loss += loss.item()
#         loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#         optimizer.step() 
#         scheduler.step()
        
#     model.eval()
#     for it, (x,y) in enumerate(valid_dataloader):
#         x = x.to(device)
#         y = y.to(device)
#         with torch.no_grad():
#             y_pred, loss = model(x, y, token_type_ids=None, attention_mask=None)
                
#             total_val_loss += loss.item()
#             total_eval_accuracy += binary_accuracy(y_pred, y)
            
#     avg_train_loss = total_loss / len(train_dataloader)
#     avg_val_loss = total_val_loss / len(valid_dataloader)
#     avg_val_accuracy = total_eval_accuracy / len(valid_dataloader)
    
#     print(f'Train loss     : {avg_train_loss}')
#     print(f'Validation loss: {avg_val_loss}')
#     print(f'Accuracy: {avg_val_accuracy:.2f}')
#     print('\n')

In [137]:
class Trainer:

    def __init__(self, model, train_loader, test_loader, config):
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.config = config

        # take over whatever gpus are on the system
        self.device = 'cpu'
        if torch.cuda.is_available():
            self.device = torch.cuda.current_device()
            self.model = torch.nn.DataParallel(self.model).to(self.device)

    def save_checkpoint(self):
        # DataParallel wrappers keep raw model object in .module attribute
        raw_model = self.model.module if hasattr(self.model, "module") else self.model
        logger.info("saving %s", self.config.ckpt_path)
        torch.save(raw_model.state_dict(), self.config.ckpt_path)
        
    def binary_accuracy(self, preds, y):
        rounded_preds = torch.round(torch.sigmoid(preds))
        correct = (rounded_preds == y).float()
        acc = correct.sum() / len(correct)
        return acc

    def train(self):
        model, config = self.model, self.config
        raw_model = model.module if hasattr(self.model, "module") else model
        optimizer = raw_model.configure_optimizers(config)

        def run_epoch(split):
            is_train = split == 'train'
            model.train(is_train)
            loader = self.train_loader if is_train else self.test_loader
            
            losses = []
            all_y = []
            all_y_pred = []
            pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
            for it, (x, y) in pbar:
                # place data on the correct device
                x = x.to(self.device)
                y = y.to(self.device)
                # forward the model
                with torch.set_grad_enabled(is_train):
                    y_pred, loss = model(x, y)
                    loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
                    losses.append(loss.item())
                    step_score = self.binary_accuracy(y_pred, y)
                    all_y.extend(y)
                    all_y_pred.extend(y_pred)
                
                if is_train:

                    # backprop and update the parameters
                    model.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
                    optimizer.step()

                    # decay the learning rate based on our progress
                    if config.lr_decay:
                        self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
                        if self.tokens < config.warmup_tokens:
                            # linear warmup
                            lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
                        else:
                            # cosine learning rate decay
                            progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
                            lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
                        lr = config.learning_rate * lr_mult
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                    else:
                        lr = config.learning_rate

                    # report progress
                    pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. score {step_score:.5f}. lr {lr:e}")

            if not is_train:
                test_loss = float(np.mean(losses))
                all_y = torch.stack(all_y, dim=0)
                all_y_pred = torch.stack(all_y_pred, dim=0)
                test_score = self.binary_accuracy(all_y_pred, all_y)
                logger.info("test loss: %f", test_loss)
                logger.info("test score: %f", test_score)
                return test_loss

        self.tokens = 0 # counter used for learning rate decay
        best_loss = run_epoch('test')
        for epoch in range(config.max_epochs):

            run_epoch('train')
            if self.test_loader is not None:
                test_loss = run_epoch('test')

            # supports early stopping based on the test loss, or just save always if no test set is provided
            good_model = self.test_loader is None or test_loss < best_loss
            if self.config.ckpt_path is not None and good_model:
                best_loss = test_loss
                self.save_checkpoint()

In [138]:
class TrainerConfig:
    # optimization parameters
    max_epochs = 10
    learning_rate = 3e-4
    betas = (0.9, 0.95)
    grad_norm_clip = 1.0
    weight_decay = 0.1 # only applied on matmul weights
    # learning rate decay params: linear warmup followed by cosine decay to 10% of original
    lr_decay = False
    warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
    final_tokens = 260e9 # (at what point we reach 10% of original LR)
    # checkpoint settings
    ckpt_path = 'bert-model.pt'
    num_workers = 0 # for DataLoader

    def __init__(self, **kwargs):
        for k,v in kwargs.items():
            print(k,v)
            setattr(self, k, v)

In [139]:
tconf = TrainerConfig(max_epochs=4, learning_rate=6e-4,lr_decay=True, 
                      warmup_tokens=512*20, final_tokens=2*batch_size*len(train_dataloader),
                      num_workers=1)

max_epochs 4
learning_rate 0.0006
lr_decay True
warmup_tokens 10240
final_tokens 160
num_workers 1


In [140]:
trainer = Trainer(model, train_dataloader, valid_dataloader, tconf)

In [147]:
trainer.train()

10/12/2020 14:32:31 - test loss: 0.759015
10/12/2020 14:32:31 - test score: 0.400000
epoch 1 iter 19: train loss 0.63812. score 0.75000. lr 4.687500e-06: 100%|██████████| 20/20 [00:01<00:00, 16.00it/s]
10/12/2020 14:32:33 - test loss: 0.816740
10/12/2020 14:32:33 - test score: 0.400000
epoch 2 iter 19: train loss 0.58460. score 0.50000. lr 9.375000e-06: 100%|██████████| 20/20 [00:01<00:00, 16.07it/s]
10/12/2020 14:32:35 - test loss: 0.816679
10/12/2020 14:32:35 - test score: 0.400000
epoch 3 iter 19: train loss 0.67329. score 0.50000. lr 1.406250e-05: 100%|██████████| 20/20 [00:01<00:00, 16.02it/s]
10/12/2020 14:32:36 - test loss: 0.811993
10/12/2020 14:32:36 - test score: 0.400000
epoch 4 iter 19: train loss 0.85680. score 0.25000. lr 1.875000e-05: 100%|██████████| 20/20 [00:01<00:00, 16.00it/s]
10/12/2020 14:32:38 - test loss: 0.809926
10/12/2020 14:32:38 - test score: 0.400000
