In [1]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from torch import optim
from tqdm import tqdm_notebook as tqdm

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

from models import GMM_teacher, GMM
from loss import WLoss
from plotting_utils import plot_observables
from models import Walkers, WeightedWalkers

In [2]:
#This can be written in a smarter way like in the algorithm on the paper        

seed=0
teach=False
sched=False
resamp=True
max_var=1
contr=False
zeros=True
L2=0
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')
torch.manual_seed(seed)
dim = 2 # dimension
n_iter = int(4e3) # total number of GD iterations
hx = 1e-3 # time stepsize for walkers x
n_walker = 5000  # number of walkers
lr =5e-6


In [3]:
teacher = GMM_teacher(dim,device=device) # teacher model
model = GMM(dim,hidden_dim = 7, device=device).to(device) # MLP model 
model2 = GMM(dim,hidden_dim = 7, device=device).to(device) # MLP model
wloss_func = WLoss()

data = teacher.sample(1e3).to(device) # sample data
data_KL = teacher.sample(3e3).to(device)

In [4]:
#optimizers
optimizer = optim.Adam(model.parameters(), lr= lr,weight_decay=L2) # ADAM 
optimizer_compare = optim.Adam(model2.parameters(), lr = lr,weight_decay=L2) # ADAM 

initial_lr = 1
min_lr = 2e-7
fact=0.99

#schedulers

# lambda1 = lambda epoch: max(0.99998 ** epoch, min_lr) # exponential decay 
lambda1 = lambda epoch: (initial_lr - epoch*(initial_lr - min_lr)/n_iter) # linear decay
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',verbose=True,patience=1500,factor=fact)#torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
scheduler_compare = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',verbose=True,patience=1500,factor=fact)#torch.optim.lr_scheduler.LambdaLR(optimizer_compare, lr_lambda=lambda1)


In [5]:
total_norm_all = torch.zeros([n_iter, 1],device=device) # monitor the L1 norm of gradients of parameters

W = WeightedWalkers(init_data=model.sample(n_walker), hx=hx, device=device, clip_lim=100, use_resampling=True, max_var=max_var)
W_PCD = Walkers(init_data=model2.sample(n_walker), hx=hx, device=device, clip_lim = 100) # not sure about the 100

# approximation of integral of \rho log \rho as sum on modes of p_i\log(p_i)+p_i Integral(\rho_i \log (\rho_i))
# https://gregorygundersen.com/blog/2020/09/01/gaussian-entropy/


In [6]:
log_weight = torch.zeros((n_walker,),device=device)
logwal=torch.tensor([np.log(n_walker)],device=device)

In [7]:
n_sample = 800 # construct a uniform grid in R^2

x_limit = 100
y_limit = 100

KL = torch.zeros(int(n_iter/1000)) 
KL2 = torch.zeros(int(n_iter/1000)) 

ce=torch.zeros(int(n_iter/1000)) 
ce2=torch.zeros(int(n_iter/1000))
old_energy = model(W.walkers).clone().detach()
part_func=torch.zeros(int(n_iter), requires_grad=False)

x1_sample = np.linspace(-x_limit,x_limit,num=n_sample)
x2_sample = np.linspace(-x_limit,x_limit,num=n_sample)

x1_meshgrid,x2_meshgrid = np.meshgrid(x1_sample,x2_sample)

x1 = np.reshape(x1_meshgrid,(1,np.power(n_sample,2)))
x2 = np.reshape(x2_meshgrid,(1,np.power(n_sample,2)))

test_all = torch.t(torch.from_numpy(np.concatenate((x1,x2,np.zeros((dim-2,x1.shape[1]))))).float()).to(device)


In [8]:
q=0

foldername = "NN_fromzero_savez_"+str(zeros)+"_sched_"+str(sched)+"resampl"+str(resamp)+"var"+str(max_var)+"_teachstudGMM"+str(dim)+"d_seed"+str(seed)+"_lr"+str(lr)+"_ULA"+str(hx)+"_plat"+str(fact)+"_reg_+"+str(L2)+"/" # folder for saving all 
#the data
if not os.path.exists(foldername): # if not exists, create one
    os.makedirs(foldername)

log_Zres=torch.zeros(1,device=device)


In [9]:
for t in tqdm(range(n_iter)): # main loop

    # Preparing for Langevin on walkers and data
    
    data.requires_grad = True
    data=teacher.sample(1e3).to(device)
    data.requires_grad = False 

    #   Weighted-CD : using Jarczinski
    model.requires_grad(False)
    W.Langevin_step(model)
    log_weight_update_1 = W.compute_delta(model)#alpha k (x_k,x_{k+1}) 
    normalized_weights = W.get_normalized_weigths()
    model.requires_grad(True)
    loss = wloss_func(model(data), model(W.old_walkers), normalized_weights)
    loss.backward() # optimize the parameters 
    optimizer.step()
    optimizer.zero_grad() # clean the grad of parameters 
    part_func[t]=torch.logsumexp(W.log_weights,0)-logwal+log_Zres   
    model.requires_grad(False)
    log_weight_update_2 = W.compute_delta(model)     #alpha_{k+1}(x_{k+1},x_k)
    W.update_weights(log_weight_update_2 - log_weight_update_1)
    W.resample() #resampling step

    #    PCD : Langevin ULA 
    model2.requires_grad(False) 
    W_PCD.Langevin_step(model2)
    model2.requires_grad(True)
    loss_compare = wloss_func(model2(data), model2(W_PCD.old_walkers))
    loss_compare.backward()
    optimizer_compare.step()
    optimizer_compare.zero_grad() # clean the grad of parameters 

    if t%200==0: 
        print(loss_compare, loss)

    #Scheduler
#     if sched==True and t>n_iter/2:
#         scheduler.step(loss) # adjust the learning rate 
#         scheduler_compare.step(loss_compare) # adjust the learning rate 
         
    # plotting 
    if t % 20000== 0 and t>0:
        
        plot_observables(
                model, 
                model2, 
                teacher, 
                test_all, 
                x1_sample, 
                x2_sample, 
                n_sample, 
                W.walkers, 
                W_PCD.walkers, 
                foldername)


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for t in tqdm(range(n_iter)): # main loop


  0%|          | 0/4000 [00:00<?, ?it/s]

tensor(203.4473, device='cuda:0', grad_fn=<SubBackward0>) tensor(204.0476, device='cuda:0', grad_fn=<SubBackward0>)
tensor(202.9598, device='cuda:0', grad_fn=<SubBackward0>) tensor(203.5896, device='cuda:0', grad_fn=<SubBackward0>)
tensor(202.9134, device='cuda:0', grad_fn=<SubBackward0>) tensor(203.5756, device='cuda:0', grad_fn=<SubBackward0>)
tensor(202.9581, device='cuda:0', grad_fn=<SubBackward0>) tensor(203.6526, device='cuda:0', grad_fn=<SubBackward0>)
tensor(202.2712, device='cuda:0', grad_fn=<SubBackward0>) tensor(202.9957, device='cuda:0', grad_fn=<SubBackward0>)
tensor(202.7197, device='cuda:0', grad_fn=<SubBackward0>) tensor(203.4774, device='cuda:0', grad_fn=<SubBackward0>)


KeyboardInterrupt: 

In [None]:
data_foldername=foldername+'data/' 
if not os.path.exists(data_foldername): # if not exists, create one
    os.makedirs(data_foldername)

else:
  np.save(data_foldername+'teacher_mean',teacher.mean.cpu().numpy())
  np.save(data_foldername+'teacher_logstd',teacher.log_std.cpu().numpy())
  np.save(data_foldername+'teacher_mass',teacher.mix_logits.cpu().numpy())
  np.save(data_foldername+'walkers_jarz',W.walkers.cpu().numpy())
  np.save(data_foldername+'walkers_nojarz',W_PCD.walkers.cpu().numpy())
  np.save(data_foldername+'weights',log_weight.cpu().numpy())
  np.save(data_foldername+'KLJ',KL.cpu().numpy()) 
  np.save(data_foldername+'KLnoJ',KL2.cpu().numpy()) 
  np.save(data_foldername+'ceJ',ce.cpu().numpy()) 

  np.save(data_foldername+'partition',part_func.detach().cpu().numpy()) 
  torch.save(model, data_foldername+'modelJ')
  torch.save(teacher, data_foldername+'teacher')
  torch.save(model2, data_foldername+'modelnoJ')



In [None]:


    
fig, axes = plt.subplots(1,3, figsize=(12,4))
area = 10
Energy_landscape = (model(test_all)-model(test_all).min()).detach() 
levels = np.linspace(0,30,20)
axes[0].set_title("Energy Function (w/ Jarzynski)")
contour = axes[0].contourf(x1_sample,x2_sample,Energy_landscape.cpu().numpy().reshape((n_sample,n_sample)),levels = levels)
cbar3 = plt.colorbar(contour)
axes[0].set_xlim(-100,100)
axes[0].set_ylim(-100,100)

Energy_landscape = (model2(test_all)-model2(test_all).min()).detach() 
contour = axes[1].contourf(x1_sample,x2_sample,Energy_landscape.cpu().numpy().reshape((n_sample,n_sample)),levels = levels)
axes[1].set_title("Energy Function (w/o Jarzynski)")
cbar3 = plt.colorbar(contour)
axes[1].set_xlim(-100,100)
axes[1].set_ylim(-100,100)

Energy_landscape = (teacher(test_all)-teacher(test_all).min()).detach() 
contour = axes[2].contourf(x1_sample,x2_sample,Energy_landscape.cpu().numpy().reshape((n_sample,n_sample)),levels = levels)
axes[2].set_title("True Energy Function")
axes[2].set_xlim(-100,100)
axes[2].set_ylim(-100,100)
filename = str(t) + "_data.png"
plt.savefig(foldername + filename, dpi=300, bbox_inches='tight', transparent=True,facecolor='w')
plt.close()

# fig4 = plt.figure(4);fig4.clf
# min1=np.abs(np.min(ce.detach().cpu().numpy()[0:t]))
# min2=np.abs(np.min(ce2.detach().cpu().numpy()[0:t]))
# shift=np.max((min1,min2))+0.1
# plt.semilogy(ce.detach().cpu().numpy()[0:t]+shift,label = "W/ Jarzynski")
# plt.semilogy(ce2.detach().cpu().numpy()[0:t]+shift,label = "W/o Jarzynski")
# plt.title("Loss")
# plt.xlabel("Adam Iterations")
# plt.ylabel("Loss")
# plt.legend()
# filename = str(t) + "_ce.png"
# plt.savefig(foldername + filename, dpi=300, bbox_inches='tight', transparent=True,facecolor='w')
# plt.close()




