In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

In [2]:
import jax
import flax
import flax.linen as nn
import jax.numpy as jnp
from jax import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from jax_resnet import pretrained_resnet, Sequential, slice_variables, ResNet50
from scipy.optimize import linear_sum_assignment
import lib.DETR_jax as lb
import torch

In [3]:
# Based on the class defined above, we create training and validation datasets.
from transformers import DetrFeatureExtractor
from lib.DETR import CocoDetection
feature_extractor = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50")

feature_extractor.max_size = 256
feature_extractor.size = 128

DATA_BASE = 'data/custom/'
train_dataset = CocoDetection(img_folder=f'{DATA_BASE}/train', feature_extractor=feature_extractor)
val_dataset = CocoDetection(img_folder=f'{DATA_BASE}/val', feature_extractor=feature_extractor, train=False)

from torch.utils.data import DataLoader

def collate_fn(batch):
  pixel_values = [item[0] for item in batch]
  encoding = feature_extractor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")
  labels = [item[1] for item in batch]
  batch = {}
  batch['pixel_values'] = encoding['pixel_values']
  batch['pixel_mask'] = encoding['pixel_mask']
  batch['labels'] = labels
  return batch


def to_jax(batch):
    batch['pixel_values'] = jnp.array(batch['pixel_values'].numpy().transpose(0,2,3,1))
    batch['pixel_mask'] = jnp.array(batch['pixel_mask'])
    batch['labels'] = [{k: jnp.array(v) for k,v in n.items()} for n in batch['labels']]
    return batch

train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=2, shuffle=True, num_workers = 3)
val_dataloader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=2, shuffle=True, num_workers = 3)
cats = val_dataset.coco.cats
# Use this for the # classes
id2label = {k: v['name'] for k,v in cats.items()}

t_it = iter(train_dataloader)
batch = to_jax(next(t_it))

loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!


INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.


In [6]:
config = lb.TestConfig(output_dim=len(id2label))
rng = jax.random.PRNGKey(config.seed)
rng, init_rng = jax.random.split(rng)

model = lb.DETR_transformer(config)
init_rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)}
params = jax.jit(model.init)(init_rngs, batch['pixel_values'])
#params = lb.load_pretrained_resnet(params)

In [9]:
jax.devices()

[GpuDevice(id=0, process_index=0)]

In [7]:
outputs = model.apply(params, batch['pixel_values'], rngs = init_rngs)

RuntimeError: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(3990): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'

In [None]:

    
matcher = lb.DetrHungarianMatcher()
loss = lb.DetrLoss(matcher, 3, 0.2, ["labels", "boxes", "cardinality"])
loss.forward(outputs, batch['labels'])

In [15]:
batch['labels']

[{'boxes': DeviceArray([[0.5       , 0.14988428, 0.8625    , 0.1932871 ],
               [0.11046723, 0.32689232, 0.08343446, 0.04487854],
               [0.2702295 , 0.24954374, 0.09999999, 0.0538542 ],
               [0.40902254, 0.34259975, 0.37758607, 0.07629338],
               [0.11046716, 0.4361332 , 0.08343447, 0.04487854],
               [0.11046716, 0.5386497 , 0.08343447, 0.04487848],
               [0.11046723, 0.64789057, 0.08343446, 0.04487848],
               [0.11874992, 0.24954374, 0.1       , 0.0538542 ],
               [0.67334294, 0.24954374, 0.12710054, 0.0538542 ],
               [0.82334286, 0.24954374, 0.12710054, 0.0538542 ],
               [0.40902254, 0.4518406 , 0.37758607, 0.07629335],
               [0.40902254, 0.5543571 , 0.37758607, 0.07629335],
               [0.40902254, 0.663598  , 0.37758607, 0.07629335],
               [0.6893622 , 0.34259975, 0.15913929, 0.07629338],
               [0.6893622 , 0.4518406 , 0.15913929, 0.07629335],
               [

In [15]:
params['params'].keys()

dict_keys(['backbone', 'col_embed', 'feature_conv', 'linear_bbox', 'linear_class', 'queries', 'row_embed', 'transformer'])

In [16]:
model.apply(params, batch['pixel_values'], rngs = init_rngs)['pred_bbox']

DeviceArray([[[ 2.67246246e-01, -1.81899405e+00, -1.08822155e+00,
                7.24082530e-01],
              [ 1.31203067e+00, -1.42437947e+00, -1.59083635e-01,
               -5.95721126e-01],
              [ 5.95161140e-01, -1.66342878e+00, -8.07054579e-01,
                1.47104368e-01],
              [ 6.07963622e-01, -1.54283071e+00, -4.09295470e-01,
               -1.45494771e+00],
              [ 3.59880537e-01, -1.69999206e+00, -1.13488531e+00,
               -1.37670922e+00],
              [ 1.50208938e+00, -8.55092466e-01,  4.60382514e-02,
               -1.69564724e+00],
              [ 1.20096898e+00, -1.75632977e+00, -4.52695668e-01,
               -4.06302452e-01],
              [ 9.26728666e-01, -1.19546139e+00, -3.07611674e-02,
               -7.62400508e-01],
              [ 1.42015553e+00, -1.22326767e+00, -2.17968225e-01,
               -3.94537091e-01],
              [ 1.41192114e+00, -1.19288182e+00, -6.44537985e-01,
               -1.32771552e+00],
          

In [212]:
mutable_params = {k:v for k,v in params.items()}
for k,v in params['batch_stats']:

{'batch_stats': FrozenDict({
    backbone: {
        layers_0: {
            ConvBlock_0: {
                BatchNorm_0: {
                    mean: DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                 0., 0., 0., 0.], dtype=float32),
                    var: DeviceArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                                 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                                 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                                 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                                 1., 1., 1., 1.], dtype=

In [215]:
ResNeSt50, variables = pretrained_resnest(50)
model = ResNeSt50()
idx = 18 # gives B, 7,7,2048
backbone, backbone_variables = Sequential(model.layers[0:idx]), slice_variables(variables, end=idx)
for k,v in backbone_variables.items():
    print(k)

params
batch_stats


In [15]:
for q,t in params.items():
    print(q)
    for k,v in params['params'].items():
        print(k)

batch_stats
backbone
col_embed
feature_conv
linear_bbox
linear_class
queries
row_embed
transformer
params
backbone
col_embed
feature_conv
linear_bbox
linear_class
queries
row_embed
transformer


In [20]:
# standard PyTorch mean-std input image normalization
import torchvision.transforms as T
from PIL import Image
import requests
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)
img = transform(im).unsqueeze(0).numpy().transpose(0,2,3,1)

In [None]:
import optax
tx = optax.sgd(learning_rate=alpha)
opt_state = tx.init(params)


In [54]:
model.apply(initial_variables, test_inputs, test_targets, rngs = init_rngs).shape

(2, 100, 32)

In [None]:

# apply an optimizer to this tree
optimizer_def = optim.Adam(
  config.learning_rate,
  beta1=0.9,
  beta2=0.98,
  eps=1e-9,
  weight_decay=config.weight_decay)
optimizer = optimizer_def.create(initial_variables["params"])




In [197]:
ResNeSt50, variables = pretrained_resnest(50)
model = ResNeSt50()
idx = 18 # gives B, 7,7,2048
backbone, backbone_variables = Sequential(model.layers[0:idx]), slice_variables(variables, end=idx) 
output = backbone.apply(backbone_variables, jnp.ones((32, 224, 224, 3)),  # ImageNet sized inputs.
                  mutable=False)  # Ensure `batch_stats` aren't updated.


In [216]:
for k,v in backbone_variables['params'].items():
    print(k)

layers_0
layers_2
layers_3
layers_4
layers_5
layers_6
layers_7
layers_8
layers_9
layers_10
layers_11
layers_12
layers_13
layers_14
layers_15
layers_16
layers_17


In [6]:
test = jnp.array([[[1,2], [1,2], [1,2],[0,0]], [[1,2], [1,2], [1,2],[0,0]], [[1,2], [1,2], [1,2],[0,0]]])
test, test.shape

(DeviceArray([[[1, 2],
               [1, 2],
               [1, 2],
               [0, 0]],
 
              [[1, 2],
               [1, 2],
               [1, 2],
               [0, 0]],
 
              [[1, 2],
               [1, 2],
               [1, 2],
               [0, 0]]], dtype=int32),
 (3, 4, 2))

In [17]:
nn.make_attention_mask(test > 0, test > 0 ).squeeze()[0]

DeviceArray([[[1., 1.],
              [1., 1.]],

             [[1., 1.],
              [1., 1.]],

             [[1., 1.],
              [1., 1.]],

             [[0., 0.],
              [0., 0.]]], dtype=float32)