In [None]:
## Code originally in https://arxiv.org/pdf/2101.08176.pdf, slightly adapted for our milestone

import base64
import io
import pickle
import numpy as np
import torch
import packaging.version
%matplotlib inline
if torch.cuda.is_available():
    torch_device = 'cuda'
    float_dtype = np.float32 # single torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
    torch_device = 'cpu'
    float_dtype = np.float64 # double torch.set_default_tensor_type(torch.DoubleTensor)
from IPython.display import display

TORCH VERSION: 1.8.1+cu101
TORCH DEVICE: cuda


In [None]:
def torch_mod(x):
    return torch.remainder(x, 2*np.pi)
def torch_wrap(x):
    return torch_mod(x+np.pi) - np.pi
def grab(var):
    return var.detach().cpu().numpy()

In [None]:
class SimpleNormal:  # prior distribution
    def __init__(self, loc, var):
        self.dist = torch.distributions.normal.Normal(
            torch.flatten(loc), torch.flatten(var))
        self.shape = loc.shape
    def log_prob(self, x):
        logp = self.dist.log_prob(x.reshape(x.shape[0], -1))
        return torch.sum(logp, dim=1)
    def sample_n(self, batch_size):
        x = self.dist.sample((batch_size,))
        return x.reshape(batch_size, *self.shape)

In [None]:
def apply_flow_to_prior(prior, coupling_layers, *, batch_size): # normalizing flow
    x = prior.sample_n(batch_size)
    logq = prior.log_prob(x)
    for layer in coupling_layers:
        x, logJ = layer.forward(x)
        logq = logq.to(torch_device) - logJ.to(torch_device)
    return x.to(torch_device), logq.to(torch_device)

In [None]:
class ScalarPhi4Action:  # action for phi^4 2D scalar field theory
    def __init__(self, M2, lam):
        self.M2 = M2
        self.lam = lam
    def __call__(self, cfgs):
        # potential term
        action_density = self.M2*cfgs**2 + self.lam*cfgs**4 # kinetic term (discrete Laplacian)
        Nd = len(cfgs.shape)-1
        dims = range(1,Nd+1)
        for mu in dims:
            action_density += 2*cfgs**2
            action_density -= cfgs*torch.roll(cfgs, -1, mu)
            action_density -= cfgs*torch.roll(cfgs, 1, mu)
        return torch.sum(action_density, dim=tuple(dims))

In [None]:
def make_checker_mask(shape, parity):  # checkerboard pattern
    checker = torch.ones(shape, dtype=torch.uint8) - parity
    checker[::2, ::2] = parity
    checker[1::2, 1::2] = parity
    return checker.to(torch_device)

In [None]:
class AffineCoupling(torch.nn.Module):  # Affine coupling layer
    def __init__(self, net, *, mask_shape, mask_parity):
        super().__init__()
        self.mask = (make_checker_mask(mask_shape, mask_parity)).to(torch_device)
        self.net = net.to(torch_device)
    def forward(self, x):
        x = x.to(torch_device)
        x_frozen = (self.mask).to(torch_device) * x
        x_active = (1 - self.mask) * x
        net_out = self.net(x_frozen.unsqueeze(1))
        s, t = net_out[:,0], net_out[:,1]
        fx = (1 - self.mask) * t + x_active * torch.exp(s) + x_frozen
        axes = range(1,len(s.size()))
        logJ = torch.sum((1 - self.mask) * s, dim=tuple(axes))
        return fx, logJ
    def reverse(self, fx):
        fx_frozen = self.mask * fx
        fx_active = (1 - self.mask) * fx
        net_out = self.net(fx_frozen.unsqueeze(1))
        s, t = net_out[:,0], net_out[:,1]
        x = (fx_active - (1 - self.mask) * t) * torch.exp(-s) + fx_frozen
        axes = range(1,len(s.size()))
        logJ = torch.sum((1 - self.mask)*(-s), dim=tuple(axes))
        return x, logJ

In [None]:
def make_conv_net(*, hidden_sizes, kernel_size, in_channels, out_channels, use_final_tanh):   # CNN
    sizes = [in_channels] + hidden_sizes + [out_channels]
    padding_size = (kernel_size // 2)
    net = []
    for i in range(len(sizes) - 1):
        net.append(torch.nn.Conv2d(
            sizes[i], sizes[i+1], kernel_size, padding=padding_size, 
            stride=1, padding_mode='circular'))
        if i != len(sizes) - 2:
            net.append(torch.nn.LeakyReLU())
        else:
            if use_final_tanh:
                net.append(torch.nn.Tanh())
    return torch.nn.Sequential(*net)

In [None]:
def make_phi4_affine_layers(*, n_layers, lattice_shape, hidden_sizes, kernel_size):
    layers = []
    for i in range(n_layers):
        parity = i % 2
        net = make_conv_net(
            in_channels=1, out_channels=2, hidden_sizes=hidden_sizes,
            kernel_size=kernel_size, use_final_tanh=True)
        coupling = AffineCoupling(net, mask_shape=lattice_shape, mask_parity=parity)
        layers.append(coupling)
    return torch.nn.ModuleList(layers)

In [None]:
def calc_dkl(logp, logq):   # reverse loss
    return (logq - logp).mean()
def train_step(model, action, loss_fn, optimizer, metrics): 
    layers, prior = model['layers'], model['prior'] 
    optimizer.zero_grad()
    x, logq = apply_flow_to_prior(prior, layers, batch_size=batch_size)
    logp = -action(x)
    loss = calc_dkl(logp, logq)
    loss.backward()
    optimizer.step()
    metrics['loss'].append(grab(loss)) 
    metrics['logp'].append(grab(logp)) 
    metrics['logq'].append(grab(logq))
    metrics['ess'].append(grab( compute_ess(logp, logq) ))

In [None]:
def compute_ess(logp, logq): # ESS metric
    logw = logp - logq
    log_ess = 2*torch.logsumexp(logw, dim=0) - torch.logsumexp(2*logw, dim=0)
    ess_per_cfg = torch.exp(log_ess) / len(logw)
    return ess_per_cfg
def print_metrics(history, avg_last_N_epochs): 
    print(f'== Era {era} | Epoch {epoch} metrics ==') 
    for key, val in history.items():
        avgd = np.mean(val[-avg_last_N_epochs:]) 
        print(f'\t{key} {avgd:g}')

In [None]:
def serial_sample_generator(model, action, batch_size, N_samples): 
    layers, prior = model['layers'], model['prior']
    layers.eval()
    x, logq, logp = None, None, None
    for i in range(N_samples):
        batch_i = i % batch_size
        if batch_i == 0:
            x, logq = apply_flow_to_prior(prior, layers, batch_size=batch_size) 
            logp = -action(x)
        yield x[batch_i], logq[batch_i], logp[batch_i]

def make_mcmc_ensemble(model, action, batch_size, N_samples): # function to generate the enseble where we will test the acceptance rate of our trained model
    history = {'x' : [], 'logq' : [], 'logp' : [], 'accepted' : []}
    # build Markov chain
    sample_gen = serial_sample_generator(model, action, batch_size, N_samples)
    for new_x, new_logq, new_logp in sample_gen:
        if len(history['logp']) == 0:
            accepted = True
        else:
            last_logp = history['logp'][-1]
            last_logq = history['logq'][-1]
            p_accept = torch.exp((new_logp - new_logq) - (last_logp - last_logq)) 
            p_accept = min(1, p_accept)
            draw = torch.rand(1).to(torch_device) # ~ [0,1]
            if draw < p_accept:
                accepted = True
            else:
                accepted = False
                new_x = history['x'][-1]
                new_logp = last_logp
                new_logq = last_logq
        history['logp'].append(new_logp)
        history['logq'].append(new_logq)
        history['x'].append(new_x) 
        history['accepted'].append(accepted)
    return history
      

In [None]:
# Parameters defined here
L = 8  # lattice size
lattice_shape = (L,L)
M2 = -4.0  # mass
lam = 8.0  # coupling constant
phi4_action = ScalarPhi4Action(M2=M2, lam=lam)
prior = SimpleNormal(torch.zeros(lattice_shape), torch.ones(lattice_shape))
n_layers = 16  
hidden_sizes = [8,8]  
kernel_size = 3 # CNN kernel size
layers = make_phi4_affine_layers(
    lattice_shape=lattice_shape, n_layers=n_layers,
    hidden_sizes=hidden_sizes, kernel_size=kernel_size) 
model = {'layers': layers, 'prior': prior}
base_lr = .001  # learning rate
optimizer = torch.optim.Adam(model['layers'].parameters(), lr=base_lr) # Adam optimizer

In [None]:
N_epoch_list = np.arange(10, 210, 10)  # test acceptance rate for different n_epoch
acceptance_rate = []
for N_epoch in N_epoch_list:
    N_era = 25
    batch_size = 64
    print_freq = N_epoch
    plot_freq = 1
    history = { 'loss' : [],
          'logp' : [], 'logq' : [], 'ess' : []
          }
          
    for era in range(N_era):
        for epoch in range(N_epoch):
            train_step(model, phi4_action, calc_dkl, optimizer, history)

    serialized_model = io.BytesIO() 
    torch.save(model['layers'].state_dict(), serialized_model)

    ensemble_size = 8192
    phi4_ens = make_mcmc_ensemble(model, phi4_action, 64, ensemble_size) 
    acceptance_rate.append(np.mean(phi4_ens['accepted']))
    print(N_epoch, "- accept rate:", np.mean(phi4_ens['accepted']))  #save the output value for each Nepoch you try

110 - accept rate: 0.5189208984375
120 - accept rate: 0.581787109375
130 - accept rate: 0.6044921875
140 - accept rate: 0.596923828125
150 - accept rate: 0.593505859375
160 - accept rate: 0.5836181640625
170 - accept rate: 0.6302490234375
180 - accept rate: 0.6224365234375
190 - accept rate: 0.6031494140625
200 - accept rate: 0.6529541015625
