https://github.com/ayulockin/SwAV-TF  
https://github.com/ayulockin/SwAV-TF/blob/master/Train_SwAV_40_epochs.ipynb

# SwAV-TF
- 本家の実装とは少し違う

    - キューを使用しない。本家の実装ではプロトタイプの数がバッチサイズよりも大きくなったときに小さなバッチを使用する場合、本家はキューを維持
    - 224x224の解像度の2つのクロップと、96x96の解像度の3つのクロップを使用。マルチクロップの提案された設定とは違う
    - 15個のプロトタイプを使用。元の論文では、著者はImageNetデータセットに3000のプロトタイプを使用 
    - 基本学習率0.1のコサイン減衰スケジュールとともにSGDを使用。本家の実装は習率スケジュールにウォームアップとコサイン減衰の組み合わせ



SwAVは、同じ画像からのクロップ画像たちが属するクラスター(プロトタイプベクトル)は同じ、というのを学習

このクラスターのラベル(ソフトラベル)はSinkhorn-Knoppアルゴリズムと呼ばれるアルゴリズムで各バッチで毎回生成

https://qiita.com/omiita/items/a7429ec42e4eef4b6a4d

https://wandb.ai/authors/swav-tf/reports/Unsupervised-Visual-Representation-Learning-with-SwAV--VmlldzoyMjg3Mzg

SwAVのアーキテクチャ

![swav](https://i.ibb.co/TtSW4Fd/figure-3.png)

![swav2](https://i.ibb.co/2FGDvd6/figure-6.png)

![swav3](https://i.ibb.co/jgm7J81/figure-7.png)

- 同じ画像をランダムにトリミングして高解像度（例：224x224）低解像度（例：96x96）のマルチクロップ画像にcolor distortion, random flipping, and random grayscalingなどのaugmentation を順番に適用（SimCLR のaugmentation）
- CNN（ResNet50）で埋め込み（最後のグローバル平均プーリングレイヤーからの出力）ベクトルを取得
- この埋め込みベクトルを浅い非線形ネットワークに送り、その出力が射影ベクトルZ
- 射影ベクトルZを単一の線形層に渡す。つまり、このレイヤーには非線形性が含まれていません。レイヤーの出力は、Zとプロトタイプの間の内積。このレイヤーの関連する「重み」マトリックス（バックプロパゲーション中に更新されたもの）は、学習可能なプロトタイプバンクと見なすことができる
- Sinkhorn Knopp アルゴリズムを使用して同じ画像の2つの別々のビュー間でスワップされた予測問題を設定

In [None]:
#!git clone https://github.com/ayulockin/SwAV-TF.git

In [None]:
# https://github.com/ayulockin/SwAV-TF/tree/master/utils
# multicrop_dataset.py
import tensorflow as tf
import random

AUTO = tf.data.experimental.AUTOTUNE

# Reference: https://github.com/google-research/simclr/blob/master/data_util.py

@tf.function
def gaussian_blur(image, kernel_size=23, padding='SAME'):
    sigma = tf.random.uniform((1,))* 1.9 + 0.1

    radius = tf.cast(kernel_size / 2, tf.int32)
    kernel_size = radius * 2 + 1
    x = tf.cast(tf.range(-radius, radius + 1), tf.float32)
    blur_filter = tf.exp(
        -tf.pow(x, 2.0) / (2.0 * tf.pow(tf.cast(sigma, tf.float32), 2.0)))
    blur_filter /= tf.reduce_sum(blur_filter)
    # One vertical and one horizontal filter.
    blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1])
    blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1])
    num_channels = tf.shape(image)[-1]
    blur_h = tf.tile(blur_h, [1, 1, num_channels, 1])
    blur_v = tf.tile(blur_v, [1, 1, num_channels, 1])
    expand_batch_dim = image.shape.ndims == 3
    if expand_batch_dim:
        image = tf.expand_dims(image, axis=0)
    blurred = tf.nn.depthwise_conv2d(
        image, blur_h, strides=[1, 1, 1, 1], padding=padding)
    blurred = tf.nn.depthwise_conv2d(
        blurred, blur_v, strides=[1, 1, 1, 1], padding=padding)
    if expand_batch_dim:
        blurred = tf.squeeze(blurred, axis=0)
    return blurred

@tf.function
def color_jitter(x, s=0.5):
    x = tf.image.random_brightness(x, max_delta=0.8*s)
    x = tf.image.random_contrast(x, lower=1-0.8*s, upper=1+0.8*s)
    x = tf.image.random_saturation(x, lower=1-0.8*s, upper=1+0.8*s)
    x = tf.image.random_hue(x, max_delta=0.2*s)
    x = tf.clip_by_value(x, 0, 1)
    return x

@tf.function
def color_drop(x):
    x = tf.image.rgb_to_grayscale(x)
    x = tf.tile(x, [1, 1, 3])
    return x

@tf.function
def custom_augment(image):
    # Random flips
    image = random_apply(tf.image.flip_left_right, image, p=0.5)
    # Randomly apply gausian blur
    image = random_apply(gaussian_blur, image, p=0.5)
    # Randomly apply transformation (color distortions) with probability p.
    image = random_apply(color_jitter, image, p=0.8)
    # Randomly apply grayscale
    image = random_apply(color_drop, image, p=0.2)

    return image

@tf.function
def random_resize_crop(image, min_scale, max_scale, crop_size):
    # Conditional resizing
    if crop_size == 224:
        image_shape = 260
        image = tf.image.resize(image, (image_shape, image_shape))
    else:
        image_shape = 160
        image = tf.image.resize(image, (image_shape, image_shape))
    # Get the crop size for given min and max scale
    size = tf.random.uniform(shape=(1,), minval=min_scale*image_shape,
        maxval=max_scale*image_shape, dtype=tf.float32)
    size = tf.cast(size, tf.int32)[0]
    # Get the crop from the image
    crop = tf.image.random_crop(image, (size,size,3))
    crop_resize = tf.image.resize(crop, (crop_size, crop_size))

    return crop_resize

@tf.function
def random_apply(func, x, p):
    return tf.cond(
        tf.less(tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
                tf.cast(p, tf.float32)),
        lambda: func(x),
        lambda: x)

@tf.function
def scale_image(image):
    image = tf.image.convert_image_dtype(image, tf.float32)
    return image

@tf.function
def tie_together(image, min_scale, max_scale, crop_size):
    # Retrieve the image features
    image = image['image']
    # Scale the pixel values
    image = scale_image(image)
    # Random resized crops
    image = random_resize_crop(image, min_scale,
        max_scale, crop_size)
    # Color distortions & Gaussian blur
    image = custom_augment(image)

    return image

def get_multires_dataset(dataset,
    size_crops,
    num_crops,
    min_scale,
    max_scale,
    options=None):
    loaders = tuple()
    for i, num_crop in enumerate(num_crops):
        for _ in range(num_crop):
            loader = (
                    dataset
                    .shuffle(1024)
                    .map(lambda x: tie_together(x, min_scale[i],
                        max_scale[i], size_crops[i]), num_parallel_calls=AUTO)
                )
            if options!=None:
                loader = loader.with_options(options)
            loaders += (loader, )

    return loaders

def shuffle_zipped_output(a,b,c,d,e):
    listify = [a,b,c,d,e]
    random.shuffle(listify)

    return listify[0], listify[1], listify[2], \
        listify[3], listify[4]

In [None]:
# https://github.com/ayulockin/SwAV-TF/tree/master/utils
# architecture.py
from tensorflow.keras import layers
import tensorflow as tf

def get_resnet_backbone():
    base_model = tf.keras.applications.ResNet50(
        include_top=False, weights=None, input_shape=(None, None, 3)
    )
    base_model.trainabe = True

    inputs = layers.Input((None, None, 3))
    h = base_model(inputs, training=True)
    h = layers.GlobalAveragePooling2D()(h)
    backbone = tf.keras.models.Model(inputs, h)

    return backbone

def get_projection_prototype(dense_1=1024, dense_2=96, prototype_dimension=10):
    inputs = layers.Input((2048, ))
    projection_1 = layers.Dense(dense_1)(inputs)
    projection_1 = layers.BatchNormalization()(projection_1)
    projection_1 = layers.Activation("relu")(projection_1)

    projection_2 = layers.Dense(dense_2)(projection_1)
    projection_2_normalize = tf.math.l2_normalize(projection_2, axis=1, name='projection')

    prototype = layers.Dense(prototype_dimension, use_bias=False, name='prototype')(projection_2_normalize)

    return tf.keras.models.Model(inputs=inputs,
        outputs=[projection_2_normalize, prototype])

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt
import numpy as np
import random
import time
import os

from itertools import groupby
from tqdm import tqdm

tf.random.set_seed(666)
np.random.seed(666)

tfds.disable_progress_bar()

print("Tensorflow version " + tf.__version__)

# params

In [None]:
#epochs = 40
epochs = 25  # 9時間以内に収めるため。1epoch~20分

BATCH_SIZE = 32

n_classes = 5  # キャッサバデータは5クラス
n_prototype = n_classes * (3000 / 1000)  # プロトタイプの数.論文の比率と合わせる

# エポック数減らして実行テスト
#DEBUG = True
DEBUG = False
if DEBUG:
    epochs = 2
    print("DEBUG")

# Cassava data
- https://www.kaggle.com/jessemostipak/getting-started-tpus-cassava-leaf-disease

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    return image

In [None]:
def read_tfrecord(example, labeled):
    tfrecord_format = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.int64)
    } if labeled else {
        "image": tf.io.FixedLenFeature([], tf.string),
        "image_name": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    if labeled:
        label = tf.cast(example['target'], tf.int32)
        return {"image": image, "label": label}
    idnum = example['image_name']
    return {"image": image, "label": idnum}

In [None]:
from functools import partial

AUTOTUNE = tf.data.experimental.AUTOTUNE

def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(partial(read_tfrecord, labeled=labeled), num_parallel_calls=AUTOTUNE)
    return dataset

In [None]:
import re

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [None]:
from kaggle_datasets import KaggleDatasets
from sklearn.model_selection import train_test_split

GCS_PATH = KaggleDatasets().get_gcs_path()

train_tfrecs = tf.io.gfile.glob(GCS_PATH + '/train_tfrecords/ld_train*.tfrec')
    
TRAINING_FILENAMES, _ = train_test_split(train_tfrecs,
                                         test_size=0.35, 
                                         random_state=5)

NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)

print(f"NUM_TRAINING_IMAGES: {NUM_TRAINING_IMAGES}")

# Multi Crop Resize Data Augmentation

In [None]:
# Configs
SIZE_CROPS = [224, 96]
NUM_CROPS = [2, 3]
MIN_SCALE = [0.14, 0.05] 
MAX_SCALE = [1., 0.14]

# Experimental options
options = tf.data.Options()
options.experimental_optimization.noop_elimination = True
options.experimental_optimization.map_vectorization.enabled = True
options.experimental_optimization.apply_default_optimizations = True
options.experimental_deterministic = False
options.experimental_threading.max_intra_op_parallelism = 1

In [None]:
dataset = load_dataset(TRAINING_FILENAMES, labeled=True) 

trainloaders = get_multires_dataset(dataset,
    size_crops=SIZE_CROPS,
    num_crops=NUM_CROPS,
    min_scale=MIN_SCALE,
    max_scale=MAX_SCALE,
    options=options)

trainloaders

In [None]:
# Prepare the final data loader

AUTO = tf.data.experimental.AUTOTUNE

# Zipping 
trainloaders_zipped = tf.data.Dataset.zip(trainloaders)

# Final trainloader
trainloaders_zipped = (
    trainloaders_zipped
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

im1, im2, im3, im4, im5 = next(iter(trainloaders_zipped))
print(im1.shape, im2.shape, im3.shape, im4.shape, im5.shape)

In [None]:
plt.figure(figsize=(15, 15))
for i, image_batch  in enumerate([im1, im2, im3, im4, im5]):
    ax = plt.subplot(4, 5, i + 1)
    plt.imshow(image_batch[0])
    ax = plt.subplot(4, 5, 5 + i + 1)
    plt.imshow(image_batch[1])
    ax = plt.subplot(4, 5, 10 + i + 1)
    plt.imshow(image_batch[2])
    ax = plt.subplot(4, 5, 15 + i + 1)
    plt.imshow(image_batch[3])

# Model Architecture

In [None]:
feature_backbone = get_resnet_backbone()
feature_backbone.summary()

In [None]:
projection_prototype = get_projection_prototype(n_prototype)
projection_prototype.summary()

In [None]:
embedding_batch = feature_backbone(im1)
embedding_batch.shape

In [None]:
projection, prototype = projection_prototype(embedding_batch)
projection.shape, prototype.shape

# Sinkhorn Knopp for Cluster Assignment
Reference: A.1 from https://arxiv.org/abs/2006.09882

In [None]:
def sinkhorn(sample_prototype_batch, n_iters=3):
    Q = tf.transpose(tf.exp(sample_prototype_batch/0.05))
    Q /= tf.keras.backend.sum(Q)
    K, B = Q.shape

    u = tf.zeros_like(K, dtype=tf.float32)
    r = tf.ones_like(K, dtype=tf.float32) / K
    c = tf.ones_like(B, dtype=tf.float32) / B

    for _ in range(n_iters):
        u = tf.keras.backend.sum(Q, axis=1)
        Q *= tf.expand_dims((r / u), axis=1)
        Q *= tf.expand_dims(c / tf.keras.backend.sum(Q, axis=0), 0)

    final_quantity = Q / tf.keras.backend.sum(Q, axis=0, keepdims=True)
    final_quantity = tf.transpose(final_quantity)

    return final_quantity

In [None]:
# Check
final_q = sinkhorn(prototype)
final_q.shape

# Train Step

In [None]:
# @tf.function
# Reference: https://github.com/facebookresearch/swav/blob/master/main_swav.py
def train_step(input_views, feature_backbone, projection_prototype, 
               optimizer, crops_for_assign, temperature):
    # ============ retrieve input data ... ============
    im1, im2, im3, im4, im5 = input_views 
    inputs = [im1, im2, im3, im4, im5]
    batch_size = inputs[0].shape[0]

    # ============ create crop entries with same shape ... ============
    crop_sizes = [inp.shape[1] for inp in inputs] # list of crop size of views
    unique_consecutive_count = [len([elem for elem in g]) for _, g in groupby(crop_sizes)] # equivalent to torch.unique_consecutive
    idx_crops = tf.cumsum(unique_consecutive_count)
    
    # ============ multi-res forward passes ... ============
    start_idx = 0
    with tf.GradientTape() as tape:
        for end_idx in idx_crops:
            concat_input = tf.stop_gradient(tf.concat(inputs[start_idx:end_idx], axis=0))
            _embedding = feature_backbone(concat_input) # get embedding of same dim views together
            if start_idx == 0:
                embeddings = _embedding # for first iter
            else:
                embeddings = tf.concat((embeddings, _embedding), axis=0) # concat all the embeddings from all the views
            start_idx = end_idx
        
        projection, prototype = projection_prototype(embeddings) # get normalized projection and prototype
        projection = tf.stop_gradient(projection)

        # ============ swav loss ... ============
        # https://github.com/facebookresearch/swav/issues/19
        loss = 0
        for i, crop_id in enumerate(crops_for_assign): # crops_for_assign = [0,1]
            with tape.stop_recording():
                out = prototype[batch_size * crop_id: batch_size * (crop_id + 1)]
                
                # get assignments
                q = sinkhorn(out) # sinkhorn is used for cluster assignment
            
            # cluster assignment prediction
            subloss = 0
            for v in np.delete(np.arange(np.sum(NUM_CROPS)), crop_id): # (for rest of the portions compute p and take cross entropy with q)
                p = tf.nn.softmax(prototype[batch_size * v: batch_size * (v + 1)] / temperature) 
                subloss -= tf.math.reduce_mean(tf.math.reduce_sum(q * tf.math.log(p), axis=1))
            loss += subloss / tf.cast((tf.reduce_sum(NUM_CROPS) - 1), tf.float32)
        
        loss /= len(crops_for_assign)

    # ============ backprop ... ============
    variables = feature_backbone.trainable_variables + projection_prototype.trainable_variables
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))

    return loss

# Training Loop

In [None]:
def train_swav(feature_backbone, 
               projection_prototype, 
               dataloader, 
               optimizer, 
               crops_for_assign,
               temperature, 
               epochs=50):
  
    step_wise_loss = []
    epoch_wise_loss = []
    
    for epoch in tqdm(range(epochs)):
        w = projection_prototype.get_layer('prototype').get_weights()
        w = tf.transpose(w)
        w = tf.math.l2_normalize(w, axis=1)
        projection_prototype.get_layer('prototype').set_weights(tf.transpose(w))

        for i, inputs in enumerate(dataloader):
            loss = train_step(inputs, feature_backbone, projection_prototype, 
                              optimizer, crops_for_assign, temperature)
            step_wise_loss.append(loss)

        epoch_wise_loss.append(np.mean(step_wise_loss))
        #wandb.log({'epoch': epoch, 'loss':np.mean(step_wise_loss)})
        
        if epoch % 5 == 0:
            print("epoch: {} loss: {:.3f}".format(epoch + 1, np.mean(step_wise_loss)))

    return epoch_wise_loss, [feature_backbone, projection_prototype]

In [None]:
%%time
# ============ initialize the networks and the optimizer ... ============
feature_backbone = get_resnet_backbone()
projection_prototype = get_projection_prototype(10)

lr_decayed_fn = tf.keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=0.1,
    decay_steps=500,
    end_learning_rate=0.001,
    power=0.5)
opt = tf.keras.optimizers.SGD(learning_rate=lr_decayed_fn)

# ================= initialize wandb ======================
#wandb.init(entity='authors', project='swav-tf', id='40-epochs')

# ============ train for 40 epochs ... ============
epoch_wise_loss, models = train_swav(feature_backbone, 
    projection_prototype, 
    trainloaders_zipped, 
    opt, 
    crops_for_assign=[0, 1],
    temperature=0.1, 
    epochs=epochs
)

In [None]:
plt.plot(epoch_wise_loss)
plt.show()

In [None]:
# Serialize the models
feature_backbone, projection_prototype = models
feature_backbone.save_weights('feature_backbone.h5')
projection_prototype.save_weights('projection_prototype.h5')