In [1]:
import jax
import jax.numpy as np
from jax import grad,vmap,jit
import numpy as onp
from jax import random
key = random.PRNGKey(0)
import timeit

from jax import experimental 
from jax.experimental import *
from jax.numpy import linalg

import tensorflow as tf
from keras.utils import to_categorical

import matplotlib.pyplot as py
from jax import device_put

import itertools

from jax.experimental import optimizers

Using TensorFlow backend.


In [2]:
tf.config.experimental.set_visible_devices([], "GPU")

In [135]:
def tanh_act(x): 
    return np.tanh(x)
def sigmoid_act(x): 
  return jax.nn.sigmoid(x)
def softmax_act(x): 
    return np.exp(x)/(np.sum(np.exp(x)))
def binary_crossentropy(x,y): #x=input, y= target
    return -y*np.log(x)-(1-y)*np.log(1-x)
def relu_act(x): 
  return jax.nn.relu(x)
def normalize(x): 
  return jax.nn.normalize(x,axis=0)

In [121]:
def init_parameters(shapes):  
    onp.random.seed(1000)
    trainable_v=[[]]
    #first layer
    trainable_v[0].append(onp.random.randn(shapes[0],shapes[0])) #input
    trainable_v[0].append( onp.random.randn (shapes[0])) #bbias 
    for i in range(1,len(shapes)): 
      trainable_v.append([]) #creates new layer?
      trainable_v[i].append(onp.random.randn(shapes[i],shapes[i-1]))
      trainable_v[i].append(onp.random.randn(shapes[i]))
    return trainable_v

def NLL(x,y): 
    return -np.log(x[np.argmax(y)]) #assuming one hot

def BCE_loss(params,i,t): 
  pred=dense_network(params,i)
  final=binary_crossentropy(pred,t) 
  return final

#assumes the gradient input has shape [batch_size,weight matrix]
#this is why the np.mean is there.
def update_weights(params, gradient ,lr=1.0): 
  for i in range(len(params)): #iterate through the layer. 
    params[i][0]=params[i][0]-(lr*np.mean(gradient[i][0],axis=0))
    params[i][1]=params[i][1]-(lr*np.mean(gradient[i][1],axis=0))
  return params





In [136]:
#Single Network

def dense_network_list(params,input):
  l1=np.dot(input,params[0][0].T)+params[0][1]
  l1=sigmoid_act(l1)
  l2=np.dot(l1,params[1][0].T)+params[1][1] 
  l2=sigmoid_act(l2)
  l3=np.dot(l2,params[2][0].T)+params[2][1]
  l3=softmax_act(l3)
  return l3



def NLL_loss(params,i,t): #loss function only accepts one sample at a time, however I will attempt to remedy this by vmap
  pred=dense_network_list(params,i)
  final=jit(NLL)(pred,t)
  return final

gradient=grad(NLL_loss)

vmap_backprop = vmap(gradient,in_axes=(None,0,0))


###Data

In [6]:
(train_data,train_labels),(test_data,test_labels)=tf.keras.datasets.mnist.load_data()
train_data=train_data.reshape(60000,784).astype('float32')
test_data=test_data.reshape(10000,784).astype('float32')
train_labels=to_categorical(train_labels)
test_labels=to_categorical(test_labels)
train_data=train_data/255.0
test_data=test_data/255.0

In [7]:
train_data=device_put(train_data, jax.devices('gpu')[0])
train_labels=device_put(train_labels, jax.devices('gpu')[0])

###Testing

In [137]:
parameters = init_parameters([784,50,10])

In [138]:
dense_jit=jit(dense_network_list) #speed up
jit_backprop=jit(vmap_backprop)
jit_update_weights=jit(update_weights)

ctr=0
start_time = timeit.default_timer()
for i in range(10): #epochs
  for j in range(300): #number of batches to iterate.
    
    dparams=jit_backprop(parameters,train_data[ctr:ctr+100],train_labels[ctr:ctr+100]) #The last value is batch size
    parameters= jit_update_weights(parameters,dparams,lr=1.0)
    ctr=ctr+100
    #print (j)
    if ctr+10>30000:
      ctr=0

  pred=np.argmax(dense_jit(parameters,train_data[30000:30100]),axis=1)
  targets=np.argmax(train_labels[30000:30100],axis=1)
  print (f'epoch: ', i+1)
  print(len(np.where(pred == targets)[0])/100)
elapsed = timeit.default_timer() - start_time
print (elapsed)  

epoch:  1
0.73
epoch:  2
0.75
epoch:  3
0.79
epoch:  4
0.84
epoch:  5
0.84
epoch:  6
0.86
epoch:  7
0.86
epoch:  8
0.87
epoch:  9
0.89
epoch:  10
0.88
10.315771481999946


In [139]:
#full test: 
pred_1=np.argmax(dense_jit(parameters,test_data),axis=1)
target_1=np.argmax(test_labels,axis=1)
print(len(np.where(pred_1 == target_1)[0])/10000 *100)


91.05
