#Cross-Stitch Network using Jax

Implementation of: 
---
Misra, I., Shivastava, A., Gupta, A., Herbert,M. (2016). Cross-Stitch Networks for Multi-Task Learning. 

Misra et al combined the activation map of a layer of one network to another network using a  learnable parameters $\alpha$. 

Overview:
--- 
Two FC networks with [60,20,10] stucture are trained on either MNIST or Fashion MNIST. 
The two networks are combined with $\alpha$ to form a network of shape [120,40,10] and trained on a fraction of a combined dataset. 

Result is a combined network that can classify mnist and fashion mnist with similar accuracy to the networks on their own. 

## Import Dependencies

In [1]:
import random as orandom
import jax
import jax.numpy as jnp
from jax import grad,vmap,jit
import numpy as onp
from jax import random
key = random.PRNGKey(0)
import timeit

from jax import experimental 
from jax.experimental import *
from jax.numpy import linalg

import tensorflow as tf
from keras.utils import to_categorical

import matplotlib.pyplot as py
from jax import device_put

import itertools
import random as orandom
from jax.experimental import optimizers
from sklearn.utils import shuffle #to help with shuffling of combined dataset.

In [2]:
tf.config.experimental.set_visible_devices([], "GPU")

##Helper Functions

###Activation Functions

In [3]:
def sigmoid_act(x): 
  return jax.nn.sigmoid(x)
def softmax_act(x): 
    return jax.nn.softmax(x)
def binary_crossentropy(x,y): #x=input, y= target
    return -y*jnp.log(x)-(1-y)*jnp.log(1-x)
def relu_act(x): 
  return jax.nn.relu(x)
def NLL(x,y): 
  return -jnp.log(x[jnp.argmax(y)]) #assuming one hot

###Initialization function

In [4]:
def init_parameters_nobias(shapes,input_shape=784):      
    trainable_v=[[]]
    #first layer
    trainable_v[0].append(onp.random.randn(shapes[0],input_shape)/10) #input
    trainable_v[0].append( onp.zeros (shapes[0]))
    for i in range(1,len(shapes)): 
      trainable_v.append([]) 
      trainable_v[i].append(onp.random.randn(shapes[i],shapes[i-1])/10)
      trainable_v[i].append(onp.zeros(shapes[i]))
    trainable_v[i][1] =onp.zeros(shapes[i]) #the last layer will have 0 bias to allow for concatenation 
    return trainable_v



###Update Weight Function

In [5]:
#assumes the gradient input has shape [batch_size,weight matrix]
def update_weights(params, gradient ,lr=1.0): 
  for i in range(len(params)): #iterate through the layer. 
    params[i][0]=params[i][0]-(lr*jnp.sum(gradient[i][0],axis=0))
    params[i][1]=params[i][1]-(lr*jnp.sum(gradient[i][1],axis=0))
  return params

##Network and loss function

In [6]:
def dense_network_list(params,input):
  l1=jnp.dot(input,params[0][0].T)+params[0][1]
  l1=sigmoid_act(l1)
  l2=jnp.dot(l1,params[1][0].T)+params[1][1] 
  l2=sigmoid_act(l2)
  l3=jnp.dot(l2,params[2][0].T)#+params[2][1]
  l3=softmax_act(l3)
  return l3

def NLL_loss(params,i,t): #loss function only accepts one sample at a time, however I will attempt to remedy this by vmap
  pred=dense_network_list(params,i)
  final=NLL(pred,t)
  return final

gradient=grad(NLL_loss)

vmap_backprop = vmap(gradient,in_axes=(None,0,0))

##Load Data

Load Data, MNIST

In [7]:
(mnist_train_data,mnist_train_labels),(mnist_test_data,mnist_test_labels)=tf.keras.datasets.mnist.load_data()
mnist_train_data=mnist_train_data.reshape(60000,784).astype('float32')
mnist_test_data=mnist_test_data.reshape(10000,784).astype('float32')
mnist_train_labels=to_categorical(mnist_train_labels)
mnist_test_labels=to_categorical(mnist_test_labels)
mnist_train_data=mnist_train_data/255.0
mnist_test_data=mnist_test_data/255.0
mnist_train_data=device_put(mnist_train_data, jax.devices('gpu')[0])
mnist_train_labels=device_put(mnist_train_labels, jax.devices('gpu')[0])

Load Data Fashion MNIST

In [8]:
(fashion_train_data,fashion_train_labels),(fashion_test_data,fashion_test_labels)=tf.keras.datasets.fashion_mnist.load_data()
fashion_train_data=fashion_train_data.reshape(60000,784).astype('float32') /255.0
fashion_test_data=fashion_test_data.reshape(10000,784).astype('float32')/255.0
fashion_train_labels=to_categorical(fashion_train_labels) 
fashion_test_labels=to_categorical(fashion_test_labels) 
fashion_train_data=device_put(fashion_train_data, jax.devices('gpu')[0])
fashion_train_labels=device_put(fashion_train_labels, jax.devices('gpu')[0])

##Training

###Initialize parameters for MNIST Network

In [9]:
onp.random.seed(1000)
parameters_mnist = init_parameters_nobias([60,20,10],input_shape=784)

###Train First Network on MNIST

In [10]:
dense_jit=jit(dense_network_list)
jit_backprop=jit(vmap_backprop)
jit_update_weights=jit(update_weights)

ctr=0
start_time = timeit.default_timer()
for i in range(5): #epochs
  for j in range(500): #number of batches to iterate.
    
    dparams=jit_backprop(parameters_mnist,mnist_train_data[ctr:ctr+100],mnist_train_labels[ctr:ctr+100]) #The last value is batch size
    parameters_mnist= jit_update_weights(parameters_mnist,dparams,lr=0.01)
    ctr=ctr+100
    if ctr+100>50000:
      ctr=0
  pred=dense_jit(parameters_mnist,mnist_train_data[50000:50100])
  pred=jnp.argmax(pred,axis=1)

  targets=jnp.argmax(mnist_train_labels[50000:50100],axis=1)
  print (f'epoch: ', i+1)
  print(f'Validation Accuracy: ', len(jnp.where(pred == targets)[0])/100*100)
elapsed = timeit.default_timer() - start_time
print (f'elapsed time: ', elapsed, 's')  

epoch:  1
Validation Accuracy:  88.0
epoch:  2
Validation Accuracy:  93.0
epoch:  3
Validation Accuracy:  96.0
epoch:  4
Validation Accuracy:  97.0
epoch:  5
Validation Accuracy:  98.0
elapsed time:  13.200043748999633 s


####Test MNIST network

In [11]:
#Test: 
pred_1=dense_jit(parameters_mnist,mnist_test_data)
pred_1=jnp.argmax(pred_1,axis=1)
target_1=jnp.argmax(mnist_test_labels,axis=1)
print(f'Test Accuracy:',len(jnp.where(pred_1 == target_1)[0])/10000 *100)

Test Accuracy: 95.72


###Second Network for Fashion MNIST

###Initialize parameters for Fashion MNIST Network

In [12]:
onp.random.seed(1002)
parameters_fashion = init_parameters_nobias([60,20,10],input_shape=784)

In [13]:
ctr=0
start_time = timeit.default_timer()
for i in range(5): #epochs
  for j in range(600): #number of batches to iterate.
    
    dparams=jit_backprop(parameters_fashion,fashion_train_data[ctr:ctr+100],fashion_train_labels[ctr:ctr+100]) #The last value is batch size
    parameters_fashion= jit_update_weights(parameters_fashion,dparams,lr=0.01)
    ctr=ctr+100
    #print (j)
    if ctr+100>60000:
      ctr=0
  pred=dense_jit(parameters_fashion,fashion_train_data[30000:30100])
  pred=jnp.argmax(pred,axis=1)

  targets=jnp.argmax(fashion_train_labels[30000:30100],axis=1)
  print (f'epoch: ', i+1)
  print(f'Validation Accuracy: ', len(jnp.where(pred == targets)[0])/100*100)
elapsed = timeit.default_timer() - start_time
print (f'elapsed time: ', elapsed, 's')  

epoch:  1
Validation Accuracy:  83.0
epoch:  2
Validation Accuracy:  86.0
epoch:  3
Validation Accuracy:  87.0
epoch:  4
Validation Accuracy:  87.0
epoch:  5
Validation Accuracy:  87.0
elapsed time:  11.376954624000064 s


####Test For Fashion MNIST Network

In [14]:
#full test: 
pred_1=dense_jit(parameters_fashion,fashion_test_data)
pred_1=jnp.argmax(pred_1,axis=1)
target_1=jnp.argmax(fashion_test_labels,axis=1)
print(f'Test Accuracy:',len(jnp.where(pred_1 == target_1)[0])/10000 *100)

Test Accuracy: 85.58


###"Stitch" networks together

####Compose Parameters Function
Function to concatenate the networks together and another matrix called $\alpha$


Function takes: 

$W^{params1}$
$W^{params2}$

Performs cross stitch using trainable parameter $\alpha$ returns:

\
$
composed \ parameters=
\begin{bmatrix} 
W^{params1} & \alpha^{set1}\\
\alpha^{set2} & W^{params2}\\
\end{bmatrix}
\quad
$

This is done for all hidden layers, with the output layer being concatenated along one dimension







In [15]:
def compose_parameters(params1,params2,alpha):
  #Concatenate all bias. The final layer will not matter, the bias is set to 0, and is excluded in the last layer as output layer is concatenated along 0 axis. 
  layer0_cat_bias = jnp.concatenate([params1[0][1],params2[0][1]],axis=0)
  layer1_cat_bias=jnp.concatenate([params1[1][1],params2[1][1]],axis=0) 
  layer2_cat_bias=params1[2][1]

  #layer0: 
  layer0_cat = jnp.concatenate([params1[0][0],params2[0][0]],axis=0)
  #layer1: 
  temp1=jnp.concatenate([params1[1][0],alpha[0][0]],axis=1) 
  temp2=jnp.concatenate([alpha[0][1],params2[1][0]],axis=1) 
  layer1_cat = jnp.concatenate([temp1,temp2],axis=0)
  
  #layer2 - Output Layer.
  layer2_cat=jnp.concatenate([params1[2][0],params2[2][0]],axis=1)  #this will output a 10,40
  
  #put concatenations together into new parameters

  new_parameters=[[]]
  new_parameters[0].append(layer0_cat)
  new_parameters[0].append(layer0_cat_bias)
  new_parameters.append([])
  new_parameters[1].append(layer1_cat) 
  new_parameters[1].append(layer1_cat_bias)
  new_parameters.append([])
  new_parameters[2].append(layer2_cat)
  new_parameters[2].append(layer2_cat_bias)
  return new_parameters

New loss function with the separate 

In [16]:
#Function that takes parameters of both networks, and 'alpha' puts them together. 
def NLL_loss_alpha(params1,params2,alpha,i,t): 
  new_parameters=compose_parameters(params1,params2,alpha)
  pred=jit(dense_network_list)(new_parameters,i)
  final=(NLL)(pred,t)
  return final

In [17]:
gradient_alpha=grad(NLL_loss_alpha,argnums=(2))

vmap_backprop_alpha = vmap(gradient_alpha,in_axes=(None,None,None,0,0))


def update_alpha(alpha_,dalpha,lr=1.0):
  for i in range(len(alpha_)): #iterate through the layer. 
    alpha_[i][0]=alpha_[i][0]-(lr*jnp.sum(dalpha[i][0],axis=0))
    alpha_[i][1]=alpha_[i][1]-(lr*jnp.sum(dalpha[i][1],axis=0))
  return alpha_

#dense_jit=jit(dense_network_list) #speed up
jit_backprop=jit(vmap_backprop_alpha)
jit_update_alpha=jit(update_alpha)



###Combine Data Sets to retrain network on both tasks
Combined set has 120000 samples

In [18]:
combined_sets = jnp.concatenate([mnist_train_data,fashion_train_data],axis=0)
combined_labels=jnp.concatenate([mnist_train_labels,fashion_train_labels],axis=0)
onp.random.seed(1500)
shuf_combined_sets,shuf_combined_labels=shuffle(combined_sets,combined_labels)

###Initialize $\alpha$

In [19]:
alpha=[[]]
alpha[0].append(jnp.zeros((20,60)))
alpha[0].append(jnp.zeros((20,60)))

###Retrain Combined Network, but only on 2000 samples

In [20]:

ctr=0
start_time = timeit.default_timer()
for i in range(40): #epochs
  for j in range(20): #number of batches to iterate.
    
    dalpha=jit_backprop(parameters_mnist,parameters_fashion,alpha,shuf_combined_sets[ctr:ctr+100],shuf_combined_labels[ctr:ctr+100])
 #The last value is batch size
    alpha= jit_update_alpha(alpha,dalpha,lr=0.01) #lr has to be very small
    ctr=ctr+100
    #print (j)
    if ctr+100>2000:
      ctr=0
  new_params=compose_parameters(parameters_mnist,parameters_fashion,alpha)
  pred=dense_jit(new_params,fashion_train_data[30000:31000])
  pred=jnp.argmax(pred,axis=1)

  targets=jnp.argmax(fashion_train_labels[30000:31000],axis=1)
  print (f'epoch: ', i+1)
  print(f'Validation Accuracy: ', len(jnp.where(pred == targets)[0])/1000)
elapsed = timeit.default_timer() - start_time
print (f'elapsed time: ', elapsed)  

epoch:  1
Validation Accuracy:  0.806
epoch:  2
Validation Accuracy:  0.825
epoch:  3
Validation Accuracy:  0.834
epoch:  4
Validation Accuracy:  0.838
epoch:  5
Validation Accuracy:  0.843
epoch:  6
Validation Accuracy:  0.848
epoch:  7
Validation Accuracy:  0.853
epoch:  8
Validation Accuracy:  0.851
epoch:  9
Validation Accuracy:  0.851
epoch:  10
Validation Accuracy:  0.852
epoch:  11
Validation Accuracy:  0.853
epoch:  12
Validation Accuracy:  0.853
epoch:  13
Validation Accuracy:  0.853
epoch:  14
Validation Accuracy:  0.855
epoch:  15
Validation Accuracy:  0.855
epoch:  16
Validation Accuracy:  0.853
epoch:  17
Validation Accuracy:  0.852
epoch:  18
Validation Accuracy:  0.851
epoch:  19
Validation Accuracy:  0.851
epoch:  20
Validation Accuracy:  0.85
epoch:  21
Validation Accuracy:  0.85
epoch:  22
Validation Accuracy:  0.85
epoch:  23
Validation Accuracy:  0.85
epoch:  24
Validation Accuracy:  0.85
epoch:  25
Validation Accuracy:  0.85
epoch:  26
Validation Accuracy:  0.85
ep

###Examine Combined Network Accuracy for MNIST 

In [21]:
super_params=compose_parameters(parameters_mnist,parameters_fashion,alpha) #combine all parameters into a new set of parameters.

In [22]:
pred_1=dense_jit(super_params,mnist_test_data)
pred_1=jnp.argmax(pred_1,axis=1)
target_1=jnp.argmax(mnist_test_labels,axis=1)
print(f'Test Accuracy:',len(jnp.where(pred_1 == target_1)[0])/10000 *100)

Test Accuracy: 93.77


###Examine Combined Network Accuracy on Fashion MNIST

In [23]:
pred_1=dense_jit(super_params,fashion_test_data)
pred_1=jnp.argmax(pred_1,axis=1)
target_1=jnp.argmax(fashion_test_labels,axis=1)
print(f'Test Accuracy:',len(jnp.where(pred_1 == target_1)[0])/10000 *100)

Test Accuracy: 84.34


##Results

MNIST- Alone = 95.72%
Fashion - Alone = 85.58%

Combined with cross stitch  - MNIST data = 93.77% , Fashion MNIST data= 84.34% on 2000 samples of combined set of data

##Appendix 





#####Parameter of mnist:
$$
\begin{bmatrix} 
w^{mnist}_{0,0} & w^{mnist}_{0,1} & ... & w^{mnist}_{0,j}  \\
w^{mnist}_{1,0} & w^{mnist}_{1,1} & ... & w^{mnist}_{1,j}\\
\vdots & \vdots & ... & \vdots\\
w^{mnist}_{i,0} & w^{mnist}_{i,1} & ... & w^{mnist}_{i,j} \\
\end{bmatrix}
\quad
$$

Parameter of fashion mnist:
$$
\begin{bmatrix} 
w^{fashion}_{0,0} & w^{fashion}_{0,1} & ... & w^{fashion}_{0,j}  \\
w^{fashion}_{1,0} & w^{fashion}_{1,1} & ... & w^{fashion}_{1,j}\\
\vdots & \vdots & ... & \vdots\\
w^{fashion}_{i,0} & w^{fashion}_{i,1} & ... & w^{fashion}_{i,j} \\
\end{bmatrix}
\quad
$$

Full Cross Stitch:

$$
\begin{bmatrix} 
w^{mnist}_{0,0} & w^{mnist}_{0,1} & ... & w^{mnist}_{0,j} & \alpha^{set1}_{0,0} & \alpha^{set1}_{0,1} & ... & \alpha^{set1}_{0,j}  \\
w^{mnist}_{1,0} & w^{mnist}_{1,1} & ... & w^{mnist}_{1,j} & \alpha^{set1}_{1,0} & \alpha^{set1}_{1,1} & ... & \alpha^{set1}_{1,j} \\
\vdots & \vdots & ... & \vdots & \vdots & \vdots & ... & \vdots\\
w^{mnist}_{i,0} & w^{mnist}_{i,1} & ... & w^{mnist}_{i,j} & \alpha^{set1}_{i,0} & \alpha^{set1}_{i,1} & ... & \alpha^{set1}_{i,j}  \\
\alpha^{set2}_{0,0} & \alpha^{set2}_{0,1} & ... & \alpha^{set2}_{0,j} &  w^{fashion}_{0,0} & w^{fashion}_{0,1} & ... & w^{fashion}_{0,j} \\
\alpha^{set2}_{1,0} & \alpha^{set2}_{1,1} & ... & \alpha^{set2}_{1,j} & w^{fashion}_{1,0} & w^{fashion}_{1,1} & ... & w^{fashion}_{1,j}\\
\vdots & \vdots & ... & \vdots &\vdots & \vdots & ... & \vdots \\
\alpha^{set2}_{i,0} & \alpha^{set2}_{i,1} & ... & \alpha^{set2}_{i,j} & w^{fashion}_{i,0} & w^{fashion}_{i,1} & ... & w^{fashion}_{i,j}\\
\end{bmatrix}
\quad
$$
