In [22]:
%load_ext autoreload
%autoreload 2

import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from jax.config import config
config.update('jax_disable_jit', True)

path exists, leaving alone


In [None]:
from pathlib import Path
from typing import Callable
from functools import partial, reduce
import wandb

import jax
from jax import numpy as jnp
from jax import random as rnd
from flax.training.train_state import TrainState
from flax.training import common_utils
from flax import linen as nn, jax_utils
import optax

from pyfig import Pyfig

c = Pyfig(wandb_mode='disabled', notebook=True)
c.print()

class FermiNet(nn.Module):
  
  n_sv: int
  n_pv: int
  n_fb: int
  n_fb_out: int
  n_det: int

  af_fb: Callable
  af_fb_out: Callable
  af_pseudo: Callable

  compute_s_emb: Callable
  compute_p_emb: Callable
  
  n_e: int
  n_u: int
  n_d: int = n_e-n_u
  p_mask_u, p_mask_d = create_masks(n_e, n_u)
  compute_s_perm: Callable = partial(compute_s_perm, p_mask_u=p_mask_u, p_mask_d=p_mask_d, n_u=n_u)

  pbc:          bool = False

  @nn.compact
  def __call__(_i, x):
    
    n_e, n_u, n_d, n_det = _i.n_e, _i.n_u, _i.n_d, _i.n_det

    x_s_var = _i.compute_s_emb(x)
    x_p_var = _i.compute_p_emb(x)

    x_s_res = x_p_res = 0.
    for _ in range(_i.n_fb):
        x_p_var = x_p_res = _i.af_fb(nn.Dense(_i.n_pv)(x_p_var)) + x_p_res
        
        x_s_var = _i.compute_s_perm(x_s_var, x_p_var)
        x_s_var = x_s_res = _i.af_fb(nn.Dense(_i.n_sv)(x_s_var)) + x_s_res

    x_w = jnp.concatentate([x_s_var, x_p_var], axis=-1)
    x_w = _i.af_fb_out(nn.Dense(_i.n_fb_out)(x_w))
    x_w = _i.af_pseudo(nn.Dense(n_det*n_e)(x_w))
    x_wu, x_wd = jnp.split(x_w, [n_u, n_d], axis=0)
    
    orb_u = jnp.split(x_wu * jnp.exp(-nn.Dense(n_u*n_det)), n_det, axis=1)
    orb_d = jnp.split(x_wd * jnp.exp(-nn.Dense(n_d*n_det)), n_det, axis=1)

    log_psi, sgn = logabssumdet(orb_u, orb_d)

    return log_psi

model = FermiNet(**c.dict)
n_e = 10
rng_model = rnd.PRNGKey(1)
x = jnp.ones(n_e, 3)
params = model.init(rng_model, x)
model.apply(params, x)

In [29]:

def compute_s_perm(x, x_p, p_mask_u, p_mask_d, n_u):
    n_e, _ = x_p.shape
    n_u_ish, _ = x.shape
    n_d = n_e - n_u_ish
    n_u = n_e - n_d

    xu, xd = jnp.split(x, [n_u, n_d], axis=0)
    mean_xu = jnp.mean(xu, axis=0, keepdims=True)
    mean_xd = jnp.mean(xd, axis=0, keepdims=True)

    x_p = jnp.expand_dims(x_p, axis=0)
    sum_p_u = (p_mask_u * x_p).sum((1, 2)) / float(n_u)
    sum_p_d = (p_mask_d * x_p).sum((1, 2)) / float(n_d)

    x = jnp.concatenate((x, mean_xu, mean_xd, sum_p_u, sum_p_d), axis=-1)
    return jnp.split(x, [n_u, n_d], axis=0)

def logabssumdet(orb_u, orb_d=None):
    
    xs = [orb_u, orb_d] if not orb_d is None else [orb_u]
    
    dets = [x.reshape(-1) for x in xs if x.shape[-1] == 1]
    dets = reduce(lambda a,b: a*b, dets) if len(dets)>0 else 1

    slogdets = [jnp.linalg.slogdet(x) for x in xs if x.shape[-1] > 1]
    
    if len(slogdets) > 0: # at least 2 electon in at least 1 orbital
        sign_in, logdet = reduce(lambda a,b: (a[0]*b[0], a[1]+b[1]), slogdets)
        maxlogdet = jnp.max(logdet)
        det = sign_in * dets * jnp.exp(logdet-maxlogdet)
    else:
        maxlogdet = 0
        det = dets

    psi_ish = jnp.sum(det)
    sgn_psi = jnp.sign(psi_ish)
    log_psi = jnp.log(jnp.abs(psi_ish)) + maxlogdet
    return log_psi, sgn_psi

def create_masks(n_electrons, n_up):
    ups = jnp.ones(n_electrons)
    ups[n_up:] = 0.
    downs = (ups-1.)*-1.

    pairwise_up_mask = []
    pairwise_down_mask = []
    for electron in range(n_electrons):
        mask_up = jnp.zeros((n_electrons, n_electrons))
        mask_up[electron, :] = ups
        pairwise_up_mask.append(mask_up)
        # mask_up = mask_up[eye_mask].reshape(-1) # for when drop diagonal enforced
        mask_down = jnp.zeros((n_electrons, n_electrons))
        mask_down[electron, :] = downs
        pairwise_down_mask.append(mask_down)

    pairwise_up_mask = jnp.stack(pairwise_up_mask, axis=0)[..., None]
    pairwise_down_mask = jnp.stack(pairwise_down_mask, axis=0)[..., None]
    return pairwise_up_mask, pairwise_down_mask

from typing import Callable

In [26]:
def compute_metric(d:dict):
    ...
    metrics = lax.pmean(metrics, axis_name='b')

    return metric


def compute_energy():
    

    return 

In [28]:
n_e = 10
x = jnp.ones((2, n_e, 3))
xu, xd = x.split(2, axis=1)

def init_walker(xu, xd):
    rng = rnd.PRNGKey(c.seed)
    rng_u, rng_d = rnd.split(rng, 2)
    return jnp.concatenate([rnd.normal(rng_u, xu.shape), rnd.normal(rng_d, xd.shape)], axis=1)

from jax import pmap

class SampleState():
    """
    @struct.dataclass
    struct.PyTreeNode means jax transformations *do not affect it* eg pmap
    fn.apply << apply vs fn() << apply_fn"""
    
    def __init__(_i, step=0, move_std=0.02, corr_len=20):
        _i.acc_target = 0.5
        _i.corr_len = corr_len//2
        _i.move_std = move_std
        _i.step = step
    
    def __call__(_i, x, state:TrainState):
        _i.rng, rng_0, rng_1, rng_move = rnd.split(_i.rng, 4)

        x, acc_0 = sample(rng_0, x, state, _i.corr_len, _i.move_std)
        move_std_1 = jnp.clip(_i.move_std + 0.001*rnd.normal(rng_move))
        x, acc_1 = sample(rng_1, x, state, _i.corr_len, move_std_1)

        mask = jnp.array((_i.acc_target-acc_0)**2 < (_i.acc_target-acc_1)**2, dtype=jnp.float32)
        not_mask = ((mask-1.)*-1.)
        _i.move_std = mask*_i.move_std + not_mask*move_std_1
        return x

def sample(rng, x, state:TrainState, corr_len, move_std):

    def move(x, rng, move_std):
        x = x + rnd.normal(rng, x.shape)*move_std
        return x

    to_prob = lambda log_psi: jnp.exp(log_psi)**2
    
    p = to_prob(state(x))
    
    acc = 0.0
    for _ in range(corr_len//2):
        rng, rng_move, rng_alpha = rnd.split(rng)
        
        x_1 = move(x, rng_move, move_std)
        p_1 = to_prob(state(x_1))

        p_mask = (p_1 / p) > rnd.uniform(rng_alpha, p_1.shape)
        p = jnp.where(p_mask, p_1, p)
        p_mask = jnp.expand_dims(p, axis=(-1, -2))
        x = jnp.where(p_mask, x_1, x)

        acc += jnp.mean(p_mask)

    return x, acc


# x = init_walker(xu, xd)

NameError: name 'n_u' is not defined

DeviceArray([[[[0.31068158, 1.        , 0.9456457 ],
               [0.31068158, 1.        , 0.9456457 ],
               [0.31068158, 1.        , 0.9456457 ],
               ...,
               [0.31068158, 1.        , 0.9456457 ],
               [0.31068158, 1.        , 0.9456457 ],
               [0.31068158, 1.        , 0.9456457 ]],

              [[0.31068158, 1.        , 0.9456457 ],
               [0.31068158, 1.        , 0.9456457 ],
               [0.31068158, 1.        , 0.9456457 ],
               ...,
               [0.31068158, 1.        , 0.9456457 ],
               [0.31068158, 1.        , 0.9456457 ],
               [0.31068158, 1.        , 0.9456457 ]],

              [[0.31068158, 1.        , 0.9456457 ],
               [0.31068158, 1.        , 0.9456457 ],
               [0.31068158, 1.        , 0.9456457 ],
               ...,
               [0.31068158, 1.        , 0.9456457 ],
               [0.31068158, 1.        , 0.9456457 ],
               [0.31068158, 1.     

In [None]:

# jax.tree_map(lambda x: x.shape, params) # Check the parameters





def create_train_state(c: Pyfig):
  model = CNN()
  params = model.init(c.rng_init, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(c.lr)
  return TrainState.create(apply_fn=model.apply, params=params, tx=tx)

state = create_train_state(c)
state = jax_utils.replicate(state)

@jax.jit
def train_step(params, state, b):
  
  b_energy = compute_energy(state)
  
  def loss_fn(p):
    model_out = state.apply_fn({'params':p}, b)
    log_psi, sgn = model_out
    return log_psi, sgn
  
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (log_psi, sgn), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  
  variables = {
    'E batch'           : b_energy,
    r'$\log\psi$ batch' : log_psi,
    r'sgn$(\cdot)$'     : sgn
  }
  
  metrics = compute_metric(b_energy, log_psi, sgn)
  return state, metrics

train_step = jax.pmap(train_step, axis_name='batch')

In [None]:
def get_posterior_mean_variance(img, t, x0, v, ddpm):

    beta = ddpm_param.betas[t,None,None,None] # only needed when t > 0
    alpha = ddpm_param.alphas[t,None,None,None]
    alpha_bar = ddpm_param.alphas_bar[t,None,None,None]
    alpha_bar_last = ddpm_param.alphas_bar[t-1,None,None,None]
    sqrt_alpha_bar_last = ddpm_param.sqrt_alphas_bar[t-1,None,None,None]

    coef_x0 = beta * sqrt_alpha_bar_last / (1. - alpha_bar)
    coef_xt = (1. - alpha_bar_last) * jnp.sqrt(alpha) / ( 1- alpha_bar)        
    posterior_mean = coef_x0 * x0 + coef_xt * img
        
    posterior_variance = beta * (1 - alpha_bar_last) / (1. - alpha_bar)
    posterior_log_variance = jnp.log(jnp.clip(posterior_variance, a_min = 1e-20))

    return posterior_mean, posterior_log_variance


def ddpm_sample_step(state, rng, x, t, x0_last, ddpm_param: ddpm_param):
 
    batched_t = jnp.ones((x.shape[0],), dtype=jnp.int32) * t
    
    if c.ddpm.self_condition:
        x0, v = model_predict(state, x, x0_last, batched_t, ddpm_param, use_ema=True) 
    else:
        x0, v = model_predict(state, x, None, batched_t, ddpm_param, use_ema=True)
    
    x0 = jnp.clip(x0,-1.,1.) # make sure x0 between [-1,1]
    posterior_mean, posterior_log_variance = get_posterior_mean_variance(x, t, x0, v, ddpm_param)
    x = posterior_mean + jnp.exp(0.5 *  posterior_log_variance) * jax.random.normal(rng, x.shape) 
    return x, x0

sample_step = functools.partial(
    ddpm_sample_step, 
    ddpm_param=ddpm_param, 
    self_condition=c.ddpm.self_condition, 
    is_pred_x0=c.ddpm.is_pred_x0
)

p_sample_step = jax.pmap(sample_step, axis_name='batch')

In [None]:
def copy_params_to_ema(state):
   state = state.replace(params_ema = state.params)
   return state

def apply_ema_decay(state, ema_decay):
    params_ema = jax.tree_map(lambda p_ema, p: p_ema * ema_decay + p * (1. - ema_decay), state.params_ema, state.params)
    state = state.replace(params_ema = params_ema)
    return state

p_apply_ema = jax.pmap(apply_ema_decay, in_axes=(0, None), axis_name='batch')
p_copy_params_to_ema = jax.pmap(copy_params_to_ema, axis_name='batch')


In [None]:
from pprint import pprint
from pathlib import Path

def to_wandb_config(d, parent_key: str = '', sep: str ='.'):
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, dict):
            items.extend(to_wandb_config(v, new_key, sep=sep).items())
        else:
            if isinstance(v, Path):
                v = str(v)
            items.append((new_key, v))
    return dict(items)

wandb.init(
    job_type=c.wandb.job_type,
    entity=c.wandb.entity,
    project=c.project,
    config=to_wandb_config(c.dict),
    settings=wandb.Settings(start_method='fork'),  # idk why this is an issue
    dir=c.exp_path,
)

wandb.define_metric("*", step_metric="train/step")

# # Display a project workspace
# %wandb USERNAME/PROJECT
# # Display a single run
# %wandb USERNAME/PROJECT/runs/RUN_ID
# # Display a sweep
# %wandb USERNAME/PROJECT/sweeps/SWEEP_ID
# # Display a report
# %wandb USERNAME/PROJECT/reports/REPORT_ID
# # Specify the height of embedded iframe
# %wandb USERNAME/PROJECT -h 2048

In [None]:
def get_data_stats():
    def normalize_to_neg_one_to_one(img):
        return img * 2 - 1  
    data_tr = datasets.FashionMNIST(
        root=c.data_dir, 
        download=True, 
        train=True,
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(28),
            transforms.CenterCrop(28),
        ]),
        target_transform = transforms.Compose([
        transforms.Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
    ])
    )
    imgs = torch.stack([img_t for img_t, _ in data_tr], dim=3)
    mean = imgs.view(1,-1).mean(dim=1) 
    std = imgs.view(1,-1).std(dim=1)
    print('Data mean: ', mean, 'Data std: ', std)
    return mean, std
mean, std = get_data_stats()
loader_tr, loader_test = get_fashion_loader(c.data.b_size, c.data_dir, mean=mean, std=std)  # is not an iterator or list
img, l = next(iter(loader_tr))

In [None]:
import numpy as np
Image.fromarray(np.uint8((img[0, 0]*std+mean).cpu().numpy()*255))

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('1', size=(cols*w, rows*h))
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))

    return grid
std = float(std.cpu().numpy())
mean = float(mean.cpu().numpy())

In [None]:
train_metrics = []

for ep in range(c.n_epoch):
    for step, (b, target) in enumerate(loader_tr):
        b = jnp.squeeze(jnp.expand_dims(jnp.array(b.numpy()), axis=(0,-1)), axis=2) # p, B, H, W, C
        rng, *train_step_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
        train_step_rng = jnp.array(train_step_rng)
        
        state, metrics = p_train_step(train_step_rng, state, b)

        if step <= c.ema.update_after_step:
            state = p_copy_params_to_ema(state)

        elif step % c.ema.update_every == 0:
            ema_decay = ema_decay_fn(step)
            state =  p_apply_ema(state, ema_decay)

        if step % c.log_metric_step == 0:

            train_metrics.append(metrics)
            train_metrics = common_utils.get_metrics(train_metrics)
            
            summary = {
                f'train/{k}': v
                for k, v in jax.tree_map(lambda x: x.mean(), train_metrics).items()
            }

            train_metrics = []
            b = ((np.array(b)*std+mean)*255).astype(np.uint8)
            imgs = [wandb.Image(Image.fromarray(b[0, i].reshape(28, 28))) for i in range(9)]
            
            wandb.log({
                    "train/step": step, 
                    'train/sample': imgs,
                    **summary
            })
 
    print('Epoch: ', ep)

In [None]:
# Image.fromarray(b[0, 0].reshape(28, 28, 1))
b[0, 0].max()