# Jax YoloV1

In [1]:
!pip install --upgrade -q pip jax jaxlib
!pip install --upgrade -q git+https://github.com/google/flax.git

[K     |████████████████████████████████| 1.7 MB 5.2 MB/s 
[K     |████████████████████████████████| 850 kB 46.3 MB/s 
[K     |████████████████████████████████| 62.2 MB 1.2 MB/s 
[?25h  Building wheel for jax (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
     |████████████████████████████████| 126 kB 5.3 MB/s            
     |████████████████████████████████| 65 kB 2.6 MB/s             
[?25h  Building wheel for flax (setup.py) ... [?25l[?25hdone


In [2]:
#for GPU support in colab
!dpkg -i "/content/libcudnn8_8.1.0.77-1+cuda11.2_amd64.deb"

(Reading database ... 155225 files and directories currently installed.)
Preparing to unpack .../libcudnn8_8.1.0.77-1+cuda11.2_amd64.deb ...
Unpacking libcudnn8 (8.1.0.77-1+cuda11.2) over (8.0.5.39-1+cuda11.1) ...
Setting up libcudnn8 (8.1.0.77-1+cuda11.2) ...
Processing triggers for libc-bin (2.27-3ubuntu1.3) ...
/sbin/ldconfig.real: /usr/local/lib/python3.7/dist-packages/ideep4py/lib/libmkldnn.so.0 is not a symbolic link



In [4]:
from jax import nn as jnn, numpy as jnp, random
from flax import linen as nn
from typing import Sequence

#Each tuple represents a convolutional or maxpool layer in the YoloV1 architecture
#(Features, Dimension of Kernel Size, Stride, Padding on either side) for convolutional layers
#(0,) represents 2 X 2 max pool layer with stride of 2
MODEL_ARCHITECTURE = [
  (64, 7, 2, 3),
  (0,),
  (192, 3, 1, 1),
  (0,),
  (128, 1, 1, 0),
  (256, 3, 1, 1),
  (256, 1, 1, 0),
  (512, 3, 1, 1),
  (0,),
  (256, 1, 1, 0),
  (512, 3, 1, 1),
  (256, 1, 1, 0),
  (512, 3, 1, 1),
  (256, 1, 1, 0),
  (512, 3, 1, 1),
  (256, 1, 1, 0),
  (512, 3, 1, 1),
  (512, 1, 1, 0),
  (1024, 3, 1, 1),
  (0,),
  (512, 1, 1, 0),
  (1024, 3, 1, 1),
  (512, 1, 1, 0),
  (1024, 3, 1, 1),
  (1024, 3, 2, 1),
  (1024, 3, 1, 1),
  (1024, 3, 1, 1)
]

#Each max pool layer in YoloV1 is identical
def max_pool_layer(x):
  return nn.max_pool(x, (2, 2), (2, 2))

class YoloV1(nn.Module):
  #Properties of each convolutional layer
  conv_structures: Sequence[tuple]

  #split size
  S: int

  #number of bounding boxes per grid position
  B: int

  #number of classes
  C: int

  #For N X N output grid, split_size represents N
  #num_boxes represents number of bounding boxes per grid position
  #num_classes represents how many object classes the model can detect
  #split_size X split_size is number of grid positions
  #num_boxes * (x, y, w, h, and confidence for each box) + one probability value for each class per box is number of values per grid position
  #multiply values per grid spot by number of grid spots for number of neurons in output layer
  def get_output_length(self, split_size, num_boxes, num_classes):
    return split_size * split_size * (5 * num_boxes + num_classes)

  def setup(self):
    #converting the model architecture to flax layers
    self.conv_layers = [nn.Conv(conv_structure[0], (conv_structure[1], conv_structure[1]), (conv_structure[2], conv_structure[2]), [(conv_structure[3], conv_structure[3]), (conv_structure[3], conv_structure[3])]) if len(conv_structure)==4 else max_pool_layer for conv_structure in self.conv_structures]
    
    #actual model has hidden layer with 4096 neurons, using 496 to make training/inference time more reasonable
    self.dense_layers = [nn.Dense(496), nn.Dense(self.get_output_length(self.S, self.B, self.C))]

  def __call__(self, inputs):
    x = inputs
    for conv_layer in self.conv_layers:
      x = conv_layer(x)
      print(x.shape) #make sure the shapes of each layer match the paper's model architecture
      
      #activation function for each convolutional layer
      if conv_layer != max_pool_layer:
        x = jnn.leaky_relu(x, 0.1)
    
    #flattening to pass into dense layers
    x = jnp.ravel(x)

    for i, dense_layer in enumerate(self.dense_layers):
      x = dense_layer(x)
      if i != len(self.dense_layers) - 1:
        x = jnn.leaky_relu(x)
    return x

#initializing using default values from paper
model = YoloV1(conv_structures=MODEL_ARCHITECTURE, S=7, B=2, C=20)
key1, key2 = random.split(random.PRNGKey(1000), 2)

#model works with 448 X 448 RGB images
inputs = random.uniform(key1, (448, 448, 3))
params = model.init(key2, inputs)

(224, 224, 64)
(112, 112, 64)
(112, 112, 192)
(56, 56, 192)
(56, 56, 128)
(56, 56, 256)
(56, 56, 256)
(56, 56, 512)
(28, 28, 512)
(28, 28, 256)
(28, 28, 512)
(28, 28, 256)
(28, 28, 512)
(28, 28, 256)
(28, 28, 512)
(28, 28, 256)
(28, 28, 512)
(28, 28, 512)
(28, 28, 1024)
(14, 14, 1024)
(14, 14, 512)
(14, 14, 1024)
(14, 14, 512)
(14, 14, 1024)
(7, 7, 1024)
(7, 7, 1024)
(7, 7, 1024)


In [5]:
from IPython.display import clear_output
y = model.apply(params, inputs)
clear_output()
#number of output neurons = 7 * 7 * (20 + 2 * 5) = 1470
y.shape

(1470,)