In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('/home/xmax/projects/nn_ansatz/src')

In [63]:
from jax import random as rnd
from collections import OrderedDict

key = rnd.PRNGKey(1)
n_atom, n_up = 1, 1
in_lin, out_lin = 3, 10
n_det = 2
n_walkers = 2

walkers = rnd.normal(key, (n_walkers, 1, 3))
r_atoms = jnp.array([[0.0, 0.0, 0.0]])
params = OrderedDict()
params['linear'] = rnd.normal(key, (in_lin, out_lin))
params['env_linear'] = rnd.normal(key, (n_det, n_up, out_lin))
params['env_sigma'] = rnd.normal(key, (n_det, n_up, n_atom, 3, 3))
params['env_pi'] = rnd.normal(key, (n_det, n_up, n_atom))


def compute_ae_vectors_i(walkers: jnp.array,
                         r_atoms: jnp.array) -> jnp.array:
    r_atoms = jnp.expand_dims(r_atoms, axis=0)
    walkers = jnp.expand_dims(walkers, axis=1)
    ae_vectors = walkers - r_atoms
    return ae_vectors

def slogdet(x):
  if x.shape[-1] == 1:
    sign = jnp.sign(x[..., 0, 0])
    logdet = jnp.log(jnp.abs(x[..., 0, 0]))
  else:
    sign, logdet = jnp.linalg.slogdet(x)

  return sign, logdet


def logdet_matmul(xs) -> jnp.ndarray:
  dets = [x.reshape(*x.shape[:2], -1) for x in xs if x.shape[-1] == 1]
  dets = functools.reduce(
    lambda a, b: a*b, dets
  ) if len(dets) > 0 else 1.

  slogdets = [slogdet(x) for x in xs if x.shape[-1] > 1]
  maxlogdet = 0
  if len(slogdets) > 0:
    sign_in, logdet = functools.reduce(
      lambda a, b: (a[:, 0]*b[:, 0], a[:, 1]+b[:, 1]), slogdets
    )

    maxlogdet = jnp.max(logdet)
    det = sign_in * dets * jnp.exp(logdet - maxlogdet)
  else:
    det = dets

  result = jnp.sum(det, axis=1)

  sign_out = jnp.sign(result)
  log_out = jnp.log(jnp.abs(result)) + maxlogdet
  return log_out


def second_derivative(wf):

    def _lapl_over_f(params, walkers):
        walkers = walkers.reshape(n_walkers, -1)
        n = walkers.shape[-1]
        eye = jnp.eye(n, dtype=walkers.dtype)[None, ...].repeat(n_walkers, axis=0)
        grad_f = jax.grad(wf, argnums=1)
        grad_f_closure = lambda y: grad_f(params, y) 

        def _body_fun(i, val):
            print(walkers.shape, eye.shape, eye[..., i].shape)
            primal, tangent = jax.jvp(grad_f_closure, (walkers,), (eye[..., i],))
            return val + primal[:, i]**2 + tangent[:, i]

        return -0.5 * lax.fori_loop(0, n, _body_fun, 0.0)
    return _lapl_over_f


def create_wf(r_atoms):

    def _wf(params, walkers):

        

        ae_vectors = walkers[:, None, ...] - r_atoms[None, None, ...] 

        linear = jnp.tanh(walkers @ params['linear'])
        
        print(linear.shape)
        factor = jnp.einsum('njf,kif->nkij', linear, params['env_linear'])

        argument = jnp.einsum('njmv,kimvc->njkimc', ae_vectors, params['env_sigma'])
        print(argument.shape)
        exponent = jnp.linalg.norm(argument, axis=-1)
        exponential = jnp.exp(-exponent)
        print(exponential.shape)
        orbitals = factor * jnp.einsum('njkim,kim->nkij', exponential, params['env_pi']) 
        print(orbitals.shape)
        log_psi = logdet_matmul([orbitals, orbitals])

        return log_psi.squeeze()
    
    return _wf

wf = create_wf(r_atoms)
log_psi = wf(params, walkers)
print(log_psi.shape)
grad = second_derivative(wf)

grad(params, walkers)








(2, 1, 10)
(2, 1, 2, 1, 1, 3)
(2, 1, 2, 1, 1)
(2, 2, 1, 1)
(2,)
(2, 3) (2, 3, 3) (2, 3)
(2, 10)


ValueError: Einstein sum subscript 'njf' does not contain the correct number of indices for operand 0.

In [38]:
from nn_ansatz import *
from jax import vmap
import jax.numpy as jnp
from jax import vmap, lax
import numpy as np
import functools


lr, damping, nc = 1e-4, 1e-4, 1e-4
n_pre_it = 500
n_walkers = 512
n_layers = 2
n_sh = 64
n_ph = 16
n_det = 8
n_it = 1000
seed = 1


cfg = setup(system='LiSolid',
               n_pre_it=500,
               n_walkers=8,
               n_layers=2,
               n_sh=16,
               n_ph=4,
               opt='kfac',
               n_det=4,
               print_every=1,
               save_every=5000,
               lr=lr,
               n_it=1000,
               norm_constraint=nc,
               damping=damping)

mol = SystemAnsatz(**cfg)

def create_wf(mol):
    
    n_up, n_down, r_atoms, n_el, min_cell_width = mol.n_up, mol.n_down, mol.r_atoms, mol.n_el, mol.min_cell_width
    masks = create_masks(mol.n_atoms, mol.n_el, mol.n_up, mol.n_layers, mol.n_sh, mol.n_ph)

    compute_inputs_i = create_compute_inputs_i(mol)

    def _wf_orbitals(params, walkers, d0s):

        if len(walkers.shape) == 1:  # this is a hack to get around the jvp
            walkers = walkers.reshape(n_up + n_down, 3)

        activations = []

        ae_vectors = compute_ae_vectors_i(walkers, r_atoms)

        single, pairwise = compute_inputs_i(walkers, ae_vectors)

        single_mixed, split = mixer_i(single, pairwise, n_el, n_up, n_down, *masks[0])

        split = linear_split(params['split0'], split, activations, d0s['split0'])
        single = linear(params['s0'], single_mixed, split, activations, d0s['s0'])
        pairwise = linear_pairwise(params['p0'], pairwise, activations, d0s['p0'])

        for (split_params, s_params, p_params), (split_per, s_per, p_per), mask \
                in zip(params['intermediate'], d0s['intermediate'], masks[1:]):
            single_mixed, split = mixer_i(single, pairwise, n_el, n_up, n_down, *mask)

            split = linear_split(split_params, split, activations, split_per)
            single = linear(s_params, single_mixed, split, activations, s_per) + single
            pairwise = linear_pairwise(p_params, pairwise, activations, p_per) + pairwise

        ae_up, ae_down = jnp.split(ae_vectors, [n_up], axis=0)
        data_up, data_down = jnp.split(single, [n_up], axis=0)

        factor_up = env_linear_i(params['envelopes']['linear'][0], data_up, activations, d0s['envelopes']['linear'][0])
        factor_down = env_linear_i(params['envelopes']['linear'][1], data_down, activations, d0s['envelopes']['linear'][1])

        exp_up = env_sigma_i(params['envelopes']['sigma']['up'], ae_up, activations, d0s['envelopes']['sigma']['up'], min_cell_width)
        exp_down = env_sigma_i(params['envelopes']['sigma']['down'], ae_down, activations, d0s['envelopes']['sigma']['down'], min_cell_width)

        orb_up = env_pi_i(params['envelopes']['pi'][0], factor_up, exp_up, activations, d0s['envelopes']['pi'][0])
        orb_down = env_pi_i(params['envelopes']['pi'][1], factor_down, exp_down, activations, d0s['envelopes']['pi'][1])
        return orb_up, orb_down, activations

    def _wf(params, walkers, d0s):

        orb_up, orb_down, _ = _wf_orbitals(params, walkers, d0s)
        log_psi = logdet_matmul([orb_up, orb_up])
        return log_psi

    def _orbs(params, walkers, d0s):

        orb_up, orb_down, _ = _wf_orbitals(params, walkers, d0s)
        return orb_up, orb_down

    return _wf, _orbs

def compute_ae_vectors_i(walkers: jnp.array,
                         r_atoms: jnp.array) -> jnp.array:
    r_atoms = jnp.expand_dims(r_atoms, axis=0)
    walkers = jnp.expand_dims(walkers, axis=1)
    ae_vectors = walkers - r_atoms
    return ae_vectors

def create_compute_inputs_i(mol):

    def _compute_ee_vectors(walkers):
        re1 = jnp.expand_dims(walkers, axis=1)
        re2 = jnp.transpose(re1, [1, 0, 2])
        ee_vectors = re1 - re2
        return ee_vectors

    def _compute_ee_vectors_periodic(walkers):
        unit_cell_walkers = walkers.dot(mol.inv_real_basis)  # translate to the unit cell
        unit_cell_ee_vectors = _compute_ee_vectors(unit_cell_walkers)
        min_image_unit_cell_ee_vectors = unit_cell_ee_vectors - (2 * unit_cell_ee_vectors).astype(int) * 1. 
        min_image_ee_vectors = min_image_unit_cell_ee_vectors.dot(mol.real_basis)
        return min_image_ee_vectors

    compute_ee_vectors = _compute_ee_vectors_periodic

    def compute_inputs_i(walkers, ae_vectors):

        n_electrons, n_atoms = ae_vectors.shape[:2]

        ae_distances = jnp.linalg.norm(ae_vectors, axis=-1, keepdims=True)
        single_inputs = jnp.concatenate([ae_vectors, ae_distances], axis=-1)
        single_inputs = single_inputs.reshape(n_electrons, 4 * n_atoms)

        ee_vectors = compute_ee_vectors(walkers)
        ee_vectors = drop_diagonal_i(ee_vectors)
        ee_distances = jnp.linalg.norm(ee_vectors, axis=-1, keepdims=True)
        pairwise_inputs = jnp.concatenate([ee_vectors, ee_distances], axis=-1)

        return single_inputs, pairwise_inputs

    return compute_inputs_i

def env_linear_i(params: jnp.array,
                 data: jnp.array,
                 activations: list,
                 d0: jnp.array) -> jnp.array:
    n_spins = data.shape[0]

    bias = jnp.ones((n_spins, 1))
    activation = jnp.concatenate((data, bias), axis=1)
    activations.append(activation)
    pre_activations = jnp.matmul(activation, params) + d0
    pre_activations = jnp.transpose(pre_activations).reshape(-1, n_spins, n_spins)

    return pre_activations

def env_sigma_i(sigmas: jnp.array,
                ae_vectors: jnp.array,
                activations: list,
                d0s: jnp.array,
                min_cell_width: float) -> jnp.array:
    
    # SIGMA BROADCAST VERSION
    n_spin, n_atom, _ = ae_vectors.shape
    ae_vectors = [jnp.squeeze(x) for x in jnp.split(ae_vectors, n_atom, axis=1)]
    outs = []
    for ae_vector, sigma, d0 in zip(ae_vectors, sigmas, d0s):
        activations.append(ae_vector)
        
        # pre_activation = jnp.matmul(ae_vector, sigma) + d0
        # exponent = pre_activation.reshape(n_spin, 3, -1, n_spin, 1, order='F')
        # exponent = jnp.linalg.norm(exponent, axis=1)
        # out = jnp.exp(-exponent)

        pre_activation = jnp.matmul(ae_vector, sigma) + d0
        exponent = pre_activation.reshape(n_spin, 3, -1, n_spin, 1, order='F')
        exponent = jnp.linalg.norm(exponent, axis=1)
        out = jnp.exp(-exponent)

        mask = lax.stop_gradient((exponent < min_cell_width / 2.).astype(out.dtype))
        print(mask)

        out = out * mask 
        # out = jnp.where(exponent < min_cell_width / 2., out, jnp.zeros_like(out))

        outs.append(out)
    return jnp.concatenate(outs, axis=-1)


def slogdet(x):
  if x.shape[-1] == 1:
    sign = jnp.sign(x[..., 0, 0])
    logdet = jnp.log(jnp.abs(x[..., 0, 0]))
  else:
    sign, logdet = jnp.linalg.slogdet(x)

  return sign, logdet


def logdet_matmul(xs) -> jnp.ndarray:
  dets = [x.reshape(-1) for x in xs if x.shape[-1] == 1]
  dets = functools.reduce(
    lambda a, b: a*b, dets
  ) if len(dets) > 0 else 1.

  slogdets = [slogdet(x) for x in xs if x.shape[-1] > 1]
  maxlogdet = 0
  if len(slogdets) > 0:
    sign_in, logdet = functools.reduce(
      lambda a, b: (a[0]*b[0], a[1]+b[1]), slogdets
    )

    maxlogdet = jnp.max(logdet)
    det = sign_in * dets * jnp.exp(logdet - maxlogdet)
  else:
    det = dets

  result = jnp.sum(det)

  sign_out = jnp.sign(result)
  log_out = jnp.log(jnp.abs(result)) + maxlogdet
  return log_out

def local_kinetic_energy_i(wf):

    def _lapl_over_f(params, walkers, d0s):
        walkers = walkers.reshape(-1)
        n = walkers.shape[0]
        eye = jnp.eye(n, dtype=walkers.dtype)
        grad_f = jax.grad(wf, argnums=1)
        grad_f_closure = lambda y: grad_f(params, y, d0s) 

        def _body_fun(i, val):

            primal, tangent = jax.jvp(grad_f_closure, (walkers,), (eye[..., i],))
            return val + primal[i]**2 + tangent[i]

        return -0.5 * lax.fori_loop(0, n, _body_fun, 0.0)

    return _lapl_over_f

def create_energy_fn(wf, mol):

    r_atoms, z_atoms = mol.r_atoms, mol.z_atoms

    local_kinetic_energy = vmap(local_kinetic_energy_i(wf), in_axes=(None, 0, 0))

    def _compute_local_energy(params, walkers, d0s):
        kinetic_energy = local_kinetic_energy(params, walkers, d0s)
        return kinetic_energy

    return _compute_local_energy

wf, orbs = create_wf(mol)

compute_energy = vmap(create_energy_fn(wf, mol), in_axes=(None, 0, 0))  # this replaces the pmap!

key = rnd.PRNGKey(1)
params = initialise_params(key, mol)
d0s = initialise_d0s(mol, cfg['n_devices'], cfg['n_walkers_per_device'])
walkers = mol.initialise_walkers(walkers=None, **cfg)

# e_locs = compute_energy(params, walkers, d0s)

orbs = vmap(vmap(orbs, in_axes=(None, 0, 0)), in_axes=(None, 0, 0))
ups, downs = orbs(params, walkers, d0s)

print(ups)


print(downs)

# print(e_locs)




               


version 		 090621
seed 		 369
n_devices 		 1
save_every 		 5000
print_every 		 1
exp_dir 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/nans/experiments/LiSolid/junk/kfac_1lr-4_1d-4_1nc-4_m8_s16_p4_l2_det4/run0
events_dir 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/nans/experiments/LiSolid/junk/kfac_1lr-4_1d-4_1nc-4_m8_s16_p4_l2_det4/run0/events
models_dir 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/nans/experiments/LiSolid/junk/kfac_1lr-4_1d-4_1nc-4_m8_s16_p4_l2_det4/run0/models
opt_state_dir 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/nans/experiments/LiSolid/junk/kfac_1lr-4_1d-4_1nc-4_m8_s16_p4_l2_det4/run0/models/opt_state
pre_path 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/nans/experiments/LiSolid/pretrained/s16_p4_l2_det4_1lr-4_i500.pk
timing_dir 		 /home/xmax/projects/nn_ansatz/src/scripts/debugging/nans/experiments/LiSolid/junk/kfac_1lr-4_1d-4_1nc-4_m8_s16_p4_l2_det4/run0/events/timing
system 		 LiSolid
r_atoms 		 [[0.1666666

In [37]:
ups.shape
downs.shape
downs[0, -1, :, ...]

DeviceArray([[[ 0.]],

             [[-0.]]], dtype=float32)

In [None]:



# inputs 

# layers

# generate orbitals

# logabssumdet

# take derivatives








In [43]:
import jax.numpy as jnp
import jax.random as rnd
from jax import vmap, lax, jit
import jax
key = rnd.PRNGKey(0)

In [10]:
x = rnd.normal(key, (10,3))
dtype = x.dtype
x.astype(dtype)
y = jnp.array([True, True]).astype(dtype)
y

DeviceArray([1., 1.], dtype=float32)

In [None]:
activations.append(ae_vector)
        
pre_activation = jnp.matmul(ae_vector, sigma) + d0
exponent = pre_activation.reshape(n_spin, 3, -1, n_spin, 1, order='F')
exponent = jnp.linalg.norm(exponent, axis=1)
out = jnp.exp(-exponent)

mask = (exponent < min_cell_width / 2.).astype(out.dtype)
out = out * mask 

# out = jnp.where(exponent < min_cell_width / 2., out, jnp.zeros_like(out))

In [68]:
L = 3.
x = rnd.normal(key, (10,3))
z = jnp.linalg.norm(x, axis=-1)
w = rnd.normal(key, (3, 1))
mask = z < L
print(jnp.sum(mask))

def linear(x):
    z = x @ w
    return jnp.tanh(z)

def wf(x):
    z = jnp.linalg.norm(w.squeeze() * x, axis=-1)
    out = linear(x) * jnp.exp(-z)
    mask = (z < (L / 2.)).astype(x.dtype)
    # assert jnp.sum(mask) > 0.
    out = mask * out
    out = jnp.tanh(linear(x) * out)
    return out.squeeze()

# second order derivs

def second(wf):
    def _lapl_over_f(x):
        n = x.shape[0]
        print(n)
        eye = jnp.eye(n, dtype=x.dtype)
        print(eye.shape)
        grad_f = jax.grad(wf)
        # grad_f_closure = lambda y: grad_f(y)  # ensuring the input can be just x

        def _body_fun(i, val):
            # primal is the first order evaluation
            # tangent is the second order
            primal, tangent = jax.jvp(grad_f, (x,), (eye[:, i],))
            print(i, primal.shape, tangent.shape)
            return val + primal[i]**2 + tangent[i]
        return -0.5 * lax.fori_loop(0, n, _body_fun, 0.0)

    return vmap(_lapl_over_f, in_axes=(0,))

grad = second(wf)
grad(x)


10
3
(3, 3)
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=2/0)> (3,) (3,)


DeviceArray([ 1.0752215 , -0.        ,  0.59370816,  0.8124089 ,
             -0.        ,  1.5249966 ,  1.1376866 , -0.93327606,
             -0.        ,  0.10634256], dtype=float32)

In [61]:

n = x.shape[1]
print(n)
eye = jnp.eye(n, dtype=x.dtype)
print(eye.shape)
grad_f = jax.grad(wf)
# grad_f_closure = lambda y: grad_f(y)  # ensuring the input can be just x

i = 0
val = 0.0
primal, tangent = jax.jvp(grad_f, (x,), (eye[None, :, i].repeat(x.shape[0], axis=0),))
print(i)
print(primal, '\n', tangent, '\n', mask)

d = val + primal[:, i]**2 + tangent[:, i]

# _body_fun(0, 0.0)


# -0.5 * lax.fori_loop(0, n, _body_fun, 0.0)

3
(3, 3)


TypeError: Gradient only defined for scalar-output functions. Output had shape: (10, 10).

In [59]:
x = x[0]
eye = jnp.eye(x.shape[0], dtype=x.dtype)

def _body_fun(i, val):
    # primal is the first order evaluation
    # tangent is the second order
    primal, tangent = jax.jvp(wf, (x,), (eye[:, i],))
    print(i, primal.shape, tangent.shape)
    return val + primal[i]**2 + tangent[i]

_body_fun(0, 0.0)

IndexError: Too many indices for array: 1 non-None/Ellipsis indices for dim 0.

In [27]:

def create_masks(n_atom, n_electrons, n_up, n_layers, n_sh, n_ph):

    n_sh_in, n_ph_in = 4 * n_atom, 4

    masks = [create_masks_layer(n_sh_in, n_ph_in, n_electrons, n_up)]

    for i in range(n_layers):
        masks.append(create_masks_layer(n_sh, n_ph, n_electrons, n_up))

    return masks


def create_masks_layer(n_sh, n_ph, n_electrons, n_up):
    # single spin masks
    eye_mask = ~np.eye(n_electrons, dtype=bool)
    n_down = n_electrons - n_up
    n_pairwise = n_electrons ** 2 - n_electrons

    tmp1 = jnp.ones((n_up, n_sh))
    tmp2 = jnp.zeros((n_down, n_sh))
    single_up_mask = jnp.concatenate((tmp1, tmp2), axis=0)
    single_down_mask = (jnp.concatenate((tmp1, tmp2), axis=0) - 1.) * -1.

    # pairwise spin masks
    ups = np.ones(n_electrons)
    ups[n_up:] = 0
    downs = (ups - 1.) * -1.

    pairwise_up_mask = []
    pairwise_down_mask = []
    mask = np.zeros((n_electrons, n_electrons))

    for electron in range(n_electrons):
        mask_up = np.copy(mask)
        mask_up[electron, :] = ups
        mask_up = mask_up[eye_mask].reshape(-1)

        mask_down = np.copy(mask)
        mask_down[electron, :] = downs
        mask_down = mask_down[eye_mask].reshape(-1)

        pairwise_up_mask.append(mask_up)
        pairwise_down_mask.append(mask_down)

    pairwise_up_mask = jnp.array(pairwise_up_mask).reshape((n_electrons, n_pairwise, 1))
    pairwise_up_mask = jnp.repeat(pairwise_up_mask, n_ph, axis=-1)

    pairwise_down_mask = jnp.array(pairwise_down_mask).reshape((n_electrons, n_pairwise, 1))
    pairwise_down_mask = jnp.repeat(pairwise_down_mask, n_ph, axis=-1)
    return single_up_mask, single_down_mask, pairwise_up_mask, pairwise_down_mask

def drop_diagonal_i(square):
    n = square.shape[0]
    split1 = jnp.split(square, n, axis=0)
    upper = [jnp.split(split1[i], [j], axis=1)[1] for i, j in zip(range(0, n), range(1, n))]
    lower = [jnp.split(split1[i], [j], axis=1)[0] for i, j in zip(range(1, n), range(1, n))]
    arr = [ls[i] for i in range(n-1) for ls in (upper, lower)]
    result = jnp.concatenate(arr, axis=1)
    return jnp.squeeze(result)

def linear(p: jnp.array,
           data: jnp.array,
           split: jnp.array,
           activations: list,
           d0: jnp.array) -> jnp.array:

    bias = jnp.ones((*data.shape[:-1], 1))
    activation = jnp.concatenate([data, bias], axis=-1)
    activations.append(activation)

    pre_activation = jnp.dot(activation, p) + d0
    return jnp.tanh(pre_activation + split)

def linear_pairwise(p: jnp.array,
                    data: jnp.array,
                    activations: list,
                    d0: jnp.array) -> jnp.array:

    bias = jnp.ones((*data.shape[:-1], 1))
    activation = jnp.concatenate([data, bias], axis=-1)
    activations.append(activation)

    pre_activation = jnp.dot(activation, p) + d0
    return jnp.tanh(pre_activation)

def linear_split(p: jnp.array,
                 data: jnp.array,
                 activations: list,
                 d0: jnp.array) -> jnp.array:

    activation = data
    activations.append(activation)
    pre_activation = jnp.dot(data, p) + d0
    return pre_activation

def mixer_i(single: jnp.array,
            pairwise: jnp.array,
            n_el,
            n_up,
            n_down,
            single_up_mask,
            single_down_mask,
            pairwise_up_mask,
            pairwise_down_mask):
    # single (n_samples, n_el, n_single_features)
    # pairwise (n_samples, n_pairwise, n_pairwise_features)

    # --- Single summations
    # up
    sum_spin_up = single_up_mask * single
    sum_spin_up = jnp.sum(sum_spin_up, axis=0, keepdims=True) / float(n_up)
    #     sum_spin_up = jnp.repeat(sum_spin_up, n_el, axis=1)  # not needed in split

    # down
    sum_spin_down = single_down_mask * single
    sum_spin_down = jnp.sum(sum_spin_down, axis=0, keepdims=True) / float(n_down)
    #     sum_spin_down = jnp.repeat(sum_spin_down, n_el, axis=1) # not needed in split

    # --- Pairwise summations
    sum_pairwise = jnp.repeat(jnp.expand_dims(pairwise, axis=0), n_el, axis=0)

    # up
    sum_pairwise_up = pairwise_up_mask * sum_pairwise
    sum_pairwise_up = jnp.sum(sum_pairwise_up, axis=1) / float(n_up)

    # down
    sum_pairwise_down = pairwise_down_mask * sum_pairwise
    sum_pairwise_down = jnp.sum(sum_pairwise_down, axis=1) / float(n_down)

    single = jnp.concatenate((single, sum_pairwise_up, sum_pairwise_down), axis=1)
    split = jnp.concatenate((sum_spin_up, sum_spin_down), axis=1)
    return single, split

def env_pi_i(pis: jnp.array,
            factor: jnp.array,
            exponential: jnp.array,
            activations: list,
            d0s) -> jnp.array:

    n_spins, n_det = exponential.shape[:2]

    exponential = [jnp.squeeze(x, axis=(1, 2))
                   for y in jnp.split(exponential, n_spins, axis=2)
                   for x in jnp.split(y, n_det, axis=1)]

    [activations.append(x) for x in exponential]
    orbitals = jnp.stack([(e @ pi) + d0 for pi, e, d0 in zip(pis, exponential, d0s)], axis=-1)

    return factor * jnp.transpose(orbitals.reshape(n_spins, n_det, n_spins), (1, 2, 0))
