In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from jax.config import config
config.update('jax_disable_jit', True)

from pathlib import Path
from typing import Callable
import wandb

import jax
from jax import numpy as jnp
from jax import random as rnd
from flax.training.train_state import TrainState
from flax.training import common_utils
from flax import linen as nn, jax_utils
import optax

from pyfig import Pyfig
from hwat import logabssumdet, create_masks

def wpr(d:dict):
    for k,v in d.items():
        typ = type(v) 
        has_shape = hasattr(v, 'shape')
        shape = v.shape if has_shape else None
        dtype = v.dtype if hasattr(v, 'dtype') else None
        mean = jnp.mean(v) if has_shape else v
        std = jnp.std(v) if has_shape else None
        print(k, f'\t mean={mean} \t std={std} \t shape={shape} \t dtype={dtype}') # \t type={typ}

# Method 1 completely refer to Pyfig:
    # - Can't get module in the args 
# Must have shape debug print

# 11am: 
# 1- Putting all variables into every Sub - done, it was a loopy mutable issue
# 2- Stop printing mask - done, moved masks to Ferminet
# 3- 

c = Pyfig(wandb_mode='disabled', debug=True) # online:on|disabled:off|offline:local, True: 
# c.d

2022-11-28 11:54:27.891218: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:497] The NVIDIA driver's CUDA version is 11.4 which is older than the ptxas CUDA version (11.6.55). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


{}


In [3]:
class FermiNet(nn.Module):
    n_e: int = None
    n_u: int = None
    n_d: int = None
    compute_s_emb: Callable = None
    compute_p_emb: Callable = None
    compute_s_perm: Callable = None
    n_det: int = None
    n_fb: int = None
    n_fb_out: int = None
    n_pv: int = None
    n_sv: int = None

    @nn.compact
    def __call__(_i, x):

        p_mask_u, p_mask_d = create_masks(_i.n_e, _i.n_u)
        
        xu, xd = jnp.split(x, [_i.n_u,], axis=0)
        x_s_var = _i.compute_s_emb(x)
        x_p_var = _i.compute_p_emb(x)
        wpr(dict(x_s_var=x_s_var, x_p_var=x_p_var))

        x_s_res = x_p_res = 0.
        for _ in range(_i.n_fb):
            x_p_var = x_p_res = nn.tanh(nn.Dense(_i.n_pv)(x_p_var)) + x_p_res
            x_s_var = _i.compute_s_perm(x_s_var, x_p_var, p_mask_u, p_mask_d)
            x_s_var = x_s_res = nn.tanh(nn.Dense(_i.n_sv)(x_s_var)) + x_s_res
            wpr(dict(x_p_var=x_p_var, x_s_var=x_s_var))

        x_w = nn.tanh(nn.Dense(_i.n_fb_out)(x_s_var))
        x_wu, x_wd = jnp.split(x_w, [_i.n_u,], axis=0)
        x_wu = nn.tanh(nn.Dense(_i.n_det*_i.n_u)(x_wu))
        x_wd = nn.tanh(nn.Dense(_i.n_det*_i.n_d)(x_wd))
        wpr(dict(x_w=x_w, x_wu=x_wu, x_wd=x_wd))
        
        orb_u = jnp.stack((x_wu * jnp.exp(-nn.Dense(_i.n_u*_i.n_det)(-xu))).split(_i.n_det, axis=-1)) # (e, f(e)) (e, (f(e))*n_det)
        orb_d = jnp.stack((x_wd * jnp.exp(-nn.Dense(_i.n_d*_i.n_det)(-xd))).split(_i.n_det, axis=-1))
        wpr(dict(orb_u=orb_u, orb_d=orb_d))

        log_psi, sgn = logabssumdet(orb_u, orb_d)
        return log_psi

model = c.pass_arg(FermiNet)

rng = rnd.PRNGKey(1)
x = rnd.normal(rng, (c.data.n_e, 3))
params = model.init(rng, x)
model.apply(params, x)

from functools import partial

@partial(jax.pmap, axis_name='b')
def create_train_state(rng):
  model = c.pass_arg(FermiNet)
  x = rnd.normal(rng, (c.data.n_e, 3))
  params = model.init(rng, x)['params']
  tx = optax.sgd(c.opt.lr)
  return TrainState.create(apply_fn=model.apply, params=params, tx=tx)

rng = rnd.split(rng, len(jax.devices()))
state = create_train_state(rng)
state = jax_utils.replicate(state)

@jax.jit
def train_step(params, state, b):
  
  b_energy = compute_energy(state)
  
  def loss_fn(p):
    model_out = state.apply_fn({'params':p}, b)
    log_psi, sgn = model_out
    return log_psi, sgn
  
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (log_psi, sgn), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  
  variables = {
    'E batch'           : b_energy,
    r'$\log\psi$ batch' : log_psi,
    r'sgn$(\cdot)$'     : sgn
  }
  
  metrics = compute_metric(b_energy, log_psi, sgn)
  return state, metrics

train_step = jax.pmap(train_step, axis_name='batch')

x_s_var 	 mean=1.9469865560531616 	 std=0.7550688982009888 	 shape=(10, 1) 	 dtype=float32
x_p_var 	 mean=2.524434804916382 	 std=1.3280136585235596 	 shape=(10, 10, 1) 	 dtype=float32
x_p_var 	 mean=-0.6352688670158386 	 std=0.47677209973335266 	 shape=(10, 10, 8) 	 dtype=float32
x_s_var 	 mean=0.09472795575857162 	 std=0.5215277075767517 	 shape=(10, 16) 	 dtype=float32
x_p_var 	 mean=-0.28023380041122437 	 std=0.4364957809448242 	 shape=(10, 10, 8) 	 dtype=float32
x_s_var 	 mean=0.23098309338092804 	 std=0.6535929441452026 	 shape=(10, 16) 	 dtype=float32
x_w 	 mean=-0.004265791270881891 	 std=0.5131012201309204 	 shape=(10, 64) 	 dtype=float32
x_wu 	 mean=-0.15545150637626648 	 std=0.505211353302002 	 shape=(5, 5) 	 dtype=float32
x_wd 	 mean=-0.07841941714286804 	 std=0.3952345550060272 	 shape=(5, 5) 	 dtype=float32
orb_u 	 mean=-0.7551324963569641 	 std=2.791072368621826 	 shape=(1, 5, 5) 	 dtype=float32
orb_d 	 mean=-0.0950351282954216 	 std=0.5227019190788269 	 shape=(1, 5, 5) 

In [None]:

wandb.define_metric("*", step_metric="train/step")

# # Display a project workspace
# %wandb USERNAME/PROJECT
# # Display a single run
# %wandb USERNAME/PROJECT/runs/RUN_ID
# # Display a sweep
# %wandb USERNAME/PROJECT/sweeps/SWEEP_ID
# # Display a report
# %wandb USERNAME/PROJECT/reports/REPORT_ID
# # Specify the height of embedded iframe
# %wandb USERNAME/PROJECT -h 2048

In [None]:
def get_data_stats():
    def normalize_to_neg_one_to_one(img):
        return img * 2 - 1  
    data_tr = datasets.FashionMNIST(
        root=c.data_dir, 
        download=True, 
        train=True,
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(28),
            transforms.CenterCrop(28),
        ]),
        target_transform = transforms.Compose([
        transforms.Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
    ])
    )
    imgs = torch.stack([img_t for img_t, _ in data_tr], dim=3)
    mean = imgs.view(1,-1).mean(dim=1) 
    std = imgs.view(1,-1).std(dim=1)
    print('Data mean: ', mean, 'Data std: ', std)
    return mean, std
mean, std = get_data_stats()
loader_tr, loader_test = get_fashion_loader(c.data.b_size, c.data_dir, mean=mean, std=std)  # is not an iterator or list
img, l = next(iter(loader_tr))

In [None]:
import numpy as np
Image.fromarray(np.uint8((img[0, 0]*std+mean).cpu().numpy()*255))

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('1', size=(cols*w, rows*h))
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))

    return grid
std = float(std.cpu().numpy())
mean = float(mean.cpu().numpy())

In [None]:
train_metrics = []

for ep in range(c.n_epoch):
    for step, (b, target) in enumerate(loader_tr):
        b = jnp.squeeze(jnp.expand_dims(jnp.array(b.numpy()), axis=(0,-1)), axis=2) # p, B, H, W, C
        rng, *train_step_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
        train_step_rng = jnp.array(train_step_rng)
        
        state, metrics = p_train_step(train_step_rng, state, b)

        if step <= c.ema.update_after_step:
            state = p_copy_params_to_ema(state)

        elif step % c.ema.update_every == 0:
            ema_decay = ema_decay_fn(step)
            state =  p_apply_ema(state, ema_decay)

        if step % c.log_metric_step == 0:

            train_metrics.append(metrics)
            train_metrics = common_utils.get_metrics(train_metrics)
            
            summary = {
                f'train/{k}': v
                for k, v in jax.tree_map(lambda x: x.mean(), train_metrics).items()
            }

            train_metrics = []
            b = ((np.array(b)*std+mean)*255).astype(np.uint8)
            imgs = [wandb.Image(Image.fromarray(b[0, i].reshape(28, 28))) for i in range(9)]
            
            wandb.log({
                    "train/step": step, 
                    'train/sample': imgs,
                    **summary
            })
 
    print('Epoch: ', ep)

In [None]:
# Image.fromarray(b[0, 0].reshape(28, 28, 1))
b[0, 0].max()