# Imports


In [4]:
import jax.numpy as jnp
import jax

# Jax Tutorial Basics

In [2]:
jnp.zeros(shape=(10,5))

Array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)

In [13]:
@jax.jit
def softmax(array, temp=2.0):
    arr = jnp.exp(array / temp)
    return arr/ jnp.sum(arr, axis=-1, keepdims=True)
# Cant jhust jit enveryhting, no while loops and conditionals for now

key = jax.random.key(50)
k, sk = jax.random.split(key)
print(softmax(jax.random.uniform(sk, shape=(100000,20000))))
print(softmax(jax.random.uniform(sk, shape=(100000,20000))).sum(axis=-1))

[[4.5338445e-05 4.1128937e-05 4.3717992e-05 ... 4.6037436e-05
  3.9925937e-05 3.9466235e-05]
 [4.0832503e-05 5.2905860e-05 3.8572816e-05 ... 4.3123215e-05
  6.3363841e-05 4.0338124e-05]
 [4.5396722e-05 5.3326970e-05 3.9350503e-05 ... 5.8048146e-05
  4.2621217e-05 4.6469559e-05]
 ...
 [4.3514236e-05 5.5165216e-05 5.5767734e-05 ... 4.7878770e-05
  4.5894845e-05 5.1567571e-05]
 [5.8635400e-05 3.8818544e-05 5.1258208e-05 ... 5.6500277e-05
  5.4009601e-05 4.4502205e-05]
 [4.1771575e-05 4.4697372e-05 4.6291963e-05 ... 4.1309508e-05
  5.5091899e-05 4.0407263e-05]]
[1.0000001  0.99999994 1.         ... 1.0000001  0.99999994 1.        ]


In [91]:
# Experiments with vmap
# Convolution 

# Input 3D array (batch, dim1, dim2), kernel (batch, k1, k2)

@jax.jit
def convolve(array, kernel, answer):
    b,k1,k2 = kernel.shape
    for j in range(answer.shape[1]):
        for k in range(answer.shape[2]):
            cov = jnp.multiply(array[:, j:j+k1, k:k+k2], kernel).sum(axis=-1).sum(axis=-1)
            answer = answer.at[:, j, k].set(cov)
    
    return answer   

def convolve2D(array, kernel, padding =  False):
    if kernel.ndim == 2:
        kernel = jnp.tile(kernel, (array.shape[0], 1, 1)) 
        print(kernel.shape)
    
    if padding:
        answer = jnp.zeros(shape = (array.shape))
        pad_width = [(0,0)] + [((k-1)//2, (k-1)//2 + (1 if k%2==0 else 0)) for k in kernel.shape[1:]]
        array = jnp.pad(array, pad_width=pad_width, mode='constant', constant_values=0)
        print(array.shape)
        

    else:
        shape = [array.shape[0], array.shape[1] - 2*((kernel.shape[1]-1)//2), array.shape[2] - 2*((kernel.shape[1]-1)//2)]
        answer = jnp.zeros(shape= shape)
        print(answer.shape)
    
    return convolve(array, kernel, answer)

@jax.jit
def convolve2(array, kernel, answer):
    k1,k2 = kernel.shape
    for j in range(answer.shape[0]):
        for k in range(answer.shape[1]):
            cov = jnp.multiply(array[j:j+k1, k:k+k2], kernel).sum(axis=-1).sum(axis=-1)
            answer = answer.at[j, k].set(cov)
    
    return answer 

def convolve2D2(array, kernel, padding=False):

    if kernel.ndim == 2:
        kernel = jnp.tile(kernel, (array.shape[0], 1, 1)) 
        print(kernel.shape)
        
    if padding:
        answer = jnp.zeros(shape = (array.shape))
        pad_width = [(0,0)] + [((k-1)//2, (k-1)//2 + (1 if k%2==0 else 0)) for k in kernel.shape[1:]]
        array = jnp.pad(array, pad_width=pad_width, mode='constant', constant_values=0)
        print(array.shape)
        

    else:
        shape = [array.shape[0], array.shape[1] - 2*((kernel.shape[1]-1)//2), array.shape[2] - 2*((kernel.shape[1]-1)//2)]
        answer = jnp.zeros(shape= shape)
        print(answer.shape)
    
    convolve2batch = jax.vmap(convolve2)
    return convolve2batch(array, kernel, answer)

array = jnp.tile(jnp.arange(10), (10000, 10, 1))

kernel = jnp.array([[1,1,1],[1,1,1], [1,1,1]])

import time
start_time = time.time()
result = convolve2D(array, kernel, padding= True)
print(f"Time taken to compute: {time.time() - start_time} seconds")


# convolve2Dbatch = jax.vmap(convolve2D2)

import time
start_time = time.time()
result = convolve2D2(array, kernel, padding= True)
print(f"Time taken to compute: {time.time() - start_time} seconds")

# Speedup is insane, mrola of the story do element wise computation, for each batch basically just do a vmap
# Recurrent neural network


(10000, 3, 3)
(10000, 12, 12)
Time taken to compute: 0.5043902397155762 seconds
(10000, 3, 3)
(10000, 12, 12)
Time taken to compute: 0.6596851348876953 seconds


In [None]:
# Experiments with vmap
# Recurrent Neural Network
# Input 3D array (batch, dim1, dim2), kernel (batch, k1, k2)



# Networks
