In [1]:
import numpy
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions

In [2]:
specs = {
    'tensor': (
        '../urbangrammar_samba/'
        'spatial_signatures/chips/sample.npz'
    ),
}

In [3]:
data = numpy.load(specs["tensor"], allow_pickle=True)

In [4]:
start = 0
stop = start + 30000
chips = data["chips"]
labels = data["labels"]

n_classes = numpy.unique(labels).shape[0]
assert n_classes == numpy.unique(labels).max() + 1 # no label can be missing

In [5]:
batch_size = 32
split = int(chips.shape[0] * 0.8)

train_dataset = tf.data.Dataset.from_tensor_slices((chips[:split], labels[:split]))
test_dataset = tf.data.Dataset.from_tensor_slices((chips[split:], labels[split:]))

train_dataset = train_dataset.batch(batch_size=batch_size)
test_dataset = test_dataset.batch(batch_size=batch_size)

2021-11-26 15:11:47.775607: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 6684 MB memory:  -> device: 0, name: Quadro RTX 4000, pci bus id: 0000:21:00.0, compute capability: 7.5


In [6]:
preprocessing_and_augmentation = keras.Sequential(
    [
        layers.Resizing(224, 224, crop_to_aspect_ratio=True),
        layers.Rescaling(scale=1 / 32)
#         layers.RandomFlip("horizontal")
    ]
)

In [7]:
base_model = keras.applications.ResNet50(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(224, 224, 3),
    include_top=False, # Do not include the ImageNet classifier at the top.
)

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(32, 32, 3))
x = preprocessing_and_augmentation(inputs)
x = preprocess_input(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)

x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(128, activation='relu')(x) 
predictions = layers.Dense(n_classes, activation='softmax')(x)

model = keras.Model(inputs, predictions)

model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
sequential (Sequential)      (None, 224, 224, 3)       0         
_________________________________________________________________
tf.__operators__.getitem (Sl (None, 224, 224, 3)       0         
_________________________________________________________________
tf.nn.bias_add (TFOpLambda)  (None, 224, 224, 3)       0         
_________________________________________________________________
resnet50 (Functional)        (None, 7, 7, 2048)        23587712  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dense (Dense)                (None, 128)               262272

In [8]:
model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

In [9]:
%%time

epochs = 10
history = model.fit(train_dataset, epochs=epochs, validation_data=test_dataset)

Epoch 1/10


2021-11-26 15:11:52.868752: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
2021-11-26 15:11:53.703997: I tensorflow/stream_executor/cuda/cuda_dnn.cc:381] Loaded cuDNN version 8300


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
CPU times: user 8min 36s, sys: 59.1 s, total: 9min 35s
Wall time: 22min 2s


In [10]:
%%time
pred = model.predict(chips[-10:])
pred

CPU times: user 1.43 s, sys: 0 ns, total: 1.43 s
Wall time: 1.41 s


array([[5.08274615e-01, 1.17278344e-03, 2.25971016e-04, 1.58391667e-05,
        2.30911319e-05, 2.91605983e-02, 2.07503736e-02, 2.02414487e-03,
        1.37885145e-05, 4.38326538e-01, 1.22393385e-05, 2.77678271e-12,
        7.05190507e-12, 4.80482376e-10],
       [2.06201062e-01, 1.87824294e-03, 5.91045537e-05, 1.19773958e-05,
        2.30076103e-06, 3.72314900e-02, 2.65959352e-02, 5.99628501e-03,
        2.89296386e-05, 7.21963108e-01, 3.15322322e-05, 1.40392564e-09,
        4.73208972e-11, 1.22205774e-08],
       [4.09130096e-01, 2.15269509e-03, 1.05033114e-05, 7.18942147e-06,
        7.61205740e-07, 1.17917219e-02, 4.98673022e-02, 5.45486948e-03,
        7.85293651e-06, 5.21574914e-01, 2.11375254e-06, 1.70734542e-12,
        1.56410269e-12, 8.20332802e-10],
       [3.04485947e-01, 7.42831384e-04, 1.00997080e-04, 1.26559953e-05,
        5.26443273e-07, 2.85908543e-02, 3.63336802e-02, 2.50972196e-04,
        2.53830854e-06, 6.29477382e-01, 1.63504433e-06, 1.05667807e-12,
        6.187

In [20]:
pred.argmax(axis=1)

array([0, 9, 9, 9, 0, 9, 9, 0, 0, 9])

In [23]:
labels[-10:][:, 0]

array([0, 9, 0, 6, 9, 1, 9, 9, 0, 9], dtype=int8)