In [None]:
%load_ext autoreload
%autoreload 2

import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from jax.config import config
config.update('jax_disable_jit', True)

path exists, leaving alone


In [None]:
from pathlib import Path
from typing import Callable
from functools import partial, reduce
import wandb

import jax
from jax import numpy as jnp
from jax import random as rnd
from flax.training.train_state import TrainState
from flax.training import common_utils
from flax import linen as nn, jax_utils
import optax

from pyfig import Pyfig
from hwat import create_masks, logabssumdet

c = Pyfig(wandb_mode='disabled', notebook=True) # online:on|disabled:off|offline:local, True: 
c.print()

class FermiNet(nn.Module):
  
  n_sv: int
  n_pv: int
  n_fb: int
  n_fb_out: int
  n_det: int

  af_fb: Callable
  af_fb_out: Callable
  af_pseudo: Callable

  compute_s_emb: Callable
  compute_p_emb: Callable
  
  n_e: int
  n_u: int
  n_d: int = n_e-n_u
  p_mask_u, p_mask_d = create_masks(n_e, n_u)
  compute_s_perm: Callable = partial(compute_s_perm, p_mask_u=p_mask_u, p_mask_d=p_mask_d, n_u=n_u)

  pbc:          bool = False

  @nn.compact
  def __call__(_i, x):
    
    n_e, n_u, n_d, n_det = _i.n_e, _i.n_u, _i.n_d, _i.n_det

    x_s_var = _i.compute_s_emb(x)
    x_p_var = _i.compute_p_emb(x)

    x_s_res = x_p_res = 0.
    for _ in range(_i.n_fb):
        x_p_var = x_p_res = _i.af_fb(nn.Dense(_i.n_pv)(x_p_var)) + x_p_res
        
        x_s_var = _i.compute_s_perm(x_s_var, x_p_var)
        x_s_var = x_s_res = _i.af_fb(nn.Dense(_i.n_sv)(x_s_var)) + x_s_res

    x_w = jnp.concatentate([x_s_var, x_p_var], axis=-1)
    x_w = _i.af_fb_out(nn.Dense(_i.n_fb_out)(x_w))
    x_w = _i.af_pseudo(nn.Dense(n_det*n_e)(x_w))
    x_wu, x_wd = jnp.split(x_w, [n_u, n_d], axis=0)
    
    orb_u = jnp.split(x_wu * jnp.exp(-nn.Dense(n_u*n_det)), n_det, axis=1)
    orb_d = jnp.split(x_wd * jnp.exp(-nn.Dense(n_d*n_det)), n_det, axis=1)

    log_psi, sgn = logabssumdet(orb_u, orb_d)

    return log_psi

model = FermiNet(**c.dict)
n_e = 10
rng_model = rnd.PRNGKey(1)
x = jnp.ones(n_e, 3)
params = model.init(rng_model, x)
model.apply(params, x)


# p_sample_step = jax.pmap(sample_step, axis_name='batch')

In [None]:

# jax.tree_map(lambda x: x.shape, params) # Check the parameters

def create_train_state(c: Pyfig):
  model = CNN()
  params = model.init(c.rng_init, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(c.lr)
  return TrainState.create(apply_fn=model.apply, params=params, tx=tx)

state = create_train_state(c)
state = jax_utils.replicate(state)

@jax.jit
def train_step(params, state, b):
  
  b_energy = compute_energy(state)
  
  def loss_fn(p):
    model_out = state.apply_fn({'params':p}, b)
    log_psi, sgn = model_out
    return log_psi, sgn
  
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (log_psi, sgn), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  
  variables = {
    'E batch'           : b_energy,
    r'$\log\psi$ batch' : log_psi,
    r'sgn$(\cdot)$'     : sgn
  }
  
  metrics = compute_metric(b_energy, log_psi, sgn)
  return state, metrics

train_step = jax.pmap(train_step, axis_name='batch')

In [None]:
def copy_params_to_ema(state):
   state = state.replace(params_ema = state.params)
   return state

def apply_ema_decay(state, ema_decay):
    params_ema = jax.tree_map(lambda p_ema, p: p_ema * ema_decay + p * (1. - ema_decay), state.params_ema, state.params)
    state = state.replace(params_ema = params_ema)
    return state

p_apply_ema = jax.pmap(apply_ema_decay, in_axes=(0, None), axis_name='batch')
p_copy_params_to_ema = jax.pmap(copy_params_to_ema, axis_name='batch')


In [None]:
from pprint import pprint
from pathlib import Path

def to_wandb_config(d, parent_key: str = '', sep: str ='.'):
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, dict):
            items.extend(to_wandb_config(v, new_key, sep=sep).items())
        else:
            if isinstance(v, Path):
                v = str(v)
            items.append((new_key, v))
    return dict(items)

wandb.init(
    job_type=c.wandb.job_type,
    entity=c.wandb.entity,
    project=c.project,
    config=to_wandb_config(c.dict),
    settings=wandb.Settings(start_method='fork'),  # idk why this is an issue
    dir=c.exp_path,
)

wandb.define_metric("*", step_metric="train/step")

# # Display a project workspace
# %wandb USERNAME/PROJECT
# # Display a single run
# %wandb USERNAME/PROJECT/runs/RUN_ID
# # Display a sweep
# %wandb USERNAME/PROJECT/sweeps/SWEEP_ID
# # Display a report
# %wandb USERNAME/PROJECT/reports/REPORT_ID
# # Specify the height of embedded iframe
# %wandb USERNAME/PROJECT -h 2048

In [None]:
def get_data_stats():
    def normalize_to_neg_one_to_one(img):
        return img * 2 - 1  
    data_tr = datasets.FashionMNIST(
        root=c.data_dir, 
        download=True, 
        train=True,
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(28),
            transforms.CenterCrop(28),
        ]),
        target_transform = transforms.Compose([
        transforms.Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
    ])
    )
    imgs = torch.stack([img_t for img_t, _ in data_tr], dim=3)
    mean = imgs.view(1,-1).mean(dim=1) 
    std = imgs.view(1,-1).std(dim=1)
    print('Data mean: ', mean, 'Data std: ', std)
    return mean, std
mean, std = get_data_stats()
loader_tr, loader_test = get_fashion_loader(c.data.b_size, c.data_dir, mean=mean, std=std)  # is not an iterator or list
img, l = next(iter(loader_tr))

In [None]:
import numpy as np
Image.fromarray(np.uint8((img[0, 0]*std+mean).cpu().numpy()*255))

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('1', size=(cols*w, rows*h))
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))

    return grid
std = float(std.cpu().numpy())
mean = float(mean.cpu().numpy())

In [None]:
train_metrics = []

for ep in range(c.n_epoch):
    for step, (b, target) in enumerate(loader_tr):
        b = jnp.squeeze(jnp.expand_dims(jnp.array(b.numpy()), axis=(0,-1)), axis=2) # p, B, H, W, C
        rng, *train_step_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
        train_step_rng = jnp.array(train_step_rng)
        
        state, metrics = p_train_step(train_step_rng, state, b)

        if step <= c.ema.update_after_step:
            state = p_copy_params_to_ema(state)

        elif step % c.ema.update_every == 0:
            ema_decay = ema_decay_fn(step)
            state =  p_apply_ema(state, ema_decay)

        if step % c.log_metric_step == 0:

            train_metrics.append(metrics)
            train_metrics = common_utils.get_metrics(train_metrics)
            
            summary = {
                f'train/{k}': v
                for k, v in jax.tree_map(lambda x: x.mean(), train_metrics).items()
            }

            train_metrics = []
            b = ((np.array(b)*std+mean)*255).astype(np.uint8)
            imgs = [wandb.Image(Image.fromarray(b[0, i].reshape(28, 28))) for i in range(9)]
            
            wandb.log({
                    "train/step": step, 
                    'train/sample': imgs,
                    **summary
            })
 
    print('Epoch: ', ep)

In [None]:
# Image.fromarray(b[0, 0].reshape(28, 28, 1))
b[0, 0].max()