In [None]:
import yaml
import torch
import torch.nn as nn
import numpy as np
import tqdm
import os
import argparse
from torch.backends import cudnn

cudnn.benchmark = True

def delete_module(name):
    import sys
    del sys.modules[name]

In [None]:
# delete_module('scorenet')
# from scorenet import Scorenet

In [None]:
# delete_module('critic')
# from critic import Criticnet

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
config_file = './test_config.yaml'

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

with open(config_file, 'r') as f:
    config = yaml.load(f)
        
cfg = dict2namespace(config)
cfg

In [None]:
cfg.data

In [None]:
cfg.trainer

In [None]:
cfg.models

In [None]:
cfg.inference

In [None]:
cfg.viz

# Train

In [None]:
from scorenet import Scorenet
from critic import Criticnet
from utils import get_opt, approx_jacobian_trace, set_random_seed
from data_loader import get_data
import random

In [None]:
set_random_seed(getattr(cfg.trainer, "seed", 666))

# load data
train_data = get_data(cfg.data, 0)
tr_pts = train_data['tr_points'].unsqueeze(0)
te_pts = train_data['te_points'].unsqueeze(0)
tr_pts = tr_pts.to(device)
te_pts = te_pts.to(device)
tr_pts.requires_grad_()
te_pts.requires_grad_()

In [None]:
# sigmas
if hasattr(cfg.trainer, "sigmas"):
    sigmas = cfg.trainer.sigmas
else:
    sigma_begin = float(cfg.trainer.sigma_begin)
    sigma_end = float(cfg.trainer.sigma_end)
    num_classes = int(cfg.trainer.sigma_num)
    sigmas = np.exp(np.linspace(np.log(sigma_begin), np.log(sigma_end), num_classes))
print("Sigma:, ", sigmas)

In [None]:
# score net
score_net = Scorenet()
critic_net = Criticnet()
critic_net.to(device)
score_net.to(device)

print(score_net)
print(critic_net)

# optimizer
opt_scorenet, scheduler_scorenet = get_opt(score_net.parameters(), cfg.trainer.opt_scorenet)
opt_criticnet, scheduler_criticnet = get_opt(critic_net.parameters(), cfg.trainer.opt_scorenet)

# training
start_epoch = 0
print("Start epoch: %d End epoch: %d" % (start_epoch, cfg.trainer.epochs))
k_iters = 10
e_iters = 1
for epoch in range(start_epoch, cfg.trainer.epochs):
    score_net.train()
    critic_net.train()
    opt_scorenet.zero_grad()
    opt_criticnet.zero_grad()
    
    labels = torch.randint(0, len(sigmas), (1,), device=tr_pts.device)
    used_sigmas = torch.tensor(np.array(sigmas))[labels].float().view(1, 1).cuda()
    
    perturbed_points = tr_pts + torch.randn_like(tr_pts) * used_sigmas
    
    score_pred = score_net(perturbed_points, used_sigmas)
    critic_output = critic_net(perturbed_points, used_sigmas)
    
    t1 = (score_pred * critic_output).sum(-1)
    t2 = approx_jacobian_trace(critic_output, perturbed_points)
    stein = t1 + t2
    l2_penalty = (critic_output * critic_output).sum(-1).mean()
    loss = stein.mean() - 0.001 * l2_penalty
    
    cycle_iter = epoch % (k_iters + e_iters)
    if cycle_iter < k_iters:
        (-loss).backward()
        opt_criticnet.step()
    else:
        loss.backward()
        opt_scorenet.step()

    cpu_loss = loss.detach().cpu().item()
    print("Epoch %d Loss %2.5f" % (epoch, cpu_loss))