In [0]:
import os
import json
import gc
from keras_radam.training import RAdamOptimizer
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [0]:
generator_batch_size = 8

def create_datagen():
    return tf.keras.preprocessing.image.ImageDataGenerator(
        horizontal_flip=True,
        vertical_flip=True,
        validation_split=0.9535,
        rescale=1./255)

def create_unlabeled_gen():
    return tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255).flow_from_directory(
        directory=r'mammo-calc\analyzer_calc_data\train',
        class_mode=None,
        batch_size=generator_batch_size,
        target_size=(256, 256),
        shuffle=True
    
    )

def create_flow(datagen, subset):
    return datagen.flow_from_directory(
        directory=r'mammo-calc\analyzer_calc_data\train',
        batch_size=generator_batch_size,
        target_size=(256, 256),
        shuffle=True,
        subset=subset,
        class_mode='binary'
    )

def create_test_flow(generator, directory):
    return generator.flow_from_directory(
        directory=directory, 
        batch_size=generator_batch_size,
        target_size=(256, 256),
        shuffle=False,
        class_mode='binary'
    )
def create_testgen():
    return tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

data_generator = create_datagen()
test_gen = create_testgen()

In [0]:
train_generator = create_flow(data_generator, 'training')
unlabeled_generator = create_flow(data_generator, 'validation')
test_generator = create_test_flow(test_gen, r'mammo-calc\analyzer_calc_data\test')
validation_generator = create_test_flow(test_gen, r'mammo-calc\analyzer_calc_data\validation')

In [0]:
#https://raw.githubusercontent.com/ntozer/mixmatch-tensorflow2.0/master/model.py
class Residual3x3Unit(tf.keras.layers.Layer):
    def __init__(self, channels_in, channels_out, stride, droprate=0., activate_before_residual=False):
        super(Residual3x3Unit, self).__init__()
        self.bn_0 = tf.keras.layers.BatchNormalization(momentum=0.999)
        self.relu_0 = tf.keras.layers.LeakyReLU(alpha=0.1)
        self.conv_0 = tf.keras.layers.Conv2D(channels_out, kernel_size=3, strides=stride, padding='same', use_bias=False)
        self.bn_1 = tf.keras.layers.BatchNormalization(momentum=0.999)
        self.relu_1 = tf.keras.layers.LeakyReLU(alpha=0.1)
        self.conv_1 = tf.keras.layers.Conv2D(channels_out, kernel_size=3, strides=1, padding='same', use_bias=False)
        self.downsample = channels_in != channels_out
        self.shortcut = tf.keras.layers.Conv2D(channels_out, kernel_size=1, strides=stride, use_bias=False)
        self.activate_before_residual = activate_before_residual
        self.dropout = tf.keras.layers.Dropout(rate=droprate)
        self.droprate = droprate

    @tf.function
    def call(self, x, training=True):
        if self.downsample and self.activate_before_residual:
            x = self.relu_0(self.bn_0(x, training=training))
        elif not self.downsample:
            out = self.relu_0(self.bn_0(x, training=training))
        out = self.relu_1(self.bn_1(self.conv_0(x if self.downsample else out), training=training))
        if self.droprate > 0.:
            out = self.dropout(out)
        out = self.conv_1(out)
        return out + (self.shortcut(x) if self.downsample else x)


class ResidualBlock(tf.keras.layers.Layer):
    def __init__(self, n_units, channels_in, channels_out, unit, stride, droprate=0., activate_before_residual=False):
        super(ResidualBlock, self).__init__()
        self.units = self._build_unit(n_units, unit, channels_in, channels_out, stride, droprate, activate_before_residual)

    def _build_unit(self, n_units, unit, channels_in, channels_out, stride, droprate, activate_before_residual):
        units = []
        for i in range(n_units):
            units.append(unit(channels_in if i == 0 else channels_out, channels_out, stride if i == 0 else 1, droprate, activate_before_residual))
        return units

    @tf.function
    def call(self, x, training=True):
        for unit in self.units:
            x = unit(x, training=training)
        return x


class WideResNet(tf.keras.Model):
    def __init__(self, num_classes, depth=28, width=2, droprate=0., input_shape=(None, 32, 32, 3), **kwargs):
        super(WideResNet, self).__init__(input_shape, **kwargs)
        assert (depth - 4) % 6 == 0
        N = int((depth - 4) / 6)
        channels = [16, 16 * width, 32 * width, 64 * width]

        self.conv_0 = tf.keras.layers.Conv2D(channels[0], kernel_size=3, strides=1, padding='same', use_bias=False)
        self.block_0 = ResidualBlock(N, channels[0], channels[1], Residual3x3Unit, 1, droprate, True)
        self.block_1 = ResidualBlock(N, channels[1], channels[2], Residual3x3Unit, 2, droprate)
        self.block_2 = ResidualBlock(N, channels[2], channels[3], Residual3x3Unit, 2, droprate)
        self.bn_0 = tf.keras.layers.BatchNormalization(momentum=0.999)
        self.relu_0 = tf.keras.layers.LeakyReLU(alpha=0.1)
        self.avg_pool = tf.keras.layers.AveragePooling2D((8, 8), (1, 1))
        self.flatten = tf.keras.layers.Flatten()
        self.dense = tf.keras.layers.Dense(num_classes)

    @tf.function
    def call(self, inputs, training=True):
        x = inputs
        x = self.conv_0(x)
        x = self.block_0(x, training=training)
        x = self.block_1(x, training=training)
        x = self.block_2(x, training=training)
        x = self.relu_0(self.bn_0(x, training=training))
        x = self.avg_pool(x)
        x = self.flatten(x)
        x = self.dense(x)
        return x

In [0]:
#Supervised benchmark
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

minimum_val_loss = None
BATCH_SIZE = 8
# Keep results for plotting
train_loss_results = []
train_accuracy_results = []
model = WideResNet(1, depth=16, width=2)
model.build(input_shape=(None, 256, 256, 3))
num_epochs = 400

loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(0.00001) #
epoch_loss_avg = tf.keras.metrics.Mean()
epoch_accuracy = tf.keras.metrics.BinaryAccuracy()
avg_val_loss = tf.keras.metrics.Mean()
tmp_loss = tf.keras.metrics.Mean()
val_acc_metric = tf.keras.metrics.BinaryAccuracy()
global_step = 0
for epoch in range(num_epochs):
    epoch_start = tf.timestamp(name=None)

    train_step = 1
    valid_step = 1

    # Training loop - using batches of 32
    for x, y in train_generator:
        if train_step%(400//BATCH_SIZE)==0:
            break
        # Optimize the model
        with tf.GradientTape() as tape:
                logits = model(x)
                loss_value = tf.keras.backend.binary_crossentropy(target=y, output=tf.squeeze(logits, axis=1), from_logits=True) 
                loss_value = tf.reduce_mean(loss_value)
    
        grads = tape.gradient(loss_value, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        # Track progress
        epoch_loss_avg(loss_value)  # Add current batch loss
        # Compare predicted label to actual label
        epoch_accuracy(y, tf.nn.sigmoid(logits))
        train_step = train_step + 1
        global_step = global_step + 1
        tmp_loss(loss_value)

    for x_batch_val, y_batch_val in validation_generator:
        if valid_step%(604//BATCH_SIZE) == 0:
            break
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric(y_batch_val, tf.nn.sigmoid(val_logits))
        val_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_batch_val, logits=tf.squeeze(val_logits, axis=1))
        avg_val_loss(tf.reduce_mean(val_loss))
        valid_step = valid_step + 1
    epoch_end = tf.timestamp(name=None)
    print("Epoch {:03d}: Loss: {:.3f}, Accuracy: {:.3%}, Global steps: {:.5f}".format(epoch,
                                                                epoch_loss_avg.result(),
                                                                epoch_accuracy.result(),
                                                                global_step))
    print(f'Validation loss: {avg_val_loss.result():2f}, Validation Accuracy: {val_acc_metric.result():2%}, Time:: {str(int((epoch_end.numpy()-epoch_start.numpy())))}s')

    epoch_loss_avg.reset_states()    
    epoch_accuracy.reset_states()
    avg_val_loss.reset_states()
    val_acc_metric.reset_states()

    tmp_loss.reset_states()

    # End epoch
    train_loss_results.append(epoch_loss_avg.result())
    train_accuracy_results.append(epoch_accuracy.result())



In [0]:
generator_batch_size = 8

def create_datagen():
    return tf.keras.preprocessing.image.ImageDataGenerator(
        horizontal_flip=True,
        vertical_flip=True,
        validation_split=0.9535,
        rescale=1./255)

def create_unlabeled_gen():
    return tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255).flow_from_directory(
        directory=r'mammo-calc\analyzer_calc_data\train',
        class_mode=None,
        batch_size=generator_batch_size,
        target_size=(256, 256),
        shuffle=True
    
    )

def create_flow(datagen, subset):
    return datagen.flow_from_directory(
        directory=r'mammo-calc\analyzer_calc_data\train',
        batch_size=generator_batch_size,
        target_size=(256, 256),
        shuffle=True,
        subset=subset,
        class_mode='binary'
    )

def create_test_flow(generator, directory):
    return generator.flow_from_directory(
        directory=directory, 
        batch_size=generator_batch_size,
        target_size=(256, 256),
        shuffle=False,
        class_mode='binary'
    )
def create_testgen():
    return tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

data_generator = create_datagen()
test_gen = create_testgen()

In [0]:
train_generator = create_flow(data_generator, 'training')
unlabeled_generator = create_flow(data_generator, 'validation')
test_generator = create_test_flow(test_gen, r'mammo-calc\analyzer_calc_data\test')
validation_generator = create_test_flow(test_gen, r'mammo-calc\analyzer_calc_data\validation')

In [0]:
BATCH_SIZE = generator_batch_size

# https://github.com/google-research/mixmatch/blob/master/libml/layers.py
def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets


def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [tf.concat(v, axis=0) for v in xy]


@tf.function
def sharpen(preds, temperature=0.5):
    print(preds)
    preds_target = tf.math.pow(preds, 1. / temperature)
    preds_target /= tf.math.reduce_sum(preds_target, axis=1, keepdims=True)

    return preds_target


@tf.function
def augment_labeled(images):
    images = tf.image.random_flip_left_right(images)
    images = tf.image.random_flip_up_down(images)
    images = tf.image.random_contrast(images, 0.8, 1.0)
    images = tf.image.random_brightness(images, max_delta=0.2)
    return tf.clip_by_value(images, 0.0, 1.0)


# https://github.com/google-research/mixmatch/blob/master/mixmatch.py
def mix_up(x_first, y_first, x_second, y_second, alpha=0.4):
    assert x_first.shape[0] == x_second.shape[0], "Array sizes differ."

    mix = tf.compat.v1.distributions.Beta(alpha, alpha).sample([tf.shape(x_first)[0], 1, 1, 1])
    mix = tf.maximum(mix, 1 - mix)

    x_mix = mix * x_first + (1 - mix) * x_second
    y_mix = mix[:, :, 0, 0] * y_first + (1 - mix[:, :, 0, 0]) * y_second

    return x_mix, y_mix

@tf.function
def binary_sharpen(preds, temperature=0.5):
    preds = tf.reshape(preds, [BATCH_SIZE,])
    preds_target = (tf.nn.tanh((2/temperature)*(preds-0.5))+1)/2
    preds_target = tf.reshape(preds_target, [BATCH_SIZE,1])
    return preds_target

def weight_decay(model, decay_rate):
    for var in model.trainable_variables:
        var.assign(var * (1 - decay_rate))


@tf.function
def shuffle_tensors_together(tensor1, tensor2):
    indices = tf.range(start=0, limit=tf.shape(tensor1)[0], dtype=tf.int32)
    shuffled_indices = tf.random.shuffle(indices)

    tensor1_shuffled = tf.gather(tensor1, shuffled_indices)
    tensor2_shuffled = tf.gather(tensor2, shuffled_indices)

    return tensor1_shuffled, tensor2_shuffled


@tf.function
def random_crop(images):
    padded_images = tf.pad(images, paddings=[(0, 0), (16, 16), (16, 16), (0, 0)], mode='REFLECT')
    cropped_images = tf.map_fn(lambda image: tf.image.random_crop(image, size=(256, 256, 3)), padded_images)
    return cropped_images


@tf.function
def get_losses(labeled_logits, labeled_targets, unlabeled_logits, unlabeled_targets):
    labeled_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labeled_targets, logits=labeled_logits)
    labeled_loss = tf.reduce_mean(labeled_loss)

    unlabeled_loss = tf.square(unlabeled_targets - tf.nn.sigmoid(unlabeled_logits))
    unlabeled_loss = tf.reduce_mean(unlabeled_loss)

    return labeled_loss, unlabeled_loss


def mixmatch(model, x_augmented_labeled_batch, y_labeled_batch_train, x_unlabeled_batch, unlabeled_augments,
             batch_size=BATCH_SIZE):
    logits = []

    for aug in unlabeled_augments:
        if aug == 'udflip':
            augmented_unlabeled_batch_ud = tf.image.random_flip_up_down(x_unlabeled_batch)
            ud_guessed_label = tf.nn.sigmoid(model(augmented_unlabeled_batch_ud))
            logits.append(binary_sharpen(ud_guessed_label))
        if aug == 'lrflip':
            augmented_unlabeled_batch_lr = tf.image.flip_left_right(x_unlabeled_batch)
            lr_guessed_label = tf.nn.sigmoid(model(augmented_unlabeled_batch_lr))
            logits.append(binary_sharpen(lr_guessed_label))
        if aug == 'brightness':
            augmented_unlabeled_batch_br = tf.clip_by_value(tf.image.random_brightness(x_unlabeled_batch, max_delta=0.2), 0.0, 1.0)
            br_guessed_label = tf.nn.sigmoid(model(augmented_unlabeled_batch_br))
            logits.append(binary_sharpen(br_guessed_label))

    logits = tf.add_n(logits) / len(unlabeled_augments)
    guessed_labels = logits #binary_sharpen(logits)
    guessed_labels = tf.stop_gradient(guessed_labels)

    x_unlabeled_all_augmentations = tf.concat([augmented_unlabeled_batch_ud, augmented_unlabeled_batch_lr, augmented_unlabeled_batch_br],
                                              axis=0)
    y_unlabeled_all_augmentations = tf.concat([guessed_labels, guessed_labels, guessed_labels], axis=0)

    x_combined_batch = tf.concat([x_augmented_labeled_batch, x_unlabeled_all_augmentations], axis=0)
    y_combined_batch = tf.concat([y_labeled_batch_train, y_unlabeled_all_augmentations], axis=0)

    shuffled_x, shuffled_y = shuffle_tensors_together(x_combined_batch, y_combined_batch)
    del x_unlabeled_all_augmentations, y_unlabeled_all_augmentations, augmented_unlabeled_batch_ud, \
            augmented_unlabeled_batch_lr, augmented_unlabeled_batch_br
    x_combined_mix, y_combined_mix = mix_up(x_combined_batch, y_combined_batch, shuffled_x, shuffled_y, alpha=0.75)
    x_combined_mix = tf.split(x_combined_mix, len(unlabeled_augments) + 1, axis=0)
    x_combined_mix_interleaved = interleave(x_combined_mix, batch_size)

    return x_combined_mix_interleaved, y_combined_mix


In [0]:
#Train with MixMatch
print(tf.__version__)
max_vall_acc = None
batch_size = BATCH_SIZE
model = WideResNet(1, depth=16, width=2)
model.build(input_shape=(None, 256, 256, 3))

ramp_up_length = 10000
w_decay = 0.00001
unlabeled_augments = ['udflip', 'lrflip', 'brightness']
unlabeled_weight = 30


# Instantiate an optimizer to train the model.
optimizer = tf.keras.optimizers.Adam(lr=0.0001)

# Prepare the metrics.
train_acc_metric = tf.keras.metrics.BinaryAccuracy()
val_acc_metric = tf.keras.metrics.BinaryAccuracy()

avg_train_loss = tf.keras.metrics.Mean()
avg_labeled_loss = tf.keras.metrics.Mean()
avg_unlabeled_loss = tf.keras.metrics.Mean()
avg_val_loss = tf.keras.metrics.Mean()

# Prepare the training dataset.

labeled_iter = iter(train_generator)

global_step = 0

# Iterate over epochs.
for epoch in range(20):
    epoch_start = tf.timestamp(name=None)
    labeled_step = 1
    valid_step = 1
    # Iterate over the batches of the dataset.
    for x_unlabeled_batch, _ in unlabeled_generator:
        if(x_unlabeled_batch.shape[0] != BATCH_SIZE):
            x_unlabeled_batch, _ = next(unlabeled_generator)
        if labeled_step%(2046//BATCH_SIZE)==0:
            break

        x_labeled_batch_train, y_labeled_batch_train = next(labeled_iter)
        if(x_labeled_batch_train.shape[0] != BATCH_SIZE):
            x_labeled_batch_train, y_labeled_batch_train = next(labeled_iter)
            
        y_labeled_batch_train = np.expand_dims(y_labeled_batch_train, axis=1)

        x_augmented_labeled_batch = augment_labeled(x_labeled_batch_train)

        x_combined_mix_interleaved, y_combined_mix = mixmatch(model, x_augmented_labeled_batch,
                                                              y_labeled_batch_train,
                                                              x_unlabeled_batch,
                                                              unlabeled_augments,
                                                             BATCH_SIZE)

        with tf.GradientTape() as tape:

            train_logits_interleaved = [model(x_combined_mix_interleaved[0])]
            for batch in x_combined_mix_interleaved[1:]:
                train_logits_interleaved.append(model(batch))

            train_logits = interleave(train_logits_interleaved, batch_size)
            labeled_logits = train_logits[0]
            unlabeled_logits = tf.concat(train_logits[1:], axis=0)

            labeled_loss, unlabeled_loss = get_losses(labeled_logits, y_combined_mix[:batch_size],
                                                      unlabeled_logits,
                                                      y_combined_mix[batch_size:])
            avg_unlabeled_loss(unlabeled_loss)
            avg_labeled_loss(labeled_loss)

            ramp_up = 1.0 if global_step > ramp_up_length else global_step / ramp_up_length
            combined_loss = labeled_loss + ramp_up * unlabeled_weight * unlabeled_loss

        grads = tape.gradient(combined_loss, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        global_step = global_step + 1
        weight_decay(model=model, decay_rate=w_decay)
        avg_train_loss(combined_loss)
        # Update training metric.
        train_acc_metric(y_labeled_batch_train, tf.nn.sigmoid(model(x_labeled_batch_train)))
        del x_combined_mix_interleaved, y_combined_mix
        labeled_step += 1

    for x_batch_val, y_batch_val in validation_generator:
        if valid_step%(604//BATCH_SIZE) == 0:
            break
        if(x_batch_val.shape[0] != BATCH_SIZE):
                x_batch_val, y_batch_val = next(validation_generator)
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric(y_batch_val, tf.nn.sigmoid(val_logits))
        val_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_batch_val, logits=tf.squeeze(val_logits, axis=1))
        avg_val_loss(tf.reduce_mean(val_loss))
        valid_step = valid_step + 1
    epoch_end = tf.timestamp(name=None)
    print(f'Epoch: {epoch:4d}, Train loss: {avg_train_loss.result():.4f}, Train Acc: {train_acc_metric.result():.3%},'
            f' Val loss: {avg_val_loss.result():.4f}, Val Acc: {val_acc_metric.result():.3%},'
            f' Lambda: {float(ramp_up * unlabeled_weight):.4f}, Total steps: {global_step}',)
    print(f'Labeled loss: {avg_labeled_loss.result():.4f}, Unlabeled loss: {avg_unlabeled_loss.result():.4f}, '
                f'Time: {str(int((epoch_end.numpy() - epoch_start.numpy())))}s')
    avg_train_loss.reset_states()
        

    if max_val_acc is None:
        max_val_acc = val_acc_metric.result()
        model.save_weights(f'model_ {epoch}_{max_val_acc}.h5')
    elif val_acc_metric.result() > max_val_acc:
        max_val_acc = val_acc_metric.result()
        model.save_weights(f'model_ {epoch}_{max_val_acc}.h5')
