In [None]:
import torch
import torch.nn as nn
import yaml
import torch.distributions as distributions
import torch.optim as optim
from critic import Criticnet, SmallMLP
from scorenet import Scorenet
import os
from datasets import toy_data
import numpy as np 
import matplotlib
from utils import keep_grad, approx_jacobian_trace, exact_jacobian_trace, \
    set_random_seed, get_logger, dict2namespace, get_opt, visualize_2d
import importlib
import argparse
import matplotlib.pyplot as plt


%load_ext autoreload
%autoreload 2

# Config

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

config_file = './test_config_toy_2d.yaml'

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

with open(config_file, 'r') as f:
    config = yaml.load(f)
        
cfg = dict2namespace(config)
cfg

In [None]:
def sample_data(data, n_points):
    x = toy_data.inf_train_gen(data, n_points=n_points)
    x = torch.from_numpy(x).type(torch.float32).to(device)
    return x

# Training

In [None]:
score_net = SmallMLP(n_dims=2, n_out=2)
critic_net = SmallMLP(n_dims=2, n_out=2)

critic_net.to(device)
score_net.to(device)

opt_scorenet, scheduler_scorenet = get_opt(score_net.parameters(), cfg.trainer.opt_scorenet)
opt_criticnet, scheduler_criticnet = get_opt(critic_net.parameters(), cfg.trainer.opt_scorenet)

k_iters = 5
e_iters = 1
epochs = 40000
itr = 0

losses = []

for epoch in range(epochs):
    tr_pts = sample_data('pinwheel', 4096).view(-1, 2)
    
    score_net.train()
    critic_net.train()
    opt_scorenet.zero_grad()
    opt_criticnet.zero_grad()

    tr_pts.requires_grad_()
    batch_size = tr_pts.size(0)

    perturbed_points = tr_pts

    score_pred = score_net(perturbed_points)

    critic_output = critic_net(perturbed_points)

    t1 = (score_pred * critic_output).sum(-1)
    t2 = exact_jacobian_trace(critic_output, perturbed_points)

    stein = t1 + t2
    l2_penalty = (critic_output * critic_output).sum(-1).mean()
    loss = stein.mean()

    losses.append(loss)

    cycle_iter = itr % (k_iters + e_iters)

    cpu_loss = loss.detach().cpu().item()
    cpu_t1 = t1.mean().detach().cpu().item()
    cpu_t2 = t2.mean().detach().cpu().item()

    if cycle_iter < k_iters:
        (-loss + l2_penalty).backward()
        opt_criticnet.step()
        log_message = "Epoch %d (critic), Loss=%2.5f t1=%2.5f t2=%2.5f" % (epoch, cpu_loss, cpu_t1, cpu_t2)
    else:
        loss.backward()
        opt_scorenet.step()
        log_message = "Epoch %d (score), Loss=%2.5f t1=%2.5f t2=%2.5f" % (epoch, cpu_loss, cpu_t1, cpu_t2)

    print(log_message)
    itr += 1

#     if itr % cfg.log.save_freq == 0:
#         score_net.cpu()

#         torch.save({
#             'args': args,
#             'state_dict': score_net.state_dict(),
#         }, os.path.join(cfg.log.save_dir, 'checkpt.pth'))

#         score_net.to(device)

#     if itr % cfg.log.viz_freq == 0:
#         plt.clf()

#         #pt_cl, _ = langevin_dynamics(score_net, sigmas, dim=2, eps=1e-4, num_steps=cfg.inference.num_steps)
#         x_final = langevin_dynamics_lsd(score_net, l=1., e=.01, num_points=2048, n_steps=10)

#         visualize_2d(x_final)

#         fig_filename = os.path.join(cfg.log.save_dir, 'figs', 'sample-{:04d}.png'.format(itr))
#         os.makedirs(os.path.dirname(fig_filename), exist_ok=True)
#         plt.savefig(fig_filename)


#         visualize_2d(perturbed_points)

#         fig_filename = os.path.join(cfg.log.save_dir, 'figs', 'perturbed-{:04d}.png'.format(itr))
#         os.makedirs(os.path.dirname(fig_filename), exist_ok=True)
#         plt.savefig(fig_filename)

#     itr += 1

# Plot loss curve

In [None]:
fig = plt.figure(figsize=(10, 10))
plt.plot(losses);

# Plot learned gradient field

In [None]:
import numpy as np

X = np.linspace(-4, 4, 30)
Y = np.linspace(-4, 4, 30)
grid = np.stack(np.meshgrid(X, Y), axis=-1).reshape(1, -1, 2)
grid_tensor = torch.tensor(grid).float().to(device)
grad_field = score_net(grid_tensor)
grad_field_np = grad_field.detach().cpu().numpy()
grad_field_np = grad_field_np[0]

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
q = ax.quiver(grid[..., 0], grid[..., 1], grad_field_np[..., 0], grad_field_np[..., 1])
ax.quiverkey(q, X=0.01, Y=0.01, U=0.01, label='Quiver key, length = 10', labelpos='E')
sample_data = toy_data.inf_train_gen('pinwheel', n_points=4096)
ax.scatter(sample_data[:, 0], sample_data[:, 1])
plt.show()

# Visualize sampling

In [None]:
def get_prior(num_points, inp_dim):
    return (torch.rand(num_points, inp_dim) * 2. - 1.) * 4

def langevin_dynamics_lsd_test(f, l=1., e=.01, num_points=2048, n_steps=100, anneal=None):
        x_k = get_prior(num_points, 2).cuda()
        # sgld
        if anneal == "lin":
            lrs = list(reversed(np.linspace(e, l, n_steps)))
        elif anneal == "log":
            lrs = np.logspace(np.log10(l), np.log10(e))
        else:
            lrs = [l for _ in range(n_steps)]
        x_k_list = [x_k.clone()]
        for this_lr in lrs:
            x_k += this_lr * f(x_k) + torch.randn_like(x_k) * e
            x_k_list.append(x_k.clone())
        final_samples = x_k.detach()
        return final_samples, x_k_list

In [None]:
x_final, all_samples = langevin_dynamics_lsd_test(score_net, l=0.004, e=.05, num_points=2048, n_steps=150)
all_samples = [item.detach().cpu().numpy() for item in all_samples]

In [None]:
fig = plt.figure(figsize=(10, 10))
visualize_2d(x_final)

In [None]:
from matplotlib import animation, rc
rc('animation', html='html5')

def animate(i, data, scat):
    pts = data[i]
    scatter.set_offsets(pts)
    return scatter,

fig = plt.figure(figsize=(10, 10))
plt.xlim(-6, 6)
plt.ylim(-6, 6)
initial_pts = all_samples[0]
scatter = plt.scatter(initial_pts[:, 0], initial_pts[:, 1])
anim = animation.FuncAnimation(fig, animate, frames=range(len(all_samples) - 1),
                                  fargs=(all_samples, scatter), interval=200)
anim