In [1]:
%load_ext autoreload
%autoreload 2
import os 
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.85'
os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [2]:
from itertools import product
import jax.numpy as jnp
import numpy as np


In [3]:
def generate_k_points(n_shells=3):
    img_range = jnp.arange(-3, 3+1)  # preset, when are we ever going to use more
    img_sets = jnp.array(list(product(*[img_range, img_range, img_range])))
    norms = jnp.linalg.norm(img_sets, axis=-1)
    idxs = jnp.argsort(norms)
    img_sets, norms = img_sets[idxs], norms[idxs]
    norm = 0.
    k_shells = {norm: [jnp.array([0.0, 0.0, 0.0])]}  # leacing the dictionary logic in case we ever need this data structure
    for k_point, norm_tmp in zip(img_sets[1:], norms[1:]):
        if norm_tmp > norm:
            if len(k_shells) == n_shells:
                break
            norm = norm_tmp
            k_shells[norm] = [k_point]
        else:
            if np.any([(k_point == x).all() for x in k_shells[norm]]):
                continue # because we include the opposite k_point in the sequence this statement avoids repeats
            k_shells[norm].append(k_point)
        k_shells[norm].append(-k_point)
    k_points = []
    for k, v in k_shells.items():
        for k_point in v:
            k_points.append(k_point)
    return jnp.array(k_points)

In [44]:
n_el = 7
k_points = generate_k_points(n_shells=2)
r = np.array(np.random.normal(0, 1, (n_el, 3)))
args = r @ k_points.T

pf = [np.cos, np.sin]
print(k_points[0, :].shape)
# [funcs.append(lambda r: pf[i % 2](r @ k_point)) for i, k_point in enumerate(k_points[1:])]

funcs = [lambda args: pf[i % 2](args) for i, _ in enumerate(k_points)]

# for i, k_point in enumerate(k_points[1:, :]):
#     print((r @ k_point).shape)


print(args.shape)
apply_vectorized = np.vectorize(lambda f, r: f(r), otypes=[object])
apply_vectorized(funcs, args)

(3,)
(7, 7)


array([[1.0, 0.6344616092018893, 0.6344616092018893, 0.811961856855451,
        0.811961856855451, 0.9937892417847294, 0.9937892417847294],
       [1.0, 0.6069575930046737, 0.6069575930046737, 0.704771573711915,
        0.704771573711915, 0.85344632729614, 0.85344632729614],
       [1.0, 0.36440243987712917, 0.36440243987712917,
        0.7286697981896225, 0.7286697981896225, 0.04283602839010839,
        0.04283602839010839],
       [1.0, 0.8575554997531586, 0.8575554997531586, -0.6190227647924236,
        -0.6190227647924236, 0.8642369404415623, 0.8642369404415623],
       [1.0, 0.9536559738890455, 0.9536559738890455, 0.8477255468018986,
        0.8477255468018986, 0.9973363532683985, 0.9973363532683985],
       [1.0, 0.9452302983213836, 0.9452302983213836, 0.9909773286726588,
        0.9909773286726588, 0.5557778346912491, 0.5557778346912491],
       [1.0, 0.9917832757505337, 0.9917832757505337, 0.9624410642087046,
        0.9624410642087046, 0.21549951494230224, 0.21549951494230224]

In [46]:
n_el = 7
k_points = generate_k_points(n_shells=2)
r = jnp.array(np.random.normal(0, 1, (n_el, 3)))
args = r @ k_points.T  # (n_el, n_el)


pf = [jnp.cos, jnp.sin]
print(k_points[0, :].shape)
# [funcs.append(lambda r: pf[i % 2](r @ k_point)) for i, k_point in enumerate(k_points[1:])]

funcs = [lambda args: pf[i % 2](args) for i, _ in enumerate(k_points)]

# for i, k_point in enumerate(k_points[1:, :]):
#     print((r @ k_point).shape

print(args.shape)
apply_vectorized = jnp.vectorize(lambda f, r: f(r))  # essentially a for loop
apply_vectorized(funcs, args)

(3,)
(7, 7)


TypeError: JAX only supports number and bool dtypes, got dtype object in array

In [20]:
n_el = 7
k_points = generate_k_points(n_shells=2)
r = jnp.array(np.random.normal(0, 1, (n_el, 3)))
args = r @ k_points.T  # (n_el, n_el)
args = jnp.split(args, n_el, axis=1)
pf = [jnp.cos, jnp.sin]
dets = []
for i, arg in enumerate(args):
    column = pf[i % 2](arg)
    dets.append(column)
dets = jnp.concatenate(dets, axis=-1)


print(dets.shape)
print(args[0].shape)
for det, arg in zip(dets, args):
    print(det, '\n')
    print(jnp.cos(arg).squeeze(), jnp.sin(arg).squeeze())



(7, 7)
(7, 1)
[ 1.         -0.98416805  0.17723767 -0.09954972  0.9950326  -0.03557487
  0.999367  ] 

[1. 1. 1. 1. 1. 1. 1.] [ 0.  0.  0.  0.  0. -0. -0.]
[ 1.          0.8215039   0.5702029  -0.27049577  0.96272117 -0.24291566
  0.9700474 ] 

[ 0.17723767  0.5702029  -0.46631873  0.88131523  0.3305313   0.20196325
  0.96472317] [-0.98416805  0.8215039  -0.8846168  -0.47252882  0.943795    0.9793931
  0.26326644]
[ 1.         -0.8846168  -0.46631873 -0.9913951   0.1309035   0.00277381
  0.9999961 ] 

[ 0.17723767  0.5702029  -0.46631873  0.88131523  0.3305313   0.20196325
  0.96472317] [ 0.98416805 -0.8215039   0.8846168   0.47252882 -0.943795   -0.9793931
 -0.26326644]
[ 1.         -0.47252882  0.88131523 -0.93845093  0.34541255 -0.39606145
  0.918224  ] 

[0.9950326  0.96272117 0.1309035  0.34541255 0.59707147 0.15365563
 0.8012553 ] [-0.09954972 -0.27049577 -0.9913951  -0.93845093 -0.80218804  0.98812443
  0.5983226 ]
[ 1.          0.943795    0.3305313  -0.80218804  0.59707147 -0.

In [18]:
n_el = 5
k_points = generate_k_points()
r = jnp.array(np.random.normal(0, 1, (n_el, 3)))
pf = [jnp.cos, jnp.sin]
funcs = [lambda r: k_points[0] @ r]
[funcs.append(lambda r: pf[i % 2](r @ k_point)) for i, k_point in enumerate(k_points[1:])]
apply_vectorized = jnp.vectorize(lambda f, r: f(r))
apply_vectorized(funcs, r)

# print(len(k_points))
# for i, k_point in enumerate(k_points[1:]):
#     print(k_point.shape)
#     funcs.append(lambda r: pf[i%2](r @ k_point))

print(funcs)

TypeError: JAX only supports number and bool dtypes, got dtype object in array