## Low Rank Approximations

### Imports

In [None]:
!pip install dm-haiku optax

In [None]:
from typing import Iterator, Mapping, Tuple
from copy import deepcopy
import time
from absl import app
import haiku as hk
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds
import tensorflow as tf
from functools import partial
import math
from sklearn.decomposition import TruncatedSVD          # To calculate accuracy
from statistics import mean

Batch = Tuple[np.ndarray, np.ndarray]

In [None]:
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2023, 0.1994, 0.2010)

def net_fn(batch: Batch) -> jnp.ndarray:

  x = normalize(batch[0])
  
  # Architecture
  net = hk.Sequential([
      hk.Conv2D(output_channels=6*3, kernel_shape=(5,5)),
      jax.nn.relu,
      hk.AvgPool(window_shape=(2,2), strides=(2,2), padding='VALID'),
      jax.nn.relu,
      hk.Conv2D(output_channels=16*3, kernel_shape=(5,5)), 
      jax.nn.relu,
      hk.AvgPool(window_shape=(2,2), strides=(2,2), padding='VALID'),
      hk.Flatten(),
      hk.Linear(3000), jax.nn.relu,
      hk.Linear(2000), jax.nn.relu,
      hk.Linear(2000), jax.nn.relu,
      hk.Linear(1000), jax.nn.relu,
      hk.Linear(10),
  ])
  return net(x)

def load_dataset(split: str,*,is_training: bool,batch_size: int) -> Iterator[tuple]:     ###  ,*,

  """Loads the dataset as a generator of batches.
    Args:
    split : str 
      The split of the input dataset
    is_training : bool
      Dataset is to be trained or not
    batch_size : int
      Size of the batch

    Returns: 
    Iterator : object
      Returns Iterator of the given object"""

  ds = tfds.load('cifar10', split=split, as_supervised=True).cache().repeat()           # tfds = TensorFlow Dataset
  if is_training:
    ds = ds.shuffle(10 * batch_size, seed=0)       ### 10 * batch_size
  ds = ds.batch(batch_size)
  return iter(tfds.as_numpy(ds))

def compute_loss(params: hk.Params, batch: Batch) -> jnp.ndarray:

  """Compute the loss of the network, including L2.

    Args:

     params : float32
      The params of the network
    batch : float32
      Ground truth of network
    
    Returns: 

      softmax_xnet : float32
        Returns loss"""

  x,y = batch
  logits = net.apply(params, batch)
  preds = jax.nn.log_softmax(logits)     
  labels = jax.nn.one_hot(y, 10)

  l2_loss = jnp.sum(optax.l2_loss(preds,labels))           

  #l2_loss = 16.0
  weighted_l2_loss = 0.5 * l2_loss

  softmax_xent = -jnp.sum(labels * preds)     # preds = jax.nn.log_softmax(logits)
  softmax_xent = softmax_xent + (1e-4 * l2_loss)      ### Use of positive (+) instead of negative (-) for l2 regularization
                                                      ### Learning rate can be 1e-3 instead of 1e-4 (accurate minimization for cost function)
                                                      ### Use of weighted_l2_loss instead of l2_loss  (equal penalty for parameter)
  return softmax_xent          

@jax.jit
def compute_accuracy(params: hk.Params, batch: Batch) -> jnp.ndarray:

  """Compute the accuracy of the network.

  Args:

    params : float32
      The weights/params of network
    batch : float32
      Groud truth of network
    
  Returns: 
    accuracy : float32
      Returns accuracy over the batch"""

  x, y = batch
  labels = jax.nn.one_hot(y,10)

  predictions = net.apply(params, batch)
  preds = jax.nn.log_softmax(predictions)

  accuracy = jnp.mean(jnp.argmax(predictions,-1) == y) # compute accuracy over batch
  #accuracy = 0.0
  return accuracy

@jax.jit
def update(params: hk.Params, opt_state: optax.OptState, batch: Batch,) -> Tuple[hk.Params, optax.OptState]:

  """Update parameters of network.

    Args:

    params : float32
      The parameters of network
    opt_state : float32
      Current state of network
    batch : float32
      Ground truth of input
    
  Returns: 
    new_params : float32
      Updated parameters of network
    opt_state : 
      Updates state of network"""
  
  grads = jax.grad(compute_loss)(params, batch)
  updates, opt_state = opt.update(grads, opt_state)       # Update states
  new_params = optax.apply_updates(params, -1*updates)

  return new_params, opt_state

@jax.jit
def ema_update(params, avg_params):
  """Update parameter of network

      Args:
        params : float32
          The parameters of network
        avg_params :
          Moving average of paramters of network

      Returns:
        params : float32
          Updated parameters of network
        avg_params :
          Updated moving average of paramters of network"""

  return optax.incremental_update(params, avg_params, step_size=0.001)

def normalize(images):

  """Normalize the data

  Args:
    images : string
      Path of input
  
  Returns:
    x : float32
    Normalizes ground truth of input"""
  
  mean = np.asarray(CIFAR10_MEAN)
  std = np.asarray(CIFAR10_STD)

  x = images.astype(jnp.float32) / 255.         ### float32 provides faster computation and approximation instead of int8 
  x /- mean                  
  x /= std

  return x

### Training

In [None]:
net = hk.without_apply_rng(hk.transform(net_fn))

# Learning rate
opt = optax.adam(1e-3)
# opt = optax.chain(optax.adam(1e-3), optax.scale_by_adam(), optax.scale(-1.0))       ### Use of optax in-built methods to
                                                                                      ### and take learning rate reduction into account 

train = load_dataset("train[80%:]", is_training=True, batch_size=1000)    
validation = load_dataset("train[0%:80%]", is_training=False, batch_size=10000)
test = load_dataset("test", is_training=False, batch_size=10000)

params = avg_params = net.init(jax.random.PRNGKey(42), next(train))
opt_state = opt.init(params)            

# Do not alter the number of steps
for step in range(10001):      #  For faster computation 

  if step % 1000 == 0:         # For faster computation
    val_accuracy = compute_accuracy(avg_params, next(validation))
    val_loss = compute_loss(avg_params, next(validation))
    test_accuracy = compute_accuracy(avg_params, next(test))
    val_accuracy, test_accuracy = jax.device_get(
        (val_accuracy, test_accuracy))
    print(f"[Step {step}] Validation / Test accuracy: "
          f"{val_accuracy:.3f} / {test_accuracy:.3f}.")
    print(f"[Step {step}] Loss: " f"{val_loss:.3f}.")   

  params, opt_state = update(params, opt_state, next(train))
  avg_params = ema_update(params, avg_params)               # Switch states new_params

### Metrics and Functions

In [None]:
def compute_eval_metrics(params, batch, n_samples):

  duration_list = []
  accuracy_list = []
  for i in range(n_samples):
    start = time.time()
    acc = compute_accuracy(params, batch)
    end = time.time()
    duration = end - start
    duration_list.append(duration)
    accuracy_list.append(acc)
  # mean_acc = mean(accuracy_list)

  return accuracy_list,duration_list

In [None]:
def rank_approximated_weight(weight: jnp.ndarray, rank_fraction: float):

  #weight = np.linalg.matrix_rank(weight)

  U, S, V = np.linalg.svd(weight, full_matrices=False)        # SVD
  k = rank = int(rank_fraction * min(len(weight[0]), len(weight[1])))
  #print(k)
  
  #rank_matrix = np.zeros(len(U),len(V))
  rank_matrix = (U[:,:k] @ np.diag(S[k-1])) @ V[:k]       # Find rank of a matrix
  U, S, V = np.linalg.svd(rank_matrix, full_matrices=False)   # Reconstruct and find u,s,v

  #S = np.diag(S[k-1])

  # u = jax.random.normal(jax.random.PRNGKey(42), shape=weight.shape)
  # size = weight.shape[1]
  # v = jax.random.normal(jax.random.PRNGKey(42), shape=(size,size))

  return U[:,k], V[:k], S

### Evaluations 
At different ranks



In [None]:
rank_truncated_params = deepcopy(params)
ranks_and_accuracies = []
ranks_and_times = []
for rank_fraction in np.arange(1.0, 0.0, -0.1):

  print(f"Evaluating the model at {rank_fraction}")
  for layer in params.keys():
    if 'conv' in layer:
      continue
    weight = params[layer]['w']
    u, v, s = rank_approximated_weight(weight, rank_fraction)
    rank_truncated_params[layer]['w'] = u@v

  test_batch = next(test)
  # we compute metrics over 50 samples to reduce noise in the measurement.
  n_samples = 50
  test_accuracy, latency = compute_eval_metrics(rank_truncated_params, next(test), n_samples)
  print(f"Rank Fraction / Test accuracy: "
          f"{rank_fraction:.2f} / {np.mean(test_accuracy):.3f}.")
  ranks_and_accuracies.append((rank_fraction, np.mean(test_accuracy)))
  print(f"Rank Fraction / Duration: "
          f"{rank_fraction:.2f} / {np.mean(latency):.4f}.")
  ranks_and_times.append((rank_fraction, np.mean(latency)))

### Plot relationships

In [None]:
# Accuracy vs Rank Percentage

plt.plot(ranks_and_accuracies[0],ranks_and_accuracies[1])
plt.label("Accuracy vs Rank Percentage")
plt.xlabel("Rank percentage (%)")
plt.ylabel("Accuracy")
plt.show()

In [None]:
# Rank Time vs Rank Percentage

plt.plot(ranks_and_times[0], ranks_and_times[1])
plt.label("Rank Time vs Rank Percentage")
plt.xlabel("Rank Percentage (%)")
plt.ylabel("Rank Times")
plt.show()

### Evaluations
At Factorized space

In [None]:
def low_rank_net_fn(batch: Batch, rank: float) -> jnp.ndarray:
  
  x = normalize(batch[0])
  total_input_dim = np.prod(x.shape[1:])

  #  Architecture code.
  net = hk.Sequential([
      hk.Conv2D(output_channels=6*3, kernel_shape=(5,5)),
      jax.nn.relu,
      hk.AvgPool(window_shape=(2,2), strides=(2,2), padding='VALID'),
      jax.nn.relu,
      hk.Conv2D(output_channels=16*3, kernel_shape=(5,5)),
      jax.nn.relu,
      hk.AvgPool(window_shape=(2,2), strides=(2,2), padding='VALID'),
      hk.Flatten(),
      hk.Linear(int(rank * min(total_input_dim, 3000)), with_bias=False),
      hk.Linear(3000), jax.nn.relu,
      hk.Linear(int(rank * 2000), with_bias=False), 
      hk.Linear(2000), jax.nn.relu,
      hk.Linear(int(rank * 2000), with_bias=False), 
      hk.Linear(2000), jax.nn.relu,      
      hk.Linear(int(rank * 1000), with_bias=False), 
      hk.Linear(1000), jax.nn.relu,
      hk.Linear(int(rank * 10), with_bias=False),
      hk.Linear(10),
  ])
  return net(x)

In [None]:
vanilla_to_low_rank_map = {
    'conv2_d': 'conv2_d',
    'conv2_d_1': 'conv2_d_1',
    'linear': ['linear', 'linear_1'],
    'linear_1': ['linear_2', 'linear_3'],
    'linear_2': ['linear_4', 'linear_5'],
    'linear_3': ['linear_6', 'linear_7'],
    'linear_4': ['linear_8', 'linear_9']
}

ranks_and_accuracies = []
ranks_and_times = []
for rank_fraction in np.arange(1.0, 0.0, -0.1):
  low_rank_net_fn_partial = partial(low_rank_net_fn, rank=rank_fraction)
  net = hk.without_apply_rng(hk.transform(low_rank_net_fn_partial)) 
  low_rank_params = net.init(jax.random.PRNGKey(42), next(train))

  print(f"Evaluating the model at" f"{rank_fraction:.2f}")

  for layer in vanilla_to_low_rank_map.keys():
    if 'conv' in layer:
      low_rank_params[layer] = params[layer]
      continue
    weight = params[layer]['w']
    u, v = rank_approximated_weight(weight, rank_fraction)
    low_rank_params[vanilla_to_low_rank_map[layer][0]]['w'] = u
    low_rank_params[vanilla_to_low_rank_map[layer][1]]['w'] = v
    low_rank_params[vanilla_to_low_rank_map[layer][1]]['b'] = params[layer]['b']
  
  test_accuracy, duration = compute_eval_metrics(low_rank_params, next(test), 50)
  ranks_and_times.append((rank_fraction, np.mean(duration)))
  ranks_and_accuracies.append((rank_fraction, np.mean(test_accuracy)))
  print(f"Rank Fraction / Test accuracy: "
          f"{rank_fraction:.2f} / {np.mean(test_accuracy):.3f}.")
  print(f"Rank Fraction / Duration: "
          f"{rank_fraction:.2f} / {np.mean(duration):.4f}.")



---

