In [2]:
from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
from jax import random as rnd
import numpy as np


In [1]:
import torch as tc
from torch import nn
tc.set_default_dtype(tc.float64)
tc.normal(0., 1., (1,))

tensor([-0.0218])

In [4]:
def compare(tc_arr, jnp_arr):
    tc_arr = tc_arr.detach().cpu().numpy()
    diff = tc_arr - jnp_arr
    sum_diff = jnp.sum(diff)
    abs_sum_diff = jnp.sum(jnp.abs(diff))
    print('abs sum diff %.8f, sum diff %.8f' % (abs_sum_diff, sum_diff))
    return sum_diff

In [11]:
n_samples = 1000
n_atoms = 1
n_electrons = 5
n_sh = 20
n_ph = 16
n_up = 3
n_down = n_electrons - n_up
n_determinants = 2

re = np.random.normal(0, 1, (n_samples, n_electrons, 3))
ra = np.random.normal(0, 1, (n_atoms, 3))
ra = np.concatenate([ra[None, ...] for _ in range(n_samples)], axis=0)

ra_tc = tc.from_numpy(ra)
re_tc = tc.from_numpy(re)

print(ra.shape, re.shape)

(1000, 1, 3) (1000, 5, 3)


In [11]:
def compute_ae_vectors(r_electrons: jnp.array, 
                       r_atoms: jnp.array) -> jnp.array:
    
    r_atoms = jnp.expand_dims(r_atoms, axis=1)
    r_electrons = jnp.expand_dims(r_electrons, axis=2)
    ae_vectors = r_electrons - r_atoms
    return ae_vectors

def compute_ae_vectors_tc(r_atoms: tc.Tensor, r_electrons: tc.Tensor) -> tc.Tensor:
    # ae_vectors (n_samples, n_electrons, n_atoms, 3)
    r_atoms = r_atoms.unsqueeze(1)
    r_electrons = r_electrons.unsqueeze(2)
    ae_vectors = r_electrons - r_atoms
    return ae_vectors

ae_vectors_tc = compute_ae_vectors_tc(ra_tc, re_tc)
ae_vectors = compute_ae_vectors(re, ra)
print(ae_vectors.shape)
ae_diff = compare(ae_vectors_tc, ae_vectors)


(1000, 5, 1, 3)
abs sum diff 0.00000000, sum diff 0.000000


In [15]:
def compute_inputs_tc(r_electrons, 
                   n_samples : int, 
                   ae_vectors, 
                   n_atoms : int, 
                   n_electrons : int):
    
    # r_atoms: (n_atoms, 3)
    # r_electrons: (n_samples, n_electrons, 3)
    # ae_vectors: (n_samples, n_electrons, n_atoms, 3)
    
    ae_distances = tc.norm(ae_vectors, dim=-1, keepdim=True)
    single_inputs = tc.cat((ae_vectors, ae_distances), dim=-1)
    single_inputs = single_inputs.view((-1, n_electrons, 4 * n_atoms))

    re1 = r_electrons.unsqueeze(2)
    re2 = re1.permute((0, 2, 1, 3))
    ee_vectors = re1 - re2

#     mask = tc.eye(n_electrons, dtype=tc.bool)
#     mask = ~mask.unsqueeze(0).unsqueeze(3).repeat((n_samples, 1, 1, 3))

#     ee_vectors = ee_vectors[mask]
#     ee_vectors = ee_vectors.view((-1, int(n_electrons ** 2 - n_electrons), 3))
    
    ee_vectors = ee_vectors.view((-1, int(n_electrons ** 2), 3))
    ee_distances = tc.norm(ee_vectors, dim=-1, keepdim=True)

    pairwise_inputs = tc.cat((ee_vectors, ee_distances), dim=-1)

    return single_inputs, pairwise_inputs


def compute_inputs(r_electrons, ae_vectors):
    n_samples, n_electrons, n_atoms = ae_vectors.shape[:3]

    ae_distances = jnp.linalg.norm(ae_vectors, axis=-1, keepdims=True)
    single_inputs = jnp.concatenate([ae_vectors, ae_distances], axis=-1)
    single_inputs = single_inputs.reshape(n_samples, n_electrons, 4 * n_atoms)
    
    re1 = jnp.expand_dims(r_electrons, axis=2)
    re2 = jnp.transpose(re1, [0, 2, 1, 3])
    
    ee_vectors = re1 - re2
    ee_distances = jnp.linalg.norm(ee_vectors, axis=-1, keepdims=True)
    pairwise_inputs = jnp.concatenate([ee_vectors, ee_distances], axis=-1)
    pairwise_inputs = pairwise_inputs.reshape(n_samples, n_electrons**2, 4)

    return single_inputs, pairwise_inputs

jit_compute_inputs = jax.jit(compute_inputs)

%timeit compute_inputs_tc(re_tc, n_samples, ae_vectors_tc, n_atoms, n_electrons)
%timeit compute_inputs(re, ae_vectors)[0].block_until_ready()
%timeit jit_compute_inputs(re, ae_vectors)[0].block_until_ready()

sin_tc, pin_tc = compute_inputs_tc(re_tc, n_samples, ae_vectors_tc, n_atoms, n_electrons)

single_inputs, pairwise_inputs = compute_inputs(re, ae_vectors)

compare(sin_tc, single_inputs)
compare(pin_tc, pairwise_inputs)

2.71 ms ± 8.53 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
968 µs ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
75.7 µs ± 7.4 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [50]:


def create_masks(n_sh, n_ph, n_electrons, n_up):
    
    # single spin masks
    n_down = n_electrons - n_up
    n_pairwise = n_electrons**2

    tmp1 = jnp.ones((1, n_up, n_sh))
    tmp2 = jnp.zeros((1, n_down, n_sh))
    single_up_mask = jnp.concatenate((tmp1, tmp2), axis=1)
    single_down_mask = (jnp.concatenate((tmp1, tmp2), axis=1)-1.)*-1.

    # pairwise spin masks
    ups = np.ones(n_electrons)
    ups[n_up:] = 0
    downs = (ups-1.)*-1.

    pairwise_up_mask = []
    pairwise_down_mask = []
    mask = np.zeros((n_electrons, n_electrons))

    for electron in range(n_electrons):
        e_mask_up = np.zeros((n_electrons,))
        e_mask_down = np.zeros((n_electrons,))

        mask_up = np.copy(mask)
        mask_up[electron, :] = ups
        mask_up = mask_up.reshape(-1)

        mask_down = np.copy(mask)
        mask_down[electron, :] = downs
        mask_down = mask_down.reshape(-1)

        pairwise_up_mask.append(mask_up)
        pairwise_down_mask.append(mask_down)

    pairwise_up_mask = jnp.array(pairwise_up_mask).reshape((1, n_electrons, n_pairwise, 1))
    pairwise_up_mask = jnp.repeat(pairwise_up_mask, n_ph, axis=-1)

    pairwise_down_mask = jnp.array(pairwise_down_mask).reshape((1, n_electrons, n_pairwise, 1))
    pairwise_down_mask = jnp.repeat(pairwise_down_mask, n_ph, axis=-1)
    return single_up_mask, single_down_mask, pairwise_up_mask, pairwise_down_mask
    
    


def generate_pairwise_masks(n_electrons: int, 
                            n_pairwise: int, 
                            n_spin_up: int, 
                            n_pairwise_features: int):
    
    eye_mask = ~np.eye(n_electrons, dtype=bool)
    ups = np.ones(n_electrons, dtype=bool)
    ups[n_spin_up:] = False
    downs = ~ups

    spin_up_mask = []
    spin_down_mask = []
    mask = np.zeros((n_electrons, n_electrons), dtype=bool)

    for electron in range(n_electrons):
        e_mask_up = np.zeros((n_electrons,), dtype=bool)
        e_mask_down = np.zeros((n_electrons,), dtype=bool)

        mask_up = np.copy(mask)
        mask_up[electron, :] = ups
        mask_up = mask_up.reshape(-1)


        mask_down = np.copy(mask)
        mask_down[electron, :] = downs
        mask_down = mask_down.reshape(-1)

        spin_up_mask.append(mask_up)
        spin_down_mask.append(mask_down)

    spin_up_mask = tc.tensor(spin_up_mask, dtype=tc.bool)
    # (n_samples, n_electrons, n_electrons, n_pairwise_features)
    spin_up_mask = spin_up_mask.view((1, n_electrons, n_pairwise, 1))
    spin_up_mask = spin_up_mask.repeat((1, 1, 1, n_pairwise_features))

    spin_down_mask = tc.tensor(spin_down_mask, dtype=tc.bool)
    spin_down_mask = spin_down_mask.view((1, n_electrons, n_pairwise, 1))
    spin_down_mask = spin_down_mask.repeat((1, 1, 1, n_pairwise_features))

    return spin_up_mask, spin_down_mask


class Mixer(nn.Module):
    def __init__(self, n_single_features, n_pairwise_features, n_electrons, n_spin_up, device, dtype):
        super(Mixer, self).__init__()
        self.dv = device
        self.dt = dtype

        n_spin_down = n_electrons - n_spin_up
        n_pairwise = n_electrons**2

        self.n_electrons = n_electrons
        self.n_spin_up = float(n_spin_up)
        self.n_spin_down = float(n_spin_down)

        tmp1 = tc.ones((1, n_spin_up, n_single_features), dtype=tc.bool, device=device)
        tmp2 = tc.zeros((1, n_spin_down, n_single_features), dtype=tc.bool, device=device)
        self.spin_up_mask = tc.cat((tmp1, tmp2), dim=1).type(dtype)
        self.spin_down_mask = (~tc.cat((tmp1, tmp2), dim=1)).type(dtype)

        self.pairwise_spin_up_mask, self.pairwise_spin_down_mask = \
            generate_pairwise_masks(n_electrons, n_pairwise, n_spin_up, n_pairwise_features)
        self.pairwise_spin_up_mask = self.pairwise_spin_up_mask.type(dtype).to(device)
        self.pairwise_spin_down_mask = self.pairwise_spin_down_mask.type(dtype).to(device)

    def forward(self, single: tc.Tensor, pairwise: tc.Tensor):
        # single (n_samples, n_electrons, n_single_features)
        # pairwise (n_samples, n_electrons, n_pairwise_features)
        # spin_up_mask = self.spin_up_mask.repeat((n_samples, 1, 1))
        # spin_down_mask = self.spin_down_mask.repeat((n_samples, 1, 1))

        # --- Single summations
        # up
        sum_spin_up = self.spin_up_mask * single
        sum_spin_up = sum_spin_up.sum(1, keepdim=True) / self.n_spin_up
        sum_spin_up = sum_spin_up.repeat((1, self.n_electrons, 1))

        # down
        sum_spin_down = self.spin_down_mask * single
        sum_spin_down = sum_spin_down.sum(1, keepdim=True) / self.n_spin_down
        sum_spin_down = sum_spin_down.repeat((1, self.n_electrons, 1))

        # --- Pairwise summations
        sum_pairwise = pairwise.unsqueeze(1).repeat((1, self.n_electrons, 1, 1))

        # up
        sum_pairwise_up = self.pairwise_spin_up_mask * sum_pairwise
        sum_pairwise_up = sum_pairwise_up.sum(2) / self.n_spin_up

        # down
        sum_pairwise_down = self.pairwise_spin_down_mask * sum_pairwise
        sum_pairwise_down = sum_pairwise_down.sum(2) / self.n_spin_down

        features = tc.cat((single, sum_pairwise_up, sum_pairwise_down, sum_spin_up, sum_spin_down), dim=2)
        return features

single = np.random.normal(0, 1, (n_samples, n_electrons, n_sh))
pairwise = np.random.normal(0, 1, (n_samples, n_electrons**2, n_ph))

single_tc = tc.from_numpy(single)
pairwise_tc = tc.from_numpy(pairwise)

tc_mixer = Mixer(n_sh, n_ph, n_electrons, n_up, 'cpu', tc.float64)
features_tc = tc_mixer(single_tc, pairwise_tc)

smu, smd, pmu, pmd = create_masks(n_sh, n_ph, n_electrons, n_up)
features = mixer(single, pairwise, smu, smd, pmu, pmd, n_electrons, n_up, n_down)

compare(features_tc, features)



  spin_up_mask = tc.tensor(spin_up_mask, dtype=tc.bool)
  spin_down_mask = tc.tensor(spin_down_mask, dtype=tc.bool)


abs sum diff 0.00000000, sum diff -0.00000000


DeviceArray(-1.22780475e-15, dtype=float64)

In [22]:
class EnvelopeLinear():
    def __init__(self,
                 w,
                 n_hidden,
                 n_spin_det,

                 n_samples,
                 n_determinants):
        
        self.out_shape = (-1, n_spin_det, n_determinants, n_spin_det)
        self.w = nn.Parameter(w, requires_grad=True)

    def __call__(self, data):
        n_samples, n_spin_det = data.shape[:2]
        
        bias_data = tc.ones((n_samples, n_spin_det, 1))
        data_w_bias = tc.cat((data, bias_data), dim=-1)
        
        out = tc.einsum('njf,kif->njki', data_w_bias, self.w)
        return out.view(self.out_shape)

def env_linear(params: jnp.array, data: jnp.array):
    n_samples, n_spin_det = data.shape[:2]
    
    bias = jnp.ones((n_samples, n_spin_det, 1))
    data = jnp.concatenate((data, bias), axis=-1)
    return jnp.einsum('njf,kif->njki', data, params)

w = np.random.normal(0, 1, (n_determinants, n_up, n_sh+1))
data = np.random.normal(0, 1, (n_samples, n_up, n_sh))
data_tc = tc.from_numpy(data)
w_tc = tc.from_numpy(w)

env_lin_tc = EnvelopeLinear(w_tc, n_sh, n_up, n_samples, n_determinants)

e_tc = env_lin_tc(data_tc)
e = env_linear(w, data)

compare(e_tc, e)


abs sum diff 0.00000000, sum diff 0.00000000


DeviceArray(3.30568906e-14, dtype=float64)

In [71]:
def mixer(single: jnp.array, 
          pairwise: jnp.array,
          n_electrons,
          n_up,
          n_down,
          single_up_mask,
          single_down_mask,
          pairwise_up_mask,
          pairwise_down_mask):
    # single (n_samples, n_electrons, n_single_features)
    # pairwise (n_samples, n_electrons, n_pairwise_features)
    # spin_up_mask = self.spin_up_mask.repeat((n_samples, 1, 1))
    # spin_down_mask = self.spin_down_mask.repeat((n_samples, 1, 1))

    # --- Single summations
    # up
    sum_spin_up = single_up_mask * single
    sum_spin_up = jnp.sum(sum_spin_up, axis=1, keepdims=True) / n_up
    sum_spin_up = jnp.repeat(sum_spin_up, n_electrons, axis=1)  # not needed in split

    # down
    sum_spin_down = single_down_mask * single
    sum_spin_down = jnp.sum(sum_spin_down, axis=1, keepdims=True) / n_down
    sum_spin_down = jnp.repeat(sum_spin_down, n_electrons, axis=1) # not needed in split

    # --- Pairwise summations
    sum_pairwise = jnp.repeat(jnp.expand_dims(pairwise, axis=1), n_electrons, axis=1)

    # up
    sum_pairwise_up = pairwise_up_mask * sum_pairwise
    sum_pairwise_up = jnp.sum(sum_pairwise_up, axis=2) / n_up

    # down
    sum_pairwise_down = pairwise_down_mask * sum_pairwise
    sum_pairwise_down = jnp.sum(sum_pairwise_down, axis=2) / n_down

    features = jnp.concatenate((single, sum_pairwise_up, sum_pairwise_down, sum_spin_up, sum_spin_down), axis=2)
    # split = jnp.concatenate((sum_spin_up, sum_spin_down), axis=2)
    return features


In [89]:
def create_masks(n_atom, n_up, n_layers, n_sh, n_ph):
    n_sh_in = 4 * n_atom
    n_ph_in = 4
    
    masks = [create_masks_layer(n_sh_in, n_ph_in, n_electrons, n_up)]
    
    for i in range(n_layers):
        masks.append(create_masks_layer(n_sh, n_ph, n_electrons, n_up))
    
    return masks

def create_masks_layer(n_sh, n_ph, n_electrons, n_up):
    
    # single spin masks
    n_down = n_electrons - n_up
    n_pairwise = n_electrons**2

    tmp1 = jnp.ones((1, n_up, n_sh))
    tmp2 = jnp.zeros((1, n_down, n_sh))
    single_up_mask = jnp.concatenate((tmp1, tmp2), axis=1)
    single_down_mask = (jnp.concatenate((tmp1, tmp2), axis=1)-1.)*-1.

    # pairwise spin masks
    ups = np.ones(n_electrons)
    ups[n_up:] = 0
    downs = (ups-1.)*-1.

    pairwise_up_mask = []
    pairwise_down_mask = []
    mask = np.zeros((n_electrons, n_electrons))

    for electron in range(n_electrons):
        e_mask_up = np.zeros((n_electrons,))
        e_mask_down = np.zeros((n_electrons,))

        mask_up = np.copy(mask)
        mask_up[electron, :] = ups
        mask_up = mask_up.reshape(-1)

        mask_down = np.copy(mask)
        mask_down[electron, :] = downs
        mask_down = mask_down.reshape(-1)

        pairwise_up_mask.append(mask_up)
        pairwise_down_mask.append(mask_down)

    pairwise_up_mask = jnp.array(pairwise_up_mask).reshape((1, n_electrons, n_pairwise, 1))
    pairwise_up_mask = jnp.repeat(pairwise_up_mask, n_ph, axis=-1)

    pairwise_down_mask = jnp.array(pairwise_down_mask).reshape((1, n_electrons, n_pairwise, 1))
    pairwise_down_mask = jnp.repeat(pairwise_down_mask, n_ph, axis=-1)
    return single_up_mask, single_down_mask, pairwise_up_mask, pairwise_down_mask

def compute_ae_vectors(r_electrons: jnp.array, 
                       r_atoms: jnp.array) -> jnp.array:
    
    r_atoms = jnp.expand_dims(r_atoms, axis=1)
    r_electrons = jnp.expand_dims(r_electrons, axis=2)
    ae_vectors = r_electrons - r_atoms
    return ae_vectors

def compute_inputs(r_electrons, ae_vectors):
    n_samples, n_electrons, n_atoms = ae_vectors.shape[:3]

    ae_distances = jnp.linalg.norm(ae_vectors, axis=-1, keepdims=True)
    single_inputs = jnp.concatenate([ae_vectors, ae_distances], axis=-1)
    single_inputs = single_inputs.reshape(n_samples, n_electrons, 4 * n_atoms)
    
    re1 = jnp.expand_dims(r_electrons, axis=2)
    re2 = jnp.transpose(re1, [0, 2, 1, 3])
    
    ee_vectors = re1 - re2
    ee_distances = jnp.linalg.norm(ee_vectors, axis=-1, keepdims=True)
    pairwise_inputs = jnp.concatenate([ee_vectors, ee_distances], axis=-1)
    pairwise_inputs = pairwise_inputs.reshape(n_samples, n_electrons**2, 4)

    return single_inputs, pairwise_inputs

def linear(p: jnp.array, 
           data: jnp.array)-> jnp.array:
    bias = jnp.ones((*data.shape[:-1], 1))
    data = jnp.concatenate([data, bias], axis=-1)

    return jnp.tanh(jnp.dot(data, p))

def env_linear(params: jnp.array, 
               data: jnp.array) -> jnp.array:
    bias = jnp.ones((*data.shape[:-1], 1))
    data = jnp.concatenate((data, bias), axis=-1)
    
    return jnp.einsum('njf,kif->njki', data, params)
    
def env_sigma(sigma: jnp.array,
              ae_vectors: jnp.array) -> jnp.array:
    
    exponent = jnp.einsum('njmv,kimvc->njkimc', ae_vectors, sigma)
    return jnp.exp(-jnp.linalg.norm(exponent, axis=-1))
    
def env_pi(pi: jnp.array,
           factor: jnp.array,
           exponential: jnp.array) -> jnp.array:
    
    orbitals = factor * jnp.einsum('njkim,kim->njki', exponential, pi)
    return jnp.transpose(orbitals, [0, 2, 1, 3])

def logabssumdet(orb_up: jnp.array,
                 orb_down: jnp.array) -> jnp.array:
    
    s_up, log_up = jnp.linalg.slogdet(orb_up)
    s_down, log_down = jnp.linalg.slogdet(orb_down)
    
    logdet_sum = log_up + log_down
    logdet_max = jnp.max(logdet_sum)
    
    argument = s_up * s_down * jnp.exp(logdet_sum - logdet_max)
    
    return jnp.log(jnp.abs(jnp.sum(argument, axis=1))) + logdet_max

In [97]:
def count_mixed_features(n_sh, n_ph):
    #     n_sh_mix = 2 * n_ph + n_sh # change mixer
    return 3 * n_sh + 2 * n_ph

def initialise_params(key,
                      n_atom: int,
                      n_up: int,
                      n_down: int,
                      n_layers: int = 2,
                      n_sh: int = 16,
                      n_ph: int = 8,
                      n_det: int = 1):
    '''
    
    
    Notes:
    zip(*([iter(nums)]*2) nice idiom for iterating over in sets of 2
    '''
    
    # count the number of input features
    n_sh_in = 4 * n_atom
    n_ph_in = 4
    
    # count the features in the intermediate layers
    n_sh_mix = count_mixed_features(n_sh, n_ph)


    params = {'envelopes':{}}
    
    # initial layers
    key, subkey = rnd.split(key)
    params['s0'] = rnd.normal(subkey, (count_mixed_features(n_sh_in, n_ph_in) + 1, n_sh))
    
    key, subkey = rnd.split(key)
    params['p0'] = rnd.normal(subkey, (n_ph_in + 1, n_ph)) 
    
    # intermediate layers
    key, *subkeys = rnd.split(key, num=(n_layers*2))
    params['intermediate'] = [[rnd.normal(sk1, (n_sh_mix + 1, n_sh)), rnd.normal(sk2, (n_ph + 1, n_ph))] 
                              for sk1, sk2 in zip(*([iter(subkeys)]*2))]
    
    # env_linear
    key, *subkeys = rnd.split(key, num=3)
    params['envelopes']['linear'] = [rnd.normal(subkeys[0], (n_det, n_up, n_sh + 1)),
                                     rnd.normal(subkeys[1], (n_det, n_down, n_sh + 1))]

    # env_sigma
    key, *subkeys = rnd.split(key, num=3)
    params['envelopes']['sigma'] = [rnd.normal(subkeys[0], (n_det, n_up, n_atom, 3, 3)),
                                    rnd.normal(subkeys[1], (n_det, n_down, n_atom, 3, 3))]
    
    # env_pi
    key, *subkeys = rnd.split(key, num=3)
    params['envelopes']['pi'] = [rnd.normal(subkeys[0], (n_det, n_up, n_atom)),
                                 rnd.normal(subkeys[1], (n_det, n_down, n_atom))]
    
    return params


def model(params, r_electrons, r_atoms, masks):
    n_samples, n_electrons = r_electrons.shape[:2]
    n_atoms = r_atoms.shape[1]
    
    # ae_vectors
    ae_vectors = compute_ae_vectors(r_electrons, r_atoms)
    
    # compute the inputs
    single, pairwise = compute_inputs(r_electrons, ae_vectors)
    print(single.shape, pairwise.shape)
    
    # mix the inputs
    print([x.shape for x in masks[0]])
    single = mixer(single, pairwise, n_electrons, n_up, n_down, *masks[0])
    
    # initial streams s0 and p0
    print(single.shape, pairwise.shape, params['s0'].shape, params['p0'].shape)
    single = linear(params['s0'], single)
    pairwise = linear(params['p0'], pairwise)
    
    # intermediate layers including mix
    for (s_params, p_params), mask in zip(params['intermediate'], masks[1:]):
        print(s_params.shape, p_params.shape, [x.shape for x in mask])
        
        single_mixed = mixer(single, pairwise, n_electrons, n_up, n_down, *mask)
        print(single_mixed.shape)
        
        single = linear(s_params, single_mixed) + single
        pairwise = linear(p_params, pairwise) + pairwise
        print(single.shape, pairwise.shape)
    
    # split
    ae_up, ae_down = jnp.split(ae_vectors, [n_up], axis=1)
    data_up, data_down = jnp.split(single, [n_up], axis=1)
    
    # envelopes
    # linear
    factor_up = env_linear(params['envelopes']['linear'][0], data_up)
    factor_down = env_linear(params['envelopes']['linear'][1], data_down)
    
    # sigma
    exp_up = env_sigma(params['envelopes']['sigma'][0], ae_up)
    exp_down = env_sigma(params['envelopes']['sigma'][1], ae_down)
    
    # pi
    print(exp_up.shape, exp_down.shape, params['envelopes']['pi'][0].shape, params['envelopes']['pi'][1].shape)
    orb_up = env_pi(params['envelopes']['pi'][0], factor_up, exp_up)
    orb_down = env_pi(params['envelopes']['pi'][1], factor_down, exp_down)
        
    # logabssumdet
    print(orb_up.shape, orb_down.shape)
    log_psi = logabssumdet(orb_up, orb_down)
    
    return jnp.sum(log_psi)

def create_atom_batch(r_atoms, n_samples):
    return jnp.repeat(jnp.expand_dims(r_atoms, axis=0), n_samples, axis=0)



(100, 5, 3) (100, 2, 3)
(100, 5, 8) (100, 25, 4)
[(1, 5, 8), (1, 5, 8), (1, 5, 25, 4), (1, 5, 25, 4)]
(100, 5, 32) (100, 25, 4) (33, 20) (5, 10)
(81, 20) (11, 10) [(1, 5, 20), (1, 5, 20), (1, 5, 25, 10), (1, 5, 25, 10)]
(100, 5, 80)
(100, 5, 20) (100, 25, 10)
(100, 3, 5, 3, 2) (100, 2, 5, 2, 2) (5, 3, 2) (5, 2, 2)
(100, 5, 3, 3) (100, 5, 2, 2)
(100, 5, 8) (100, 25, 4)
[(1, 5, 8), (1, 5, 8), (1, 5, 25, 4), (1, 5, 25, 4)]
(100, 5, 32) (100, 25, 4) (33, 20) (5, 10)
(81, 20) (11, 10) [(1, 5, 20), (1, 5, 20), (1, 5, 25, 10), (1, 5, 25, 10)]
(100, 5, 80)
(100, 5, 20) (100, 25, 10)
(100, 3, 5, 3, 2) (100, 2, 5, 2, 2) (5, 3, 2) (5, 2, 2)
(100, 5, 3, 3) (100, 5, 2, 2)


TypeError: Cannot interpret '1' as a data type

In [50]:
key, *subkeys = rnd.split(key, num=8)
for k in subkeys[::2]:
    print(k)
    
nums = (669256.02, 6117662.09, 669258.61, 6117664.39, 669258.05, 6117665.08)
ls = [iter(nums)]*2
for x, y in zip(*[iter(nums)]*2):
        print(x, y)
        
print(list([list(l) for l in ls]))      
list(zip(*([iter(nums)]*2) ))

def create_iterator(lst, n=2):
#     for a, b in zip(*([iter(nums)]*n):  # hacky way to do this
    it = iter(lst)
    for a, b in zip(it, it):
        print(a, b) # goes in sets of 2

[3818937504 2085278926]
[ 798071906 3899491848]
[1393771798 3194384582]
[ 175757679 2955691801]
669256.02 6117662.09
669258.61 6117664.39
669258.05 6117665.08
[[669256.02, 6117662.09, 669258.61, 6117664.39, 669258.05, 6117665.08], []]


SyntaxError: can't use starred expression here (<ipython-input-50-c88e322d9ef0>, line 12)

In [34]:
import functools
x = np.random.normal(0, 1, (n_samples, n_determinants, n_up, n_up))
y = np.random.normal(0, 1, (n_samples, n_determinants, n_down, n_down))
print(x.shape, y.shape)
z = jnp.linalg.slogdet(x)
xs = [x, y]
slogdets = [jnp.linalg.slogdet(x) for x in xs]
sign_in, logdet = functools.reduce(
      lambda a, b: (a[0]*b[0], a[1]+b[1]), slogdets)

logdet.shape
slogdets[0]

(1000, 2, 3, 3) (1000, 2, 2, 2)


(Buffer([[ 1., -1.],
         [ 1.,  1.],
         [ 1.,  1.],
         ...,
         [-1., -1.],
         [-1., -1.],
         [ 1., -1.]], dtype=float64),
 Buffer([[-2.82008323, -0.09710386],
         [-2.47097002, -3.10294813],
         [ 1.54502232,  1.7057895 ],
         ...,
         [-0.41479045,  0.7050965 ],
         [-0.71533259,  1.07479973],
         [ 0.19064791,  0.66715483]], dtype=float64))

In [None]:
# import random
# import itertools

# import jax
# import jax.numpy as np
# # Current convention is to import original numpy as "onp"
# import numpy as onp

# Sigmoid nonlinearity
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# Computes our network's output
def net(params, x):
    w1, b1, w2, b2 = params
    hidden = np.tanh(np.dot(w1, x) + b1)
    return sigmoid(np.dot(w2, hidden) + b2)

# Cross-entropy loss
def loss(params, x, y):
    out = net(params, x)
    cross_entropy = -y * np.log(out) - (1 - y)*np.log(1 - out)
    return cross_entropy

# Utility function for testing whether the net produces the correct
# output for all possible inputs
def test_all_inputs(inputs, params):
    predictions = [int(net(params, inp) > 0.5) for inp in inputs]
    for inp, out in zip(inputs, predictions):
        print(inp, '->', out)
    return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])


def initial_params():
    return [
        onp.random.randn(3, 2),  # w1
        onp.random.randn(3),  # b1
        onp.random.randn(3),  # w2
        onp.random.randn(),  #b2
    ]

loss_grad = jax.grad(loss)

# Stochastic gradient descent learning rate
learning_rate = 1.
# All possible inputs
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])

# Initialize parameters randomly
params = initial_params()

for n in itertools.count():
    # Grab a single random input
    x = inputs[onp.random.choice(inputs.shape[0])]
    # Compute the target output
    y = onp.bitwise_xor(*x)
    # Get the gradient of the loss for this input/output pair
    grads = loss_grad(params, x, y)
    # Update parameters via gradient descent
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    # Every 100 iterations, check whether we've solved XOR
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

In [None]:
# randomness
key = rnd.PRNGKey(1)
print(key)
key, subkey = rnd.split(key)
print(key)
print(subkey)
key, *subkeys = rnd.split(key, num=3)
print(key, subkeys)
rnd.split(subkeys[0])[0]

In [None]:
# a bias layer 
m = 2
ne = 3
v = 4
data = rnd.normal(key, (m, ne, v))
p = rnd.normal(key, (v + 1, v))
print(data)
bias = jnp.ones((*data.shape[:-1], 1))
data = jnp.concatenate([data, bias], axis=-1)
print(data.shape, p.shape)
out = jnp.dot(data, p)
print(out.shape)

In [9]:
# jax einsum
m, ne, v = 2, 5, 10
key = rnd.PRNGKey(1)
x = rnd.normal(key, (m, ne, v))
y = rnd.normal(key, (m, ne, v))
z = jnp.einsum('ijk,imn->jkmn', x, y)
print(z.shape)

(5, 10, 5, 10)


In [None]:
# block until ready only works on a single output
import jax.numpy as jnp
from jax import jit

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0, x

x = jnp.ones((5000, 5000))
%timeit slow_f(x).block_until_ready()

fast_f = jit(slow_f)

# Results are the same
assert jnp.allclose(slow_f(x), fast_f(x))

%timeit fast_f(x).block_until_ready()

In [None]:
# Notes
# By default, jax.grad will find the gradient with respect to the first argument.
# To find the gradient with respect to a different argument (or several), you can set argnums
# jax.grad(sum_squared_error, argnums=(0, 1))(x, y)  # Find gradient wrt both x & y

In [21]:
#
re.shape
x = jnp.repeat(re, 2, axis=1)

In [25]:
x.shape
re.shape

(1000, 5, 3)