# My Utils

In [1]:
#|default_exp utils

In [2]:
#|export
import matplotlib.pyplot as plt
from   matplotlib.collections import LineCollection
import numpy as np
import jax
import jax.numpy as jnp
import genjax
from   genjax._src.core.transforms.incremental import UnknownChange, NoChange, Diff

## JAX

In [3]:
#|export
key       = jax.random.PRNGKey(0)
logsumexp = jax.scipy.special.logsumexp

In [4]:
#|export
def keysplit(key, *ns):
    if len(ns) == 0:  
        return jax.random.split(key, 1)[0]
    elif len(ns) == 1:
        n, = ns
        if n == 1: return keysplit(key)
        else:      return jax.random.split(key, ns[0])
    else:
        keys = []
        for n in ns: keys.append(keysplit(key, n))
        return keys


In [5]:
(
    keysplit(key),
    keysplit(key, 1),
    keysplit(key, 10),
    keysplit(key, 1, 10),
)

(Array([ 928981903, 3453687069], dtype=uint32),
 Array([ 928981903, 3453687069], dtype=uint32),
 Array([[3668660785,  713825972],
        [1185646547, 2092858387],
        [4260797006,  129535844],
        [ 928977296, 1618649917],
        [2708837749, 4129373854],
        [ 652965180, 3955248629],
        [1312337421, 1285539814],
        [2974568872, 3669116123],
        [1997906629, 3379841639],
        [4278014892, 1203387755]], dtype=uint32),
 [Array([ 928981903, 3453687069], dtype=uint32),
  Array([[3668660785,  713825972],
         [1185646547, 2092858387],
         [4260797006,  129535844],
         [ 928977296, 1618649917],
         [2708837749, 4129373854],
         [ 652965180, 3955248629],
         [1312337421, 1285539814],
         [2974568872, 3669116123],
         [1997906629, 3379841639],
         [4278014892, 1203387755]], dtype=uint32)])

In [6]:
#|export
def bounding_box(arr, pad=0):
    """Takes a euclidean-like arr (`arr.shape[-1] == 2`) and returns its bounding box."""
    return jnp.array([
        [jnp.min(arr[...,0])-pad, jnp.min(arr[...,1])-pad],
        [jnp.max(arr[...,0])+pad, jnp.max(arr[...,1])+pad]
    ])

In [7]:
#|export
def argmax_axes(a, axes=None):
    """Argmax along specified axes"""
    if axes is None: return jnp.argmax(a)
    
    n = len(axes)        
    axes_  = set(range(a.ndim))
    axes_0 = axes
    axes_1 = sorted(axes_ - set(axes_0))    
    axes_  = axes_0 + axes_1

    b = jnp.transpose(a, axes=axes_)
    c = b.reshape(np.prod(b.shape[:n]), -1)

    I = jnp.argmax(c, axis=0)
    I = jnp.array([jnp.unravel_index(i, b.shape[:n]) for i in I]).reshape(b.shape[n:] + (n,))

    return  I

In [8]:
test_shape = (3, 99, 5, 9)
a = jnp.arange(np.prod(test_shape)).reshape(test_shape)

I = argmax_axes(a, axes=[0,1])
I.shape

(5, 9, 2)

# Poses and Geometry

In [9]:
#|export
def rot2d(hd): return jnp.array([
    [jnp.cos(hd), -jnp.sin(hd)], 
    [jnp.sin(hd),  jnp.cos(hd)]
    ]);

def pack_2dpose(x,hd): 
    return jnp.concatenate([x,jnp.array([hd])])

def apply_2dpose(p, ys): 
    return ys@rot2d(p[2] - jnp.pi/2).T + p[:2]

def unit_vec(hd): 
    return jnp.array([jnp.cos(hd), jnp.sin(hd)])

def adjust_angle(hd):
    """Adjusts angle to lie in the interval [-pi,pi)."""
    return (hd + jnp.pi)%(2*jnp.pi) - jnp.pi

## GenJAX

In [10]:
#|export
def argdiffs(args, other=None):
    return tuple(map(lambda v: Diff(v, UnknownChange), args))


In [13]:
#|export
from builtins import property as _property, tuple as _tuple
from typing import Any


class Args(tuple):
    def __new__(cls, *args, **kwargs):
        return _tuple.__new__(cls, list(args) + list(kwargs.values()))
    
    def __init__(self, *args, **kwargs):
        self._d = dict()
        for k,v in kwargs.items():
            self._d[k] = v
            setattr(self, k, v)

    def __getitem__(self, k: str) -> Any:
        return self._d[k]

In [15]:
type(Args(x=1))

__main__.Args

In [17]:
#|export
# 
# Monkey patching `sample` for `BuiltinGenerativeFunction`
# 
cls = genjax._src.generative_functions.builtin.builtin_gen_fn.BuiltinGenerativeFunction

def genjax_sample(self, key, *args, **kwargs):
    tr = self.simulate(key, args)
    return tr.get_retval()

setattr(cls, "sample", genjax_sample)


# 
# Monkey patching `sample` for `DeferredGenerativeFunctionCall`
# 
cls = genjax._src.generative_functions.builtin.builtin_gen_fn.DeferredGenerativeFunctionCall

def deff_gen_func_call(self, key, **kwargs):
    return self.gen_fn.sample(key, *self.args, **kwargs)

def deff_gen_func_logpdf(self, x, **kwargs):
    return self.gen_fn.logpdf(x, *self.args, **kwargs)

setattr(cls, "__call__", deff_gen_func_call)
setattr(cls, "sample", deff_gen_func_call)
setattr(cls, "logpdf", deff_gen_func_logpdf)