In [1]:
from domain_adaptation.datasets import SwaVDataset
from domain_adaptation.archs import deepcluster
from domain_adaptation.models import resnet
import pathlib
import os
import matplotlib.pyplot as plt
from itertools import groupby, compress

import tensorflow as tf

In [2]:
def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_jpeg(img, channels=3)
  # resize the image to the desired size
  return tf.image.resize(img, [180, 180])

def process_path(file_path):
    label = tf.strings.split(file_path, os.sep)[-2]
    img = tf.io.read_file(file_path)
    img = decode_img(img)
    return img

# labeled_ds = ds.shuffle(1024).map(process_path)

In [3]:
flowers_root = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
flowers_root = pathlib.Path(flowers_root)

In [4]:
flowers_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))
flowers_ds = flowers_ds.map(process_path)

In [5]:
a = SwaVDataset.SwaVDataset(flowers_ds.take(128), 
                            nmb_crops=[3, 4],
                            size_crops=[224, 168],
                            min_scale_crops=[0.14, 0.16],
                            max_scale_crops=[1., 1.])

In [6]:
model = resnet.Resnet50().model
swav_mod = deepcluster.DeepCluster(model=model, p_d1=1024, 
feat_dim=128, nmb_prototypes=[50, 50, 50], crops_for_assign=[0, 1])

In [7]:
swav_mod.prototype_model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 2048)]       0                                            
__________________________________________________________________________________________________
dense (Dense)                   (None, 1024)         2098176     input_2[0][0]                    
__________________________________________________________________________________________________
batch_normalization_53 (BatchNo (None, 1024)         4096        dense[0][0]                      
__________________________________________________________________________________________________
activation_49 (Activation)      (None, 1024)         0           batch_normalization_53[0][0]     
______________________________________________________________________________________________

In [8]:
decay_steps = 1000
lr_decayed_fn = tf.keras.experimental.CosineDecay(initial_learning_rate=0.1, 
                                                decay_steps=decay_steps)
opt = tf.keras.optimizers.SGD(learning_rate=lr_decayed_fn)

hist = swav_mod.fit(a, optimizer=opt, epochs=10)

Initializing memory banks...
100%|██████████| 8/8 [00:45<00:00,  5.69s/it]
Epoch 1/10:   0%|          | 0/8 [00:16<?, ?it/s]


ValueError: No gradients provided for any variable: ['conv2d/kernel:0', 'conv2d/bias:0', 'batch_normalization/gamma:0', 'batch_normalization/beta:0', 'conv2d_1/kernel:0', 'conv2d_1/bias:0', 'batch_normalization_1/gamma:0', 'batch_normalization_1/beta:0', 'conv2d_2/kernel:0', 'conv2d_2/bias:0', 'batch_normalization_2/gamma:0', 'batch_normalization_2/beta:0', 'conv2d_3/kernel:0', 'conv2d_3/bias:0', 'conv2d_4/kernel:0', 'conv2d_4/bias:0', 'batch_normalization_3/gamma:0', 'batch_normalization_3/beta:0', 'batch_normalization_4/gamma:0', 'batch_normalization_4/beta:0', 'conv2d_5/kernel:0', 'conv2d_5/bias:0', 'batch_normalization_5/gamma:0', 'batch_normalization_5/beta:0', 'conv2d_6/kernel:0', 'conv2d_6/bias:0', 'batch_normalization_6/gamma:0', 'batch_normalization_6/beta:0', 'conv2d_7/kernel:0', 'conv2d_7/bias:0', 'batch_normalization_7/gamma:0', 'batch_normalization_7/beta:0', 'conv2d_8/kernel:0', 'conv2d_8/bias:0', 'batch_normalization_8/gamma:0', 'batch_normalization_8/beta:0', 'conv2d_9/kernel:0', 'conv2d_9/bias:0', 'batch_normalization_9/gamma:0', 'batch_normalization_9/beta:0', 'conv2d_10/kernel:0', 'conv2d_10/bias:0', 'batch_normalization_10/gamma:0', 'batch_normalization_10/beta:0', 'conv2d_11/kernel:0', 'conv2d_11/bias:0', 'batch_normalization_11/gamma:0', 'batch_normalization_11/beta:0', 'conv2d_12/kernel:0', 'conv2d_12/bias:0', 'batch_normalization_12/gamma:0', 'batch_normalization_12/beta:0', 'conv2d_13/kernel:0', 'conv2d_13/bias:0', 'conv2d_14/kernel:0', 'conv2d_14/bias:0', 'batch_normalization_13/gamma:0', 'batch_normalization_13/beta:0', 'batch_normalization_14/gamma:0', 'batch_normalization_14/beta:0', 'conv2d_15/kernel:0', 'conv2d_15/bias:0', 'batch_normalization_15/gamma:0', 'batch_normalization_15/beta:0', 'conv2d_16/kernel:0', 'conv2d_16/bias:0', 'batch_normalization_16/gamma:0', 'batch_normalization_16/beta:0', 'conv2d_17/kernel:0', 'conv2d_17/bias:0', 'batch_normalization_17/gamma:0', 'batch_normalization_17/beta:0', 'conv2d_18/kernel:0', 'conv2d_18/bias:0', 'batch_normalization_18/gamma:0', 'batch_normalization_18/beta:0', 'conv2d_19/kernel:0', 'conv2d_19/bias:0', 'batch_normalization_19/gamma:0', 'batch_normalization_19/beta:0', 'conv2d_20/kernel:0', 'conv2d_20/bias:0', 'batch_normalization_20/gamma:0', 'batch_normalization_20/beta:0', 'conv2d_21/kernel:0', 'conv2d_21/bias:0', 'batch_normalization_21/gamma:0', 'batch_normalization_21/beta:0', 'conv2d_22/kernel:0', 'conv2d_22/bias:0', 'batch_normalization_22/gamma:0', 'batch_normalization_22/beta:0', 'conv2d_23/kernel:0', 'conv2d_23/bias:0', 'batch_normalization_23/gamma:0', 'batch_normalization_23/beta:0', 'conv2d_24/kernel:0', 'conv2d_24/bias:0', 'batch_normalization_24/gamma:0', 'batch_normalization_24/beta:0', 'conv2d_25/kernel:0', 'conv2d_25/bias:0', 'batch_normalization_25/gamma:0', 'batch_normalization_25/beta:0', 'conv2d_26/kernel:0', 'conv2d_26/bias:0', 'conv2d_27/kernel:0', 'conv2d_27/bias:0', 'batch_normalization_26/gamma:0', 'batch_normalization_26/beta:0', 'batch_normalization_27/gamma:0', 'batch_normalization_27/beta:0', 'conv2d_28/kernel:0', 'conv2d_28/bias:0', 'batch_normalization_28/gamma:0', 'batch_normalization_28/beta:0', 'conv2d_29/kernel:0', 'conv2d_29/bias:0', 'batch_normalization_29/gamma:0', 'batch_normalization_29/beta:0', 'conv2d_30/kernel:0', 'conv2d_30/bias:0', 'batch_normalization_30/gamma:0', 'batch_normalization_30/beta:0', 'conv2d_31/kernel:0', 'conv2d_31/bias:0', 'batch_normalization_31/gamma:0', 'batch_normalization_31/beta:0', 'conv2d_32/kernel:0', 'conv2d_32/bias:0', 'batch_normalization_32/gamma:0', 'batch_normalization_32/beta:0', 'conv2d_33/kernel:0', 'conv2d_33/bias:0', 'batch_normalization_33/gamma:0', 'batch_normalization_33/beta:0', 'conv2d_34/kernel:0', 'conv2d_34/bias:0', 'batch_normalization_34/gamma:0', 'batch_normalization_34/beta:0', 'conv2d_35/kernel:0', 'conv2d_35/bias:0', 'batch_normalization_35/gamma:0', 'batch_normalization_35/beta:0', 'conv2d_36/kernel:0', 'conv2d_36/bias:0', 'batch_normalization_36/gamma:0', 'batch_normalization_36/beta:0', 'conv2d_37/kernel:0', 'conv2d_37/bias:0', 'batch_normalization_37/gamma:0', 'batch_normalization_37/beta:0', 'conv2d_38/kernel:0', 'conv2d_38/bias:0', 'batch_normalization_38/gamma:0', 'batch_normalization_38/beta:0', 'conv2d_39/kernel:0', 'conv2d_39/bias:0', 'batch_normalization_39/gamma:0', 'batch_normalization_39/beta:0', 'conv2d_40/kernel:0', 'conv2d_40/bias:0', 'batch_normalization_40/gamma:0', 'batch_normalization_40/beta:0', 'conv2d_41/kernel:0', 'conv2d_41/bias:0', 'batch_normalization_41/gamma:0', 'batch_normalization_41/beta:0', 'conv2d_42/kernel:0', 'conv2d_42/bias:0', 'batch_normalization_42/gamma:0', 'batch_normalization_42/beta:0', 'conv2d_43/kernel:0', 'conv2d_43/bias:0', 'batch_normalization_43/gamma:0', 'batch_normalization_43/beta:0', 'conv2d_44/kernel:0', 'conv2d_44/bias:0', 'batch_normalization_44/gamma:0', 'batch_normalization_44/beta:0', 'conv2d_45/kernel:0', 'conv2d_45/bias:0', 'conv2d_46/kernel:0', 'conv2d_46/bias:0', 'batch_normalization_45/gamma:0', 'batch_normalization_45/beta:0', 'batch_normalization_46/gamma:0', 'batch_normalization_46/beta:0', 'conv2d_47/kernel:0', 'conv2d_47/bias:0', 'batch_normalization_47/gamma:0', 'batch_normalization_47/beta:0', 'conv2d_48/kernel:0', 'conv2d_48/bias:0', 'batch_normalization_48/gamma:0', 'batch_normalization_48/beta:0', 'conv2d_49/kernel:0', 'conv2d_49/bias:0', 'batch_normalization_49/gamma:0', 'batch_normalization_49/beta:0', 'conv2d_50/kernel:0', 'conv2d_50/bias:0', 'batch_normalization_50/gamma:0', 'batch_normalization_50/beta:0', 'conv2d_51/kernel:0', 'conv2d_51/bias:0', 'batch_normalization_51/gamma:0', 'batch_normalization_51/beta:0', 'conv2d_52/kernel:0', 'conv2d_52/bias:0', 'batch_normalization_52/gamma:0', 'batch_normalization_52/beta:0', 'dense/kernel:0', 'dense/bias:0', 'batch_normalization_53/gamma:0', 'batch_normalization_53/beta:0', 'projection/kernel:0', 'projection/bias:0', 'prototype_0/kernel:0', 'prototype_1/kernel:0', 'prototype_2/kernel:0'].

In [22]:
inputs = next(iter(a.dataset_swaved))
images = inputs
b_s = images[0].shape[0]
crop_sizes = [img.shape[1] for img in images]
idx_crops = tf.math.cumsum(
    [len(list(g)) for _, g in groupby(crop_sizes)], axis=0)
start = 0
with tf.GradientTape() as tape:
    tape.watch(inputs)
    for end in idx_crops:
        concat_input = tf.stop_gradient(
            tf.concat(values=inputs[start:end], axis=0))
        _embedding = swav_mod.model(concat_input)
        if start == 0:
            embeddings = _embedding
        else:
            embeddings = tf.concat(values=(embeddings, _embedding),
                                    axis=0)
        start = end
    projection, prototypes = swav_mod.prototype_model(embeddings)
    prototypes = [tf.stop_gradient(pro) for pro in prototypes] 
    tape.watch(swav_mod.assignements)
    scores = prototypes[0]/ 0.1
    targets = tf.tile(swav_mod.assignements[0][0:16], [7])
    tape.watch(tf.cast(swav_mod.assignements[0][0:16], dtype=tf.float32))
    tape.watch(scores)
    tape.watch(targets)
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(targets,scores))



In [23]:
varrs = (swav_mod.model.trainable_variables + swav_mod.prototype_model.trainable_variables)

In [24]:
gradients = tape.gradient(loss, varrs)