In [1]:
import sys
sys.path.append("/work01/home/wxxie/project/drug-gen/mollvae/MolLVAE/code")
from dataset import DatasetSplit
from opt import get_parser
from model.model import LVAE

In [5]:
from moses.utils import CircularBuffer, Logger

import torch
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import _LRScheduler

from tqdm import tqdm
import numpy as np
import random
import math

############ utils func and class

def get_trainable_params(model):
    
    return (p for p in model.parameters() if p.requires_grad)

def get_lr_annealer(optimizer, config):
    
    if config.lr_anr_type == "SGDR":
        return CosineAnnealingLRWithRestart(optimizer, config)
    elif config.lr_anr_type == "const":
        pass # TODO
    else:
        raise ValueError("Invalid lr annealer type")

    
def get_kl_annealer(n_epoch, kl_anr_type):
    
    if kl_anr_type == "const":
        return KLAnnealer(n_epoch, config.kl_e_start, config.kl_w_start, config.kl_w_start)
    elif kl_anr_type == "linear_inc":
        return KLAnnealer(n_epoch, config.kl_e_start, config.kl_w_start, config.kl_w_end)
    else:
        raise ValueError("Invalid kl annealer type")
    
def get_n_epoch(lr_annealer_type):
    
    if lr_annealer_type == "SGDR":
        n_epoch = sum(config.lr_period * (config.lr_mult_coeff ** i)
            for i in range(config.lr_n_restarts))
        print(f"Using SGDR annealer. Will train {n_epoch} epoches.")
        return n_epoch
    elif lr_annealer_type == "const":
        return config.n_epoch
    else:
        raise ValueError("Invalid lr annealer type")

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    
    
class CosineAnnealingLRWithRestart(_LRScheduler):
    def __init__(self, optimizer, config):
        self.n_period = config.lr_period
        self.n_mult = config.lr_mult_coeff
        self.lr_end = config.lr_end

        self.current_epoch = 0
        self.t_end = self.n_period

        # Also calls first epoch
        super().__init__(optimizer, -1)

    def get_lr(self):
        return [self.lr_end + (base_lr - self.lr_end) *
                (1 + math.cos(math.pi * self.current_epoch / self.t_end)) / 2
                for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch
        self.current_epoch += 1

        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

        if self.current_epoch == self.t_end:
            self.current_epoch = 0
            self.t_end = self.n_mult * self.t_end
    
    
class KLAnnealer:
    """
    Control KL loss weight to increase linearly
    Adapted from `moses`
    """
    
    def __init__(self, n_epoch, kl_e_start, kl_w_start, kl_w_end):
        self.i_start = kl_e_start
        self.w_start = kl_w_start
        self.w_max = kl_w_end
        self.n_epoch = n_epoch

        self.inc = (self.w_max - self.w_start) / (self.n_epoch - self.i_start)

    def __call__(self, i):
        # min(i)=0
        k = (i - self.i_start) if i >= self.i_start else 0
        return self.w_start + k * self.inc
    

def train_epoch(model, epoch, data, kl_weight, optimizer=None):
    
    if optimizer is None:
        model.eval()
    else:
        model.train()
        
    kl_loss_values = CircularBuffer(config.loss_buf_sz)
    recon_loss_values = CircularBuffer(config.loss_buf_sz)
    loss_values = CircularBuffer(config.loss_buf_sz)    
    for input_batch in data:
        input_batch = (input_batch[0].to(model.device()), input_batch[1])
        
        ## forward
        kl_loss, recon_loss = model(input_batch)
        loss = kl_weight * kl_loss + recon_loss
        
        ## backward
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(get_trainable_params(model), config.clip_grad)
            optimizer.step()
            
        ## log with buffer: average losses of the last m batches
        kl_loss_values.add(kl_loss.item())
        recon_loss_values.add(recon_loss.item())
        loss_values.add(loss.item())
        lr = (optimizer.param_groups[0]['lr']
                  if optimizer is not None
                  else 0)
        ## print out progress
        kl_loss_value = kl_loss_values.mean()
        recon_loss_value = recon_loss_values.mean()
        loss_value = loss_values.mean()
        postfix = [f'loss={loss_value:.5f}',
                   f'(kl={kl_loss_value:.5f}',
                   f'recon={recon_loss_value:.5f})',
                   f'klw={kl_weight:.5f} lr={lr:.5f}']
        data.set_postfix_str(' '.join(postfix))
            
    ## return results for this epoch (for tensorboard)
    postfix = {
        'epoch': epoch,
        'kl_weight': kl_weight,
        'lr': lr,
        'kl_loss': kl_loss_value,
        'recon_loss': recon_loss_value,
        'loss': loss_value,
        'mode': 'Eval' if optimizer is None else 'Train'}

    return postfix
            
    
    

def train(model, config, train_dataloader, valid_dataloader=None, logger=None):
    
    device = model.device()
    
    ## get optimizer and annealer
    n_epoch = get_n_epoch(config.lr_anr_type)
    optimizer = torch.optim.Adam(get_trainable_params(model),
                                 lr=config.lr_start)
    lr_annealer = get_lr_annealer(optimizer, config) #! [to be done] mechanism
    kl_annealer = get_kl_annealer(n_epoch, config.kl_anr_type)
    
    ## iterative training
    model.zero_grad()
    for epoch in range(n_epoch):
        
        tqdm_data = tqdm(train_dataloader,
                         desc='Training (epoch #{})'.format(epoch))
        
        ## training
        kl_weight = kl_annealer(epoch)
        postfix = train_epoch(model, epoch, tqdm_data, kl_weight, optimizer)
        
        if logger is not None:
            logger.append(postfix)
            logger.save(config.log_path)
        
        ## validation
        if valid_dataloader is not None:
            
            tqdm_data = tqdm(valid_dataloader,
                         desc='Validation (epoch #{})'.format(epoch))
            postfix = train_epoch(model, epoch, tqdm_data, kl_weight)
            
            if logger is not None:
                logger.append(postfix)
                logger.save(config.log_path)
        
        ## save model
        if (config.model_save is not None) and \
            (epoch % config.save_frequency == 0):
            model = model.to("cpu")
            torch.save(model.state_dict(),
                       config.model_save[:-3] + "_{:03d}.pt".format(epoch))
            model = model.to(device)
        
        lr_annealer.step()

In [6]:
############ config

parser = get_parser()
config = parser.parse_args("--device cuda:2 \
                           --kl_anr_type linear_inc \
                           --lr_anr_type SGDR \
                           --model_save test_train/model.pt \
                           --log_path test_train/log.csv".split())

device = torch.device(config.device)

set_seed(config.seed)

config

Namespace(clip_grad=50, dec_hid_sz=256, dec_n_layer=1, dec_type='lstm', device='cuda:2', dropout=0.1, emb_sz=128, enc_bidirectional=True, enc_hidden_size=256, enc_num_layers=1, enc_sorted_seq=True, enc_type='lstm', kl_anr_type='linear_inc', kl_e_start=0, kl_w_end=10.0, kl_w_start=1.0, ladder_d_size=[128, 64, 32], ladder_z2z_layer_size=[8, 16], ladder_z_size=[16, 8, 4], log_path='test_train/log.csv', loss_buf_sz=20, lr_anr_type='SGDR', lr_end=1e-06, lr_mult_coeff=1, lr_n_restarts=10, lr_period=10, lr_start=0.00030000000000000003, model_save='test_train/model.pt', n_epoch=100, save_frequency=10, seed=56, test_load='../data/test.csv', train_bsz=512, train_load='../data/train.csv', valid_load='../data/valid.csv')

In [7]:



############ load training data

print("Loading training set...")
train_split = DatasetSplit("train", config.train_load)
train_dataloader = train_split.get_dataloader(batch_size=config.train_bsz)

if config.valid_load is not None:
    print("Loading validation set...")
    valid_split = DatasetSplit("valid", config.valid_load)
    valid_dataloader = valid_split.get_dataloader(batch_size=config.train_bsz, shuffle=False)
    
vocab = train_split._vocab




############ get model and train

print("Initializing model...")
model = LVAE(vocab, config).to(device)


## log training process to csv file
logger = Logger() if config.log_path is not None else None

print("Start training...")
train(model, config, train_dataloader, valid_dataloader, logger)

Loading training set...
Loading vocab...
Loading validation set...
Loading vocab...


Training (epoch #0):   0%|          | 1/1741 [00:00<04:52,  5.94it/s, loss=6.60455 (kl=3.00358 recon=3.60097) klw=1.00000 lr=0.00029]

Initializing model...
Start training...
Using SGDR annealer. Will train 100 epoches.


Training (epoch #0): 100%|██████████| 1741/1741 [02:54<00:00,  9.99it/s, loss=0.84469 (kl=0.00003 recon=0.84466) klw=1.00000 lr=0.00029]
Validation (epoch #0): 100%|██████████| 218/218 [00:07<00:00, 29.86it/s, loss=0.84079 (kl=0.00014 recon=0.84065) klw=1.00000 lr=0.00000]
Training (epoch #1): 100%|██████████| 1741/1741 [02:51<00:00, 10.13it/s, loss=0.77236 (kl=0.00000 recon=0.77236) klw=1.09000 lr=0.00027]
Validation (epoch #1): 100%|██████████| 218/218 [00:06<00:00, 33.00it/s, loss=0.77141 (kl=0.00000 recon=0.77141) klw=1.09000 lr=0.00000]
Training (epoch #2): 100%|██████████| 1741/1741 [02:53<00:00, 10.04it/s, loss=0.74144 (kl=0.00000 recon=0.74144) klw=1.18000 lr=0.00024]
Validation (epoch #2): 100%|██████████| 218/218 [00:06<00:00, 35.58it/s, loss=0.74014 (kl=0.00000 recon=0.74014) klw=1.18000 lr=0.00000]
Training (epoch #3): 100%|██████████| 1741/1741 [02:54<00:00,  9.97it/s, loss=0.72358 (kl=0.00000 recon=0.72358) klw=1.27000 lr=0.00020]
Validation (epoch #3): 100%|██████████| 2

Validation (epoch #29): 100%|██████████| 218/218 [00:06<00:00, 34.01it/s, loss=0.64358 (kl=-0.00000 recon=0.64358) klw=3.61000 lr=0.00000]
Training (epoch #30): 100%|██████████| 1741/1741 [02:49<00:00, 10.28it/s, loss=0.64429 (kl=-0.00000 recon=0.64429) klw=3.70000 lr=0.00029]
Validation (epoch #30): 100%|██████████| 218/218 [00:07<00:00, 30.79it/s, loss=0.64700 (kl=-0.00000 recon=0.64700) klw=3.70000 lr=0.00000]
Training (epoch #31): 100%|██████████| 1741/1741 [02:55<00:00,  9.94it/s, loss=0.63984 (kl=0.00000 recon=0.63983) klw=3.79000 lr=0.00027] 
Validation (epoch #31): 100%|██████████| 218/218 [00:06<00:00, 34.40it/s, loss=0.64357 (kl=0.00000 recon=0.64357) klw=3.79000 lr=0.00000]
Training (epoch #32): 100%|██████████| 1741/1741 [02:53<00:00, 10.05it/s, loss=0.63779 (kl=-0.00000 recon=0.63779) klw=3.88000 lr=0.00024]
Validation (epoch #32): 100%|██████████| 218/218 [00:06<00:00, 34.80it/s, loss=0.64125 (kl=-0.00000 recon=0.64125) klw=3.88000 lr=0.00000]
Training (epoch #33): 100%|█

KeyboardInterrupt: 

In [None]:
## plot training process
import pandas as pd
log_path = config.log_path
df = pd.read_csv(log_path)


In [None]:
## 测试 nan batch 是否有奇怪之处
print(vocab.ids2string([33, 21, 14, 16,  5, 20, 16,  1, 14, 20, 26,  6, 26, 26, 26, 26,  1, 24,
         20,  3, 25,  1, 14, 21,  2, 24, 21,  4, 25,  2, 26,  6,  2, 23, 16,  5,
         14, 16, 26,  5, 26, 26, 26, 26, 28,  5, 34]))
print(vocab.ids2string([33, 21, 14, 16,  1, 16, 20, 16,  1, 14, 21,  2, 26,  5, 26, 26, 26, 29,
          5,  2, 20, 20, 14, 16, 26,  5, 26, 26, 26,  1, 16,  1, 14, 21,  2, 21,
          2, 26, 26,  5, 34, 35, 35, 35, 35, 35, 35]))
print(vocab.ids2string([33, 16, 21, 26,  5, 26, 26, 26, 26,  1, 21, 16,  2, 26,  5, 16,  1, 14,
         21,  2, 16, 14, 16, 26,  5, 26, 28,  1, 16,  2, 26,  6, 26, 26, 26, 26,
         26,  5,  6, 34, 35, 35, 35, 35, 35, 35, 35]))
print(vocab.ids2string([33, 16, 16,  1, 16,  2, 16,  1, 14, 21,  2, 20,  5, 16, 16, 26,  6, 26,
         26,  1,  4, 26,  7, 26, 32, 26,  1, 20,  2, 28,  7,  2, 26, 26, 26,  6,
          5, 34, 35, 35, 35, 35, 35, 35, 35, 35, 35]))
print(vocab.ids2string([33, 21, 14, 16,  1, 20, 26,  5, 28, 28, 26,  1,  4, 26,  6, 26, 26, 28,
         26, 26,  6,  2, 29,  5,  2, 26,  5, 26, 26, 26, 26, 26,  5, 34, 35, 35,
         35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35]))

In [None]:
## test nan batch
import pickle
with open("test_batch.bsz-512.pkl","rb") as fi:
    test_batch=pickle.load(fi)

## trained model
print(model((test_batch[0].to(device), test_batch[1])))

## new model on gpu
model2 = LVAE(vocab, config).to(device)
print(model2((test_batch[0].to(device), test_batch[1])))

## new model on cpu
device2 = torch.device("cpu")
model3 = LVAE(vocab, config).to(device2)
print(model3((test_batch[0].to(device2), test_batch[1])))



In [None]:
## vocab
vocab.i2c

In [None]:
config