In [0]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten, Conv2D, MaxPooling2D
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow_probability as tfp 

In [0]:
#https://raw.githubusercontent.com/ntozer/mixmatch-tensorflow2.0/master/model.py

class Residual3x3Unit(tf.keras.layers.Layer):
    def __init__(self, channels_in, channels_out, stride, droprate=0., activate_before_residual=False):
        super(Residual3x3Unit, self).__init__()
        self.bn_0 = tf.keras.layers.BatchNormalization(momentum=0.999)
        self.relu_0 = tf.keras.layers.LeakyReLU(alpha=0.1)
        self.conv_0 = tf.keras.layers.Conv2D(channels_out, kernel_size=3, strides=stride, padding='same', use_bias=False)
        self.bn_1 = tf.keras.layers.BatchNormalization(momentum=0.999)
        self.relu_1 = tf.keras.layers.LeakyReLU(alpha=0.1)
        self.conv_1 = tf.keras.layers.Conv2D(channels_out, kernel_size=3, strides=1, padding='same', use_bias=False)
        self.downsample = channels_in != channels_out
        self.shortcut = tf.keras.layers.Conv2D(channels_out, kernel_size=1, strides=stride, use_bias=False)
        self.activate_before_residual = activate_before_residual
        self.dropout = tf.keras.layers.Dropout(rate=droprate)
        self.droprate = droprate

    @tf.function
    def call(self, x, training=True):
        if self.downsample and self.activate_before_residual:
            x = self.relu_0(self.bn_0(x, training=training))
        elif not self.downsample:
            out = self.relu_0(self.bn_0(x, training=training))
        out = self.relu_1(self.bn_1(self.conv_0(x if self.downsample else out), training=training))
        if self.droprate > 0.:
            out = self.dropout(out)
        out = self.conv_1(out)
        return out + (self.shortcut(x) if self.downsample else x)


class ResidualBlock(tf.keras.layers.Layer):
    def __init__(self, n_units, channels_in, channels_out, unit, stride, droprate=0., activate_before_residual=False):
        super(ResidualBlock, self).__init__()
        self.units = self._build_unit(n_units, unit, channels_in, channels_out, stride, droprate, activate_before_residual)

    def _build_unit(self, n_units, unit, channels_in, channels_out, stride, droprate, activate_before_residual):
        units = []
        for i in range(n_units):
            units.append(unit(channels_in if i == 0 else channels_out, channels_out, stride if i == 0 else 1, droprate, activate_before_residual))
        return units

    @tf.function
    def call(self, x, training=True):
        for unit in self.units:
            x = unit(x, training=training)
        return x


class WideResNet(tf.keras.Model):
    def __init__(self, num_classes, depth=28, width=2, droprate=0., input_shape=(None, 32, 32, 3), **kwargs):
        super(WideResNet, self).__init__(input_shape, **kwargs)
        assert (depth - 4) % 6 == 0
        N = int((depth - 4) / 6)
        channels = [16, 16 * width, 32 * width, 64 * width]

        self.conv_0 = tf.keras.layers.Conv2D(channels[0], kernel_size=3, strides=1, padding='same', use_bias=False)
        self.block_0 = ResidualBlock(N, channels[0], channels[1], Residual3x3Unit, 1, droprate, True)
        self.block_1 = ResidualBlock(N, channels[1], channels[2], Residual3x3Unit, 2, droprate)
        self.block_2 = ResidualBlock(N, channels[2], channels[3], Residual3x3Unit, 2, droprate)
        self.bn_0 = tf.keras.layers.BatchNormalization(momentum=0.999)
        self.relu_0 = tf.keras.layers.LeakyReLU(alpha=0.1)
        self.avg_pool = tf.keras.layers.AveragePooling2D((8, 8), (1, 1))
        self.flatten = tf.keras.layers.Flatten()
        self.dense = tf.keras.layers.Dense(num_classes)

    @tf.function
    def call(self, inputs, training=True):
        x = inputs
        x = self.conv_0(x)
        x = self.block_0(x, training=training)
        x = self.block_1(x, training=training)
        x = self.block_2(x, training=training)
        x = self.relu_0(self.bn_0(x, training=training))
        x = self.avg_pool(x)
        x = self.flatten(x)
        x = self.dense(x)
        return x

In [0]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

n_classes = 10

x_test = x_test.astype('float32')
x_test /= 255.
y_test = tf.keras.utils.to_categorical(y_test, n_classes)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [0]:
def training_validation_set_split(x_train, y_train, labeled_samples_per_class, fix_validation_size = True, validation_split = 0.2, shuffle = True, n_classes = 10):
    
    assert len(x_train) == len(y_train), "x_train and y_train lengths do not match."
    
    labeled_training_idx = []
    labeled_validation_idx = []
    unlabeled_training_idx = []
    
    if(fix_validation_size is True):
        n_train_per_class = labeled_samples_per_class
        n_val_per_class = 500
    else:
        n_val_per_class = int(labeled_samples_per_class * validation_split)
        n_train_per_class =  labeled_samples_per_class - n_val_per_class    

    for class_idx in range(n_classes):
        indices = np.where(y_train == class_idx)[0]
        if(shuffle == True):
            np.random.shuffle(indices)
        
        labeled_training_idx.extend(indices[:n_train_per_class])
        labeled_validation_idx.extend(indices[-n_val_per_class:])
        unlabeled_training_idx.extend(indices[n_train_per_class:-n_val_per_class])

    np.random.shuffle(labeled_training_idx)
    np.random.shuffle(labeled_validation_idx)
    np.random.shuffle(unlabeled_training_idx)
   
    x_labeled_training = x_train[labeled_training_idx]
    y_labeled_training = y_train[labeled_training_idx]
    
    x_labeled_validation = x_train[labeled_validation_idx]
    y_labeled_validation = y_train[labeled_validation_idx]
    
    x_unlabeled_training = x_train[unlabeled_training_idx]
    
    return x_labeled_training, y_labeled_training, x_labeled_validation, y_labeled_validation, x_unlabeled_training

In [0]:
seed = 2019
tf.random.set_seed(seed)
np.random.seed(seed)

x_labeled_training, y_labeled_training, x_labeled_validation, y_labeled_validation, x_unlabeled_training = \
    training_validation_set_split(x_train,y_train,labeled_samples_per_class=25, fix_validation_size=True,shuffle=True)
x_labeled_training = x_labeled_training.astype('float32')
x_labeled_validation = x_labeled_validation.astype('float32')
x_labeled_training /= 255.
x_labeled_validation /= 255.
x_unlabeled_training = x_unlabeled_training.astype('float32')
x_unlabeled_training /= 255.
y_labeled_training = tf.keras.utils.to_categorical(y_labeled_training, n_classes)
y_labeled_validation = tf.keras.utils.to_categorical(y_labeled_validation, n_classes)

In [0]:
@tf.function
def sharpen(preds, temperature = 0.5):
    preds_target = tf.math.pow(preds, 1. / temperature)
    preds_target /= tf.math.reduce_sum(preds_target, axis=1, keepdims=True)
    
    return preds_target

@tf.function
def augment_labeled(images):
    images = tf.image.random_flip_left_right(images)
    images = tf.image.random_contrast(images, 0.8, 1.0)
    images = tf.image.random_hue(images,max_delta = 0.2)
    images = tf.image.random_brightness(images,max_delta = 0.2)
    choice = tf.random.uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)
    if(choice > 0.5):
        images = tf.pad(images, paddings=[(0, 0), (4, 4), (4, 4), (0, 0)], mode='REFLECT')
        images = tf.map_fn(lambda image: tf.image.random_crop(image, size=(32, 32, 3)), images)
    return tf.clip_by_value(images,0.0,1.0)

#https://github.com/google-research/mixmatch/blob/master/mixmatch.py
def mix_up(x_first, y_first, x_second, y_second, alpha = 0.75):
    assert x_first.shape[0] == x_second.shape[0], "Array sizes differ."
    
    mix = tfp.distributions.Beta(alpha, alpha).sample([tf.shape(x_first)[0], 1, 1, 1])
    mix = tf.maximum(mix, 1 - mix)
    
    x_mix = mix * x_first + (1-mix) * x_second
    y_mix = mix[:, :, 0, 0] * y_first + (1-mix[:, :, 0, 0]) * y_second
    
    return x_mix, y_mix

def weight_decay(model, decay_rate):
    for var in model.trainable_variables:
        var.assign(var * (1 - decay_rate))
        
@tf.function
def shuffle_tensors_together(tensor1, tensor2):    
    indices = tf.range(start=0, limit=tf.shape(tensor1)[0], dtype=tf.int32)
    shuffled_indices = tf.random.shuffle(indices)

    tensor1_shuffled = tf.gather(tensor1, shuffled_indices)
    tensor2_shuffled = tf.gather(tensor2, shuffled_indices)
    
    return tensor1_shuffled, tensor2_shuffled

@tf.function
def random_crop(images):
    padded_images = tf.pad(images, paddings=[(0, 0), (4, 4), (4, 4), (0, 0)], mode='REFLECT')
    cropped_images = tf.map_fn(lambda image: tf.image.random_crop(image, size=(32, 32, 3)), padded_images)
    return cropped_images

@tf.function
def get_losses(labeled_logits, labeled_targets, unlabeled_logits, unlabeled_targets):
    labeled_loss = tf.nn.softmax_cross_entropy_with_logits(labels=labeled_targets, logits=labeled_logits)
    labeled_loss = tf.reduce_mean(labeled_loss)
    
    unlabeled_loss = tf.square(unlabeled_targets - tf.nn.softmax(unlabeled_logits))
    unlabeled_loss = tf.reduce_mean(unlabeled_loss)

    return labeled_loss, unlabeled_loss

def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets


def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [tf.concat(v, axis=0) for v in xy]

In [0]:
## https://www.tensorflow.org/guide/keras/train_and_evaluate
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from keras_radam.training import RAdamOptimizer

ramp_up_length = 6000
w_decay = 0.000035
unlabeled_augments = ['random_crop','lrflip']
unlabeled_weight = 75

# Get the model.
model = WideResNet(n_classes, depth=28, width=2)
model.build(input_shape=(None, 32, 32, 3))

initial_learning_rate = 0.002

# Instantiate an optimizer to train the model.
optimizer = RAdamOptimizer(learning_rate=initial_learning_rate)

# Prepare the metrics.
train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
val_acc_metric = tf.keras.metrics.CategoricalAccuracy()

avg_train_loss = tf.keras.metrics.Mean()
avg_labeled_loss = tf.keras.metrics.Mean()
avg_unlabeled_loss =  tf.keras.metrics.Mean()
avg_val_loss = tf.keras.metrics.Mean()

# Prepare the training dataset.
batch_size = 64
labeled_train_dataset = tf.data.Dataset.from_tensor_slices((x_labeled_training, y_labeled_training))
labeled_train_dataset = labeled_train_dataset.batch(batch_size, drop_remainder=True)
labeled_iter = iter(labeled_train_dataset)

unlabeled_train_dataset = tf.data.Dataset.from_tensor_slices(x_unlabeled_training)
unlabeled_train_dataset = unlabeled_train_dataset.batch(batch_size, drop_remainder=True)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_labeled_validation, y_labeled_validation))
val_dataset = val_dataset.batch(64)

global_step = 0

# Iterate over epochs.
for epoch in range(400):
    epoch_start = tf.timestamp(name=None)

    # Iterate over the batches of the dataset.
    for x_unlabeled_batch in unlabeled_train_dataset:
        try: 
            x_labeled_batch_train, y_labeled_batch_train = next(labeled_iter)
        except:
            labeled_iter = iter(labeled_train_dataset)
            x_labeled_batch_train, y_labeled_batch_train = next(labeled_iter)

        # Augment labeled
        x_augmented_labeled_batch = augment_labeled(x_labeled_batch_train)

        # Unlabeled predictions
        logits = []
        
        for aug in unlabeled_augments:
            if(aug == 'random_crop'):
                augmented_unlabeled_batch_crop = random_crop(x_unlabeled_batch)
                crop_guessed_label = tf.nn.softmax(model(augmented_unlabeled_batch_crop))
                logits.append(crop_guessed_label)
            if(aug == 'lrflip'):
                augmented_unlabeled_batch_lr = tf.image.flip_left_right(x_unlabeled_batch)
                lr_guessed_label = tf.nn.softmax(model(augmented_unlabeled_batch_lr))
                logits.append(lr_guessed_label)

        # Aggregate unlabeled predictions and sharpen        
        logits = tf.add_n(logits) / len(unlabeled_augments)
        guessed_labels = sharpen(logits)
        guessed_labels = tf.stop_gradient(guessed_labels)

        # Prepare W
        x_unlabeled_all_augmentations = tf.concat([augmented_unlabeled_batch_crop, augmented_unlabeled_batch_lr], axis=0)
        y_unlabeled_all_augmentations = tf.concat([guessed_labels, guessed_labels], axis=0)
        
        x_combined_batch = tf.concat([x_augmented_labeled_batch, x_unlabeled_all_augmentations], axis=0)
        y_combined_batch = tf.concat([y_labeled_batch_train, y_unlabeled_all_augmentations], axis=0)

        shuffled_x, shuffled_y = shuffle_tensors_together(x_combined_batch,y_combined_batch)
        
        # MixUp on X+U, W
        XU, XUy = mix_up(x_combined_batch, y_combined_batch, \
                             shuffled_x, shuffled_y, alpha = 0.75)
        
        # Split to batches
        XU = tf.split(XU, len(unlabeled_augments) + 1, axis=0)

        # Mix samples within batches 
        # https://github.com/google-research/mixmatch/issues/5
        XU = interleave(XU, batch_size)
        
        # Predict
        with tf.GradientTape() as tape:
            train_logits = [model(XU[0])]
            for batch in XU[1:]:
                train_logits.append(model(batch))

            # Restore original batches
            train_logits = interleave(train_logits, batch_size)

            # Calculate loss
            labeled_logits = train_logits[0]
            unlabeled_logits = tf.concat(train_logits[1:], axis=0)

            labeled_loss, unlabeled_loss = get_losses(labeled_logits,XUy[:batch_size],unlabeled_logits,XUy[batch_size:])

            avg_unlabeled_loss(unlabeled_loss)
            avg_labeled_loss(labeled_loss)

            ramp_up = 1.0 if global_step > ramp_up_length else global_step / ramp_up_length
            combined_loss = labeled_loss + ramp_up * unlabeled_weight * unlabeled_loss

        # Calculate gradients, update metrics    
        grads = tape.gradient(combined_loss, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        global_step = global_step + 1
        weight_decay(model = model, decay_rate = w_decay)
        avg_train_loss(combined_loss)



    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric(y_batch_val, tf.nn.softmax(val_logits))
        val_loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_batch_val, logits=val_logits)
        avg_val_loss(tf.reduce_mean(val_loss))
        
    epoch_end = tf.timestamp(name=None)
    print(f'Epoch: {epoch:4d} , Train loss: {avg_train_loss.result():.4f}, Val loss: {avg_val_loss.result():.4f}, Val Acc: {val_acc_metric.result():.3%}, Lambda: {float(ramp_up * unlabeled_weight):.4f}' )
    print(f'Labeled loss: {avg_labeled_loss.result():.4f} , Unlabeled loss: {avg_unlabeled_loss.result():.4f}, LR: {learning_rate_fn(global_step).numpy():2f}, Time:: {str(int((epoch_end.numpy()-epoch_start.numpy())))}s')
    
    train_acc_metric.reset_states()    
    val_acc_metric.reset_states()
    avg_train_loss.reset_states()
    avg_val_loss.reset_states()
    avg_labeled_loss.reset_states()
    avg_unlabeled_loss.reset_states()

Using TensorFlow backend.


Epoch:    0 , Train loss: 1.2989, Val loss: 3.4729, Val Acc: 23.760%, Lambda: 8.7250
Labeled loss: 1.2478 , Unlabeled loss: 0.0110, LR: 0.001931, Time:: 604s
Epoch:    1 , Train loss: 1.0707, Val loss: 2.7451, Val Acc: 39.980%, Lambda: 17.4625
Labeled loss: 0.9307 , Unlabeled loss: 0.0107, LR: 0.001865, Time:: 596s
Epoch:    2 , Train loss: 1.1106, Val loss: 4.3641, Val Acc: 30.860%, Lambda: 26.2000
Labeled loss: 0.8928 , Unlabeled loss: 0.0100, LR: 0.001801, Time:: 594s
Epoch:    3 , Train loss: 1.1704, Val loss: 2.2529, Val Acc: 48.860%, Lambda: 34.9375
Labeled loss: 0.8765 , Unlabeled loss: 0.0096, LR: 0.001739, Time:: 595s
Epoch:    4 , Train loss: 1.2271, Val loss: 2.3782, Val Acc: 47.880%, Lambda: 43.6750
Labeled loss: 0.8600 , Unlabeled loss: 0.0093, LR: 0.001679, Time:: 593s
Epoch:    5 , Train loss: 1.2909, Val loss: 2.2606, Val Acc: 45.680%, Lambda: 52.4125
Labeled loss: 0.8533 , Unlabeled loss: 0.0091, LR: 0.001622, Time:: 593s
Epoch:    6 , Train loss: 1.3567, Val loss: 1.8