In [1]:
import sys
sys.path.append("/work01/home/wxxie/project/drug-gen/mollvae/MolLVAE/code")
from dataset import DatasetSplit
from opt import get_parser
from model.model import LVAE





from moses.utils import CircularBuffer, Logger

import torch
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import _LRScheduler

from tqdm import tqdm
import numpy as np
import random
import math

############ utils func and class

def get_trainable_params(model):
    
    return (p for p in model.parameters() if p.requires_grad)

def get_lr_annealer(optimizer, config):
    
    if config.lr_anr_type == "SGDR":
        return CosineAnnealingLRWithRestart(optimizer, config)
    
def get_kl_annealer(n_epoch, kl_anr_type):
    
    if kl_anr_type == "const":
        return KLAnnealer(n_epoch, config.kl_e_start, config.kl_w_start, config.kl_w_start)
    elif kl_anr_type == "linear_inc":
        return KLAnnealer(n_epoch, config.kl_e_start, config.kl_w_start, config.kl_w_end)
    
def get_n_epoch(lr_annealer_type):
    
    if lr_annealer_type == "SGDR":
        return sum(config.lr_period * (config.lr_mult_coeff ** i)
            for i in range(config.lr_n_restarts))
    elif lr_annealer_type == "const":
        return config.n_epoch

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    
    
class CosineAnnealingLRWithRestart(_LRScheduler):
    def __init__(self, optimizer, config):
        self.n_period = config.lr_period
        self.n_mult = config.lr_mult_coeff
        self.lr_end = config.lr_end

        self.current_epoch = 0
        self.t_end = self.n_period

        # Also calls first epoch
        super().__init__(optimizer, -1)

    def get_lr(self):
        return [self.lr_end + (base_lr - self.lr_end) *
                (1 + math.cos(math.pi * self.current_epoch / self.t_end)) / 2
                for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch
        self.current_epoch += 1

        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

        if self.current_epoch == self.t_end:
            self.current_epoch = 0
            self.t_end = self.n_mult * self.t_end
    
    
class KLAnnealer:
    """
    Control KL loss weight to increase linearly
    Adapted from `moses`
    """
    
    def __init__(self, n_epoch, kl_e_start, kl_w_start, kl_w_end):
        self.i_start = kl_e_start
        self.w_start = kl_w_start
        self.w_max = kl_w_end
        self.n_epoch = n_epoch

        self.inc = (self.w_max - self.w_start) / (self.n_epoch - self.i_start)

    def __call__(self, i):
        k = (i - self.i_start) if i >= self.i_start else 0
        return self.w_start + k * self.inc
    

def train_epoch(model, epoch, data, kl_weight, optimizer=None):
    
    if optimizer is None:
        model.eval()
    else:
        model.train()
        
    kl_loss_values = CircularBuffer(config.loss_buf_sz)
    recon_loss_values = CircularBuffer(config.loss_buf_sz)
    loss_values = CircularBuffer(config.loss_buf_sz)    
    for input_batch in data:
        
        input_batch = (input_batch[0].to(model.device()), input_batch[1])
        
        ## forward
        kl_loss, recon_loss = model(input_batch)
        loss = kl_weight * kl_loss + recon_loss
        
        ## backward
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(get_trainable_params(model), config.clip_grad)
            optimizer.step()
            
        ## log with buffer: average losses of the last m batches
        kl_loss_values.add(kl_loss.item())
        recon_loss_values.add(recon_loss.item())
        loss_values.add(loss.item())
        lr = (optimizer.param_groups[0]['lr']
                  if optimizer is not None
                  else 0)
        ## print out progress
        kl_loss_value = kl_loss_values.mean()
        recon_loss_value = recon_loss_values.mean()
        loss_value = loss_values.mean()
        postfix = [f'loss={loss_value:.5f}',
                   f'(kl={kl_loss_value:.5f}',
                   f'recon={recon_loss_value:.5f})',
                   f'klw={kl_weight:.5f} lr={lr:.5f}']
        

        data.set_postfix_str(' '.join(postfix))
        
        if str(loss_value) == "nan":
            print(input_batch)
            input_batch = (input_batch[0].cpu(), input_batch[1])
            import pickle
            with open("../tmp/test_batch.bsz-512.pkl","wb") as fo:
                pickle.dump(input_batch, fo)
            break
                
            
    ## return results for this epoch (for tensorboard)
    postfix = {
        'epoch': epoch,
        'kl_weight': kl_weight,
        'lr': lr,
        'kl_loss': kl_loss_value,
        'recon_loss': recon_loss_value,
        'loss': loss_value,
        'mode': 'Eval' if optimizer is None else 'Train'}

    return postfix
            
    
    

def train(model, config, train_dataloader, valid_dataloader=None, logger=None):
    
    device = model.device
    
    ## get optimizer and annealer
    n_epoch = get_n_epoch(config.lr_anr_type)
    optimizer = torch.optim.Adam(get_trainable_params(model),
                                 lr=config.lr_start)
    lr_annealer = get_lr_annealer(optimizer, config) #! [to be done] mechanism
    kl_annealer = get_kl_annealer(n_epoch, config.kl_anr_type)
    
    ## iterative training
    model.zero_grad()
    for epoch in range(n_epoch):
        
        
        #tqdm_data = train_dataloader
        tqdm_data = tqdm(train_dataloader,
                         desc='Training (epoch #{})'.format(epoch))
        
        ## training
        kl_weight = kl_annealer(epoch)
        postfix = train_epoch(model, epoch, tqdm_data, kl_weight, optimizer)
        
        if logger is not None:
            logger.append(postfix)
            logger.save(config.log_path)
            
        #? debug
        break
        
        ## validation
        if valid_dataloader is not None:
            
            #? debug
            #tqdm_data = tqdm(valid_dataloader,
            #             desc='Validation (epoch #{})'.format(epoch))
            tqdm_data = valid_dataloader
            train_epoch(model, epoch, tqdm_data, kl_weight)
        
        ## save model
        if (config.model_save is not None) and \
            (epoch % config.save_frequency == 0):
            model = model.to("cpu")
            torch.save(model.state_dict(),
                       config.model_save[:-3] + "_{:03d}.pt".format(epoch))
            model = model.to(device)
        
        lr_annealer.step()

        
        









In [2]:
############ config

parser = get_parser()
config = parser.parse_args("--device cuda:0 --enc_bidirectional True --n_epoch 2 --train_bsz 512 --kl_e_start=1".split())

device = torch.device(config.device)

set_seed(config.seed)

config

Namespace(clip_grad=50, dec_hid_sz=256, dec_n_layer=1, dec_type='lstm', device='cuda:0', dropout=0.1, emb_sz=128, enc_bidirectional=True, enc_hidden_size=256, enc_num_layers=1, enc_sorted_seq=True, enc_type='lstm', kl_anr_type='const', kl_e_start=1, kl_w_end=1.0, kl_w_start=1.0, ladder_d_size=[512, 256, 128, 64, 32], ladder_z2z_layer_size=[8, 16, 32, 64], ladder_z_size=[64, 32, 16, 8, 4], log_path=None, loss_buf_sz=1000, lr_anr_type='const', lr_end=0.00030000000000000003, lr_mult_coeff=1, lr_n_restarts=10, lr_period=10, lr_start=0.00030000000000000003, model_save=None, n_epoch=2, save_frequency=10, seed=56, test_load='../data/valid.csv', train_bsz=512, train_load='../data/train.csv', valid_load='../data/valid.csv')

In [3]:
############ load training data

print("Loading training set...")
train_split = DatasetSplit("train", config.train_load)
train_dataloader = train_split.get_dataloader(batch_size=config.train_bsz)

if config.valid_load is not None:
    print("Loading validation set...")
    valid_split = DatasetSplit("valid", config.valid_load)
    valid_dataloader = valid_split.get_dataloader(batch_size=config.train_bsz, shuffle=False)
    
vocab = train_split._vocab




Loading training set...
Loading vocab...
Loading validation set...
Loading vocab...


In [4]:
############ get model and train

print("Initializing model...")
model = LVAE(vocab, config).to(device)


## log training process to csv file
logger = Logger() if config.log_path is not None else None

print("Start training...")
train(model, config, train_dataloader, valid_dataloader, logger)

Initializing model...


Training (epoch #0):   0%|          | 0/1741 [00:00<?, ?it/s]

Start training...


Training (epoch #0):  19%|█▉        | 329/1741 [00:32<02:16, 10.32it/s, loss=nan (kl=nan recon=nan) klw=1.00000 lr=0.00030]             

(tensor([[33, 16, 26,  ..., 26,  5, 34],
        [33, 16, 16,  ...,  5, 34, 35],
        [33, 16, 26,  ..., 34, 35, 35],
        ...,
        [33, 16, 16,  ..., 35, 35, 35],
        [33, 20, 16,  ..., 35, 35, 35],
        [33, 16, 16,  ..., 35, 35, 35]], device='cuda:0'), [69, 68, 67, 67, 67, 65, 65, 65, 64, 64, 64, 64, 64, 63, 62, 62, 62, 62, 61, 61, 61, 61, 61, 60, 60, 60, 60, 59, 59, 59, 59, 59, 58, 58, 58, 57, 57, 57, 57, 57, 57, 57, 57, 56, 56, 56, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47

In [5]:
## test nan batch
import pickle
with open("test_batch.bsz-512.pkl","rb") as fi:
    test_batch=pickle.load(fi)

In [7]:
## trained model
model((test_batch[0].to(device), test_batch[1]))

(tensor(nan, device='cuda:0', grad_fn=<AddBackward0>),
 tensor(nan, device='cuda:0', grad_fn=<NllLossBackward>))

In [8]:
## new model on cpu
device2 = torch.device("cpu")
model2 = LVAE(vocab, config).to(device2)
model2((test_batch[0].to(device2), test_batch[1]))

(tensor(12.5467, grad_fn=<AddBackward0>),
 tensor(3.6465, grad_fn=<NllLossBackward>))

In [9]:
## new model on gpu
model3 = LVAE(vocab, config).to(device)
model3((test_batch[0].to(device), test_batch[1]))

(tensor(12.6349, device='cuda:0', grad_fn=<AddBackward0>),
 tensor(3.5910, device='cuda:0', grad_fn=<NllLossBackward>))