## Setup

In [None]:
import sys

assert sys.version_info >= (3, 7)

In [None]:
%pip install transformers "datasets>=1.17.0"

Collecting transformers
  Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/7.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━[0m [32m6.2/7.2 MB[0m [31m188.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m104.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=1.17.0
  Downloading datasets-2.13.1-py3-none-any.whl (486 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m486.2/486.2 kB[0m [31m43.3 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.16.2-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.5/268.5 kB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-

In [None]:
import math
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow import keras
from tensorflow.keras import layers

In [None]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten,Resizing
from tensorflow.keras.layers import Conv2D, MaxPooling2D, GlobalAveragePooling2D
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing import image
import numpy as np
import pandas as pd
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.layers import Input, Flatten, Dense, Dropout
from tensorflow.keras.models import  Model
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import cv2
from tensorflow.keras.preprocessing import image

In [None]:
from transformers import AutoImageProcessor, TFViTModel


In [None]:
# from tensorflow_addons.metrics import MultiLabelConfusionMatrix
from pathlib import Path
from time import strftime

In [None]:

def get_run_logdir(root_logdir="my_logs"):
 return Path(root_logdir) / strftime("run_%Y_%m_%d_%H_%M_%S")
run_logdir = get_run_logdir() # e.g., my_logs/run_2022_08_01_17_25_59

In [None]:
import matplotlib.pyplot as plt

plt.rc('font', size=14)
plt.rc('axes', labelsize=14, titlesize=14)
plt.rc('legend', fontsize=14)
plt.rc('xtick', labelsize=10)
plt.rc('ytick', labelsize=10)
from PIL import Image

In [None]:
# Is this notebook running on Colab or Kaggle?
IS_COLAB = "google.colab" in sys.modules
IS_KAGGLE = "kaggle_secrets" in sys.modules

if not tf.config.list_physical_devices('GPU'):
    print("No GPU was detected. Neural nets can be very slow without a GPU.")
    if IS_COLAB:
        print("Go to Runtime > Change runtime and select a GPU hardware "
              "accelerator.")
    if IS_KAGGLE:
        print("Go to Settings > Accelerator and select GPU.")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
folder_url="/content/drive/MyDrive/multilabels_data/"
train_csv_file = folder_url+"train_labels.csv"
train_image_path =folder_url+"trainset"

train = pd.read_csv(train_csv_file)
train.head()

Unnamed: 0,name,HG,HT,TR,CTH,BD,VH,CTQ,DQT,KS,CVN
0,1.jpg,0,1,0,0,0,0,0,0,0,1
1,2.jpg,0,1,0,0,0,0,0,0,0,0
2,3.jpg,0,1,0,0,0,0,1,0,0,0
3,4.jpg,1,0,0,0,0,1,0,0,0,0
4,5.jpg,0,1,0,0,0,1,0,0,0,0


In [None]:
test_csv_file = folder_url+"test_labels.csv"
test_image_path =folder_url+"testset"

test = pd.read_csv(test_csv_file)
test.head()

Unnamed: 0,name,HG,HT,TR,CTH,BD,VH,CTQ,DQT,KS,CVN
0,1.jpg,0,1,0,0,0,0,1,0,0,0
1,2.jpg,0,1,0,0,0,0,1,0,0,0
2,3.jpg,0,1,0,0,0,0,1,0,0,0
3,4.jpg,0,1,0,0,0,0,1,0,0,0
4,5.jpg,0,1,0,0,0,0,0,1,0,0


In [None]:

class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, dataset, batch_size, size):
        self.dataset = dataset
        self.batch_size = batch_size
        self.size = size
        self.indices = np.arange(self.dataset.size) # for shuffle

    def __getitem__(self, i):
        # collect batch data
        start = i * self.batch_size
        stop = (i + 1) * self.batch_size
        data = []
        for j in range(start, stop):
            inds=self.indices[j]
            data.append(self.dataset[inds])

        batch = [np.stack(samples, axis=0) for samples in zip(*data)]
        return tuple(batch)
    def on_epoch_end(self):
        np.random.shuffle(self.indices)

    def __len__(self):
        return self.size // self.batch_size

class Dataset:
    def __init__(self,image_path, data, label, w, h,size):
        # the paths of images
        self.data = data
        # the paths of segmentation images
        self.label = label
        self.image_path=image_path
        self.w = w
        self.h = h
        self.size=size

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        # read data
        img = image.load_img(self.image_path + '/' + str(self.data[i]),target_size=(IMAGE_W,IMAGE_H,3))
        img = image.img_to_array(img)
        img = img /255
        label = self.label[i]
        return img, label

In [None]:
y_train = np.array(train.drop(columns=["name"]))
X_train = np.array(train["name"])

y_val = np.array(test.drop(columns=["name"]))
X_val = np.array(test["name"])

In [None]:
# X_test, X_val, y_test, y_val = train_test_split(X_test, y_test, random_state=42, test_size=0.5)

In [None]:
IMAGE_W = 224
IMAGE_H = 224

In [None]:
# Xay dung dataset va Dataloader
# Build dataaset
DOUBLE_IMAGE_W,DOUBLE_IMAGE_H=IMAGE_W*2,IMAGE_H*2 #bigger image before augmentation

train_dataset = Dataset(train_image_path,X_train, y_train,DOUBLE_IMAGE_W,DOUBLE_IMAGE_H,len(X_train))
val_dataset = Dataset(test_image_path,X_val, y_val,DOUBLE_IMAGE_W,DOUBLE_IMAGE_H,len(X_val))
# test_dataset = Dataset(test_image_path,X_test, y_test,DOUBLE_IMAGE_W,DOUBLE_IMAGE_H,len(X_test))

# Generator

train_generator = DataGenerator(train_dataset, 16, len(train_dataset))
val_generator = DataGenerator(val_dataset, 16, len(val_dataset))
# test_generator = DataGenerator(test_dataset, 16, len(test_dataset))

In [None]:

def get_zero_shot_stats(data_generator):
  m = tf.keras.metrics.BinaryAccuracy()
  total_binary_accuracy,total_val_binary_accuracy,total_instance=0,0,0

  for step, (x_batch_train, y_batch_train) in enumerate(data_generator):
      total_instance+=len(y_batch_train)
      for instance in y_batch_train:
          m.update_state(instance, [0]*10) #case full of zero
          total_binary_accuracy+=m.result().numpy()
  return total_binary_accuracy/total_instance

val_zero_shot=get_zero_shot_stats(val_generator)
val_zero_shot

0.813723093830049

In [None]:

def get_model():
    # model_base_conv = ResNet50(weights='imagenet', include_top=False)
    model_base_conv = TFViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

    # Dong bang cac layer
    # for layer in model_base_conv.layers:
    #     layer.trainable = False
    data_augmentation = tf.keras.Sequential([
          tf.keras.layers.RandomFlip(mode="horizontal_and_vertical", seed=42),
          tf.keras.layers.RandomRotation(factor=0.05, seed=42),
          tf.keras.layers.RandomContrast(factor=0.2, seed=42),
    ])
    # Tao model
    input = Input(shape=(None, None, 3), name='image_input')
    augmented_input=data_augmentation(input)
    resized_image=Resizing(height=IMAGE_H, width=IMAGE_W, crop_to_aspect_ratio=True)(augmented_input)
    permuted_image=tf.keras.layers.Permute((3,1,2))(resized_image)
    output_base_conv = model_base_conv.vit(permuted_image)[0][:,0,:]

    # Them cac layer FC va Dropout
    # x = GlobalAveragePooling2D(name='avg')(output_base_conv)
    x = Dense(2048, activation='relu', name='fc1')(output_base_conv)
    x = Dropout(0.2)(x)
    x = Dense(1024, activation='relu', name='fc2')(x)
    x = Dropout(0.2)(x)
    x = Dense(10, activation='sigmoid', name='predictions')(x)

    # Compile
    my_model = Model(inputs=input, outputs=x)
    # my_model.compile(loss='binary_crossentropy', optimizer='adam',
    #                  metrics=['Accuracy',"BinaryAccuracy","Precision","Recall",
    #                           "TruePositives","TrueNegatives","FalsePositives","FalseNegatives"])

    return my_model

model = get_model()

Downloading (…)lve/main/config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

Downloading tf_model.h5:   0%|          | 0.00/346M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFViTModel.

All the layers of TFViTModel were initialized from the model checkpoint at google/vit-base-patch16-224-in21k.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFViTModel for predictions without further training.


In [None]:
del model

In [None]:
model.save_weights(folder_url+'my_checkpoints/vit_original_check.h5')

In [None]:
model.load_weights(folder_url+'my_checkpoints/best_weight_check.h5')

In [None]:
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 image_input (InputLayer)    [(None, None, None, 3)]   0         
                                                                 
 sequential (Sequential)     (None, None, None, 3)     0         
                                                                 
 resizing (Resizing)         (None, 224, 224, 3)       0         
                                                                 
 permute (Permute)           (None, 3, 224, 224)       0         
                                                                 
 vit (TFViTMainLayer)        TFBaseModelOutputWithPoo  86389248  
                             ling(last_hidden_state=(            
                             None, 197, 768),                    
                              pooler_output=(None, 76            
                             8),                             

In [None]:
def create_accuracy(threshold=0.5):
    def accuracy_fn(y_true, y_pred):
        m=tf.keras.metrics.Accuracy()
        f=lambda x: np.where(x < 0.5, 0, 1)
        y_pred_binary=f(y_pred)
        return np.float32(np.all(y_true==y_pred_binary,axis=1))
    return accuracy_fn
class AccuracyMetric(tf.keras.metrics.Metric):
    def __init__(self, threshold=0.5, **kwargs):
        super().__init__(**kwargs) # handles base args (e.g., dtype)
        self.threshold = threshold
        self.accuracy_fn = create_accuracy(threshold)
        self.total = self.add_weight("total", initializer="zeros")
        self.count = self.add_weight("count", initializer="zeros")
    def update_state(self, y_true, y_pred, sample_weight=None):
        sample_metrics = self.accuracy_fn(y_true, y_pred)
        self.total.assign_add(tf.reduce_sum(sample_metrics))
        self.count.assign_add(tf.cast(tf.shape(y_true)[0], tf.float32))
        # print("update",tf.shape(y_true)[0],tf.shape(y_pred)[0],tf.reduce_sum(sample_metrics))
    def result(self):
        return self.total / self.count
    def get_config(self):
        base_config = super().get_config()
        return {**base_config, "threshold": self.threshold}

In [None]:
def print_status_bar(step, total, loss, metrics=None):
    step=step+1
    metrics = " - ".join([f"{m.name}: {m.result():.4f}" for m in [loss] + (metrics or [])])
    end = "" if step < total else "\n"
    print(f"\r{step}/{total} - " + metrics)

def print_metrics(epoch,total,loss,val_loss,metrics=None,val_metrics=None):
    metrics = " - ".join([f"{m.name}: {m.result():.4f}" for m in [loss] + (metrics or [])])
    val_metrics = " - val_".join([f"{m.name}: {m.result():.4f}" for m in [val_loss] + (val_metrics or [])])

    end = "" if epoch < total else "\n"
    print(f"\r{epoch}/{total} - " + metrics+" * "+val_metrics)

In [None]:
n_epochs = 20
batch_size = 16
train_steps = len(train_dataset) // batch_size
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = tf.keras.losses.BinaryCrossentropy()
mean_loss = tf.keras.metrics.Mean(name="mean_loss")
val_mean_loss=tf.keras.metrics.Mean(name="mean_loss")
metrics = [tf.keras.metrics.BinaryCrossentropy(),tf.keras.metrics.BinaryAccuracy(),AccuracyMetric(0.5)]
val_metrics = [tf.keras.metrics.BinaryCrossentropy(),tf.keras.metrics.BinaryAccuracy(),AccuracyMetric(0.5)]

In [None]:
val_steps = len(val_dataset) // batch_size


In [None]:
a=np.array([[1, 0, 0],[1,1,1]])
b=np.array([[0.4, 0.1, 0.6],[0.9,0.8,0.9]])
loss_fn(a,b)

<tf.Tensor: shape=(), dtype=float64, numpy=0.39530093395045907>

In [None]:
import tensorflow as tf

class WeightedCrossEntropy(tf.keras.losses.Loss):
    def __init__(self, weight, epsilon=1e-7, name="weighted_cross_entropy", **kwargs):
        super().__init__(name=name, **kwargs)
        self.weight = weight
        self.epsilon = epsilon
    @tf.function
    def call(self, y_true, y_pred):
        # losses=tf.constant([])
        y_pred = tf.clip_by_value(y_pred, self.epsilon, 1 - self.epsilon)
        return tf.map_fn(lambda t: -tf.reduce_mean(self.weight * (t[0] * tf.math.log(t[1]) + (1 - t[0]) * tf.math.log(1 - t[1]))),(tf.cast(y_true,tf.float32),tf.cast(y_pred,tf.float32)), dtype = tf.float32)
        # for y_true_element,y_pred_element in zip(y_true,y_pred):
        #     # y_pred_element = tf.clip_by_value(y_pred_element, self.epsilon, 1 - self.epsilon)
        #     loss = -tf.reduce_mean(self.weight * y_true_element * tf.math.log(y_pred_element) + (1 - y_true_element) * tf.math.log(1 - y_pred_element))
        #     losses=tf.concat([x, [y]], axis=0)
        # return losses

    def get_config(self):
        config = {
            'weight': self.weight,
            'epsilon': self.epsilon
        }
        base_config = super().get_config()
        return {**base_config, **config}

# wcbe=WeightedCrossEntropy(1)
# def cross_entropy(targets,predictions, epsilon=1e-12):
#     """
#     Computes cross entropy between targets (encoded as one-hot vectors)
#     and predictions.
#     Input: predictions (N, k) ndarray
#            targets (N, k) ndarray
#     Returns: scalar
#     """
#     predictions = np.clip(predictions, epsilon, 1. - epsilon)
#     N = predictions.shape[0]
#     ce = -np.sum(targets*np.log(predictions+1e-9))/N
#     return ce

# def class_weighted_cross_entropy_loss(y_true, y_pred):
#     class_weights=[1,1,1,1,1,1,1,1,1,1]
#     class_weights_tensor = tf.constant(class_weights, dtype=tf.float64)
#     losses=[]
#     for y_true_element,y_pred_element in zip(y_true,y_pred):
#         y_pred_element = tf.clip_by_value(y_pred_element, 1e-7, 1.0 - 1e-7)
#         loss = -tf.reduce_mean(class_weights_tensor * y_true_element * tf.math.log(y_pred_element)+class_weights_tensor*(1-y_true_element)*tf.math.log(1-y_pred_element), axis=-1)
#         losses=np.append(losses,loss)
#     print(losses)
#     return tf.reduce_mean(losses)

# a=tf.constant([[1, 0, 1,0,1,0,0,0,0,0],[0,1,0,0,0,0,0,0,0,1],[1,0,0,0,0,0,0,0,0,0]], tf.float32)
# b=tf.constant([[0.4, 0.1,0.2,0.5,0.7,0.8,0.3,0.9,0.2, 0.6],[0.9,0.5,0.4,0.3,0.3,0.4,0.9,0.2,0.8,0.9],[0.9,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1]], tf.float32)

# print("wbce",wcbe(a,b))
# print("loss",loss_fn(a,b))

# # print("ce",class_weighted_cross_entropy_loss(a,b))

# losses=[]
# for i in (0,1,2):
#   loss =tf.math.reduce_mean( loss_fn(a[i],b[i]))
#   losses=np.append(losses,loss)
# print("cbe",losses,tf.math.reduce_mean(losses))
# print("natural",loss_fn(a,b))
# loss_fn=WeightedCrossEntropy(1)

In [None]:
x = tf.constant([1.0, 2.0], tf.float32)
y = tf.constant(9.0, tf.float32)
tf.concat([x, [y]], axis=0)
x


<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 2.], dtype=float32)>

In [None]:
weight=[1,1,1,1,1,1,1,1,1,1]

In [None]:
# one_weights = [item * 3 for item in norm_weight]
# zero_weights = [item * 1 for item in norm_weight]


In [None]:


loss_fn=WeightedCrossEntropy(weight)



In [None]:
# loss_fn=WeightedCrossEntropy(1)
model.layers[4]

<transformers.models.vit.modeling_tf_vit.TFViTMainLayer at 0x7fc2c4d7ece0>

In [None]:
patience=10
wait=0
best=100000
for epoch in range(1, n_epochs + 1):
    print("Epoch {}/{}".format(epoch, n_epochs))
    for step in range(0, train_steps):
        X_batch, y_batch = train_generator.__getitem__(step)
        with tf.GradientTape() as tape:
            y_pred = model(X_batch, training=True)
            # print(X_batch,y_batch,y_pred)

            main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred))
            loss = tf.add_n([main_loss] + model.losses)
        gradients = tape.gradient(loss, model.trainable_variables)

        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        mean_loss(loss)
        for metric in metrics:
            metric(y_batch, y_pred)
        print_status_bar(step, train_steps, mean_loss, metrics)

    #val
    for step in range(0, val_steps):
        X_val_batch, y_val_batch = val_generator.__getitem__(step)
        y_val_pred = model(X_val_batch)
        val_main_loss = tf.reduce_mean(loss_fn(y_val_batch, y_val_pred))
        val_loss = tf.add_n([val_main_loss]  + model.losses)
        val_mean_loss(val_loss)
        print("val_mean_loss",val_mean_loss.result(),val_loss)
        for metric in val_metrics:
            metric(y_val_batch, y_val_pred)


    print_metrics(epoch,n_epochs,mean_loss,val_mean_loss,metrics,val_metrics)

    wait += 1
    if val_main_loss < best:
      best = val_main_loss
      wait = 0
      model.save_weights(folder_url+'my_checkpoints/best_triple_one_weight_check.h5')
      print("save")
    if wait >= patience:
      break

    for metric in [mean_loss] + metrics + val_metrics:
        metric.reset_states()


Epoch 1/20




1/15 - mean_loss: 0.7108 - binary_crossentropy: 0.7108 - binary_accuracy: 0.4000 - accuracy_metric_7: 0.0000




2/15 - mean_loss: 0.6101 - binary_crossentropy: 0.6101 - binary_accuracy: 0.6281 - accuracy_metric_7: 0.0312




3/15 - mean_loss: 0.5437 - binary_crossentropy: 0.5437 - binary_accuracy: 0.7042 - accuracy_metric_7: 0.0417




4/15 - mean_loss: 0.5742 - binary_crossentropy: 0.5742 - binary_accuracy: 0.7203 - accuracy_metric_7: 0.0312




5/15 - mean_loss: 0.5515 - binary_crossentropy: 0.5515 - binary_accuracy: 0.7337 - accuracy_metric_7: 0.0250




6/15 - mean_loss: 0.5429 - binary_crossentropy: 0.5429 - binary_accuracy: 0.7406 - accuracy_metric_7: 0.0312




7/15 - mean_loss: 0.5340 - binary_crossentropy: 0.5340 - binary_accuracy: 0.7473 - accuracy_metric_7: 0.0268




8/15 - mean_loss: 0.5195 - binary_crossentropy: 0.5195 - binary_accuracy: 0.7586 - accuracy_metric_7: 0.0391




9/15 - mean_loss: 0.5223 - binary_crossentropy: 0.5223 - binary_accuracy: 0.7514 - accuracy_metric_7: 0.0347




10/15 - mean_loss: 0.5122 - binary_crossentropy: 0.5122 - binary_accuracy: 0.7631 - accuracy_metric_7: 0.0312




11/15 - mean_loss: 0.5037 - binary_crossentropy: 0.5037 - binary_accuracy: 0.7716 - accuracy_metric_7: 0.0341




12/15 - mean_loss: 0.5018 - binary_crossentropy: 0.5018 - binary_accuracy: 0.7734 - accuracy_metric_7: 0.0417




13/15 - mean_loss: 0.4974 - binary_crossentropy: 0.4974 - binary_accuracy: 0.7755 - accuracy_metric_7: 0.0385




14/15 - mean_loss: 0.4910 - binary_crossentropy: 0.4910 - binary_accuracy: 0.7817 - accuracy_metric_7: 0.0357




15/15 - mean_loss: 0.4932 - binary_crossentropy: 0.4932 - binary_accuracy: 0.7763 - accuracy_metric_7: 0.0333
val_mean_loss tf.Tensor(0.44942424, shape=(), dtype=float32) tf.Tensor(0.44942424, shape=(), dtype=float32)
val_mean_loss tf.Tensor(0.43022, shape=(), dtype=float32) tf.Tensor(0.41101575, shape=(), dtype=float32)
1/20 - mean_loss: 0.4932 - binary_crossentropy: 0.4932 - binary_accuracy: 0.7763 - accuracy_metric_7: 0.0333 * mean_loss: 0.4302 - val_binary_crossentropy: 0.4302 - val_binary_accuracy: 0.8469 - val_accuracy_metric_8: 0.0938
save
Epoch 2/20




1/15 - mean_loss: 0.4418 - binary_crossentropy: 0.4418 - binary_accuracy: 0.8188 - accuracy_metric_7: 0.0625




2/15 - mean_loss: 0.4316 - binary_crossentropy: 0.4316 - binary_accuracy: 0.8344 - accuracy_metric_7: 0.0625




3/15 - mean_loss: 0.4263 - binary_crossentropy: 0.4263 - binary_accuracy: 0.8417 - accuracy_metric_7: 0.0625




4/15 - mean_loss: 0.4429 - binary_crossentropy: 0.4429 - binary_accuracy: 0.8234 - accuracy_metric_7: 0.0469




5/15 - mean_loss: 0.4474 - binary_crossentropy: 0.4474 - binary_accuracy: 0.8150 - accuracy_metric_7: 0.0375




6/15 - mean_loss: 0.4471 - binary_crossentropy: 0.4471 - binary_accuracy: 0.8167 - accuracy_metric_7: 0.0312




7/15 - mean_loss: 0.4516 - binary_crossentropy: 0.4516 - binary_accuracy: 0.8125 - accuracy_metric_7: 0.0446




8/15 - mean_loss: 0.4469 - binary_crossentropy: 0.4469 - binary_accuracy: 0.8109 - accuracy_metric_7: 0.0391




9/15 - mean_loss: 0.4530 - binary_crossentropy: 0.4530 - binary_accuracy: 0.8097 - accuracy_metric_7: 0.0347




10/15 - mean_loss: 0.4496 - binary_crossentropy: 0.4496 - binary_accuracy: 0.8150 - accuracy_metric_7: 0.0312




11/15 - mean_loss: 0.4446 - binary_crossentropy: 0.4446 - binary_accuracy: 0.8188 - accuracy_metric_7: 0.0341




12/15 - mean_loss: 0.4483 - binary_crossentropy: 0.4483 - binary_accuracy: 0.8167 - accuracy_metric_7: 0.0417




13/15 - mean_loss: 0.4480 - binary_crossentropy: 0.4480 - binary_accuracy: 0.8154 - accuracy_metric_7: 0.0385




14/15 - mean_loss: 0.4438 - binary_crossentropy: 0.4438 - binary_accuracy: 0.8188 - accuracy_metric_7: 0.0357




15/15 - mean_loss: 0.4535 - binary_crossentropy: 0.4535 - binary_accuracy: 0.8108 - accuracy_metric_7: 0.0333
val_mean_loss tf.Tensor(0.4536036, shape=(), dtype=float32) tf.Tensor(0.5003708, shape=(), dtype=float32)
val_mean_loss tf.Tensor(0.45227784, shape=(), dtype=float32) tf.Tensor(0.4483006, shape=(), dtype=float32)
2/20 - mean_loss: 0.4535 - binary_crossentropy: 0.4535 - binary_accuracy: 0.8108 - accuracy_metric_7: 0.0333 * mean_loss: 0.4523 - val_binary_crossentropy: 0.4743 - val_binary_accuracy: 0.7719 - val_accuracy_metric_8: 0.0000
Epoch 3/20




1/15 - mean_loss: 0.4564 - binary_crossentropy: 0.4564 - binary_accuracy: 0.7937 - accuracy_metric_7: 0.0625




2/15 - mean_loss: 0.4604 - binary_crossentropy: 0.4604 - binary_accuracy: 0.7875 - accuracy_metric_7: 0.0625




3/15 - mean_loss: 0.4628 - binary_crossentropy: 0.4628 - binary_accuracy: 0.7812 - accuracy_metric_7: 0.0417




4/15 - mean_loss: 0.4516 - binary_crossentropy: 0.4516 - binary_accuracy: 0.7875 - accuracy_metric_7: 0.0312




5/15 - mean_loss: 0.4497 - binary_crossentropy: 0.4497 - binary_accuracy: 0.7850 - accuracy_metric_7: 0.0250




6/15 - mean_loss: 0.4496 - binary_crossentropy: 0.4496 - binary_accuracy: 0.7906 - accuracy_metric_7: 0.0208




7/15 - mean_loss: 0.4564 - binary_crossentropy: 0.4564 - binary_accuracy: 0.7911 - accuracy_metric_7: 0.0357




8/15 - mean_loss: 0.4522 - binary_crossentropy: 0.4522 - binary_accuracy: 0.7969 - accuracy_metric_7: 0.0469




9/15 - mean_loss: 0.4605 - binary_crossentropy: 0.4605 - binary_accuracy: 0.7854 - accuracy_metric_7: 0.0417




10/15 - mean_loss: 0.4592 - binary_crossentropy: 0.4592 - binary_accuracy: 0.7819 - accuracy_metric_7: 0.0437




11/15 - mean_loss: 0.4557 - binary_crossentropy: 0.4557 - binary_accuracy: 0.7830 - accuracy_metric_7: 0.0511




12/15 - mean_loss: 0.4538 - binary_crossentropy: 0.4538 - binary_accuracy: 0.7865 - accuracy_metric_7: 0.0521




13/15 - mean_loss: 0.4530 - binary_crossentropy: 0.4530 - binary_accuracy: 0.7880 - accuracy_metric_7: 0.0481




14/15 - mean_loss: 0.4498 - binary_crossentropy: 0.4498 - binary_accuracy: 0.7906 - accuracy_metric_7: 0.0446




15/15 - mean_loss: 0.4520 - binary_crossentropy: 0.4520 - binary_accuracy: 0.7854 - accuracy_metric_7: 0.0417
val_mean_loss tf.Tensor(0.45174915, shape=(), dtype=float32) tf.Tensor(0.44963434, shape=(), dtype=float32)
val_mean_loss tf.Tensor(0.44752887, shape=(), dtype=float32) tf.Tensor(0.4264276, shape=(), dtype=float32)
3/20 - mean_loss: 0.4520 - binary_crossentropy: 0.4520 - binary_accuracy: 0.7854 - accuracy_metric_7: 0.0417 * mean_loss: 0.4475 - val_binary_crossentropy: 0.4380 - val_binary_accuracy: 0.7719 - val_accuracy_metric_8: 0.0000
Epoch 4/20




1/15 - mean_loss: 0.4423 - binary_crossentropy: 0.4423 - binary_accuracy: 0.8062 - accuracy_metric_7: 0.0625




2/15 - mean_loss: 0.4332 - binary_crossentropy: 0.4332 - binary_accuracy: 0.7969 - accuracy_metric_7: 0.0312




3/15 - mean_loss: 0.4351 - binary_crossentropy: 0.4351 - binary_accuracy: 0.7917 - accuracy_metric_7: 0.0208




4/15 - mean_loss: 0.4326 - binary_crossentropy: 0.4326 - binary_accuracy: 0.7953 - accuracy_metric_7: 0.0156




5/15 - mean_loss: 0.4345 - binary_crossentropy: 0.4345 - binary_accuracy: 0.7962 - accuracy_metric_7: 0.0125




6/15 - mean_loss: 0.4359 - binary_crossentropy: 0.4359 - binary_accuracy: 0.7990 - accuracy_metric_7: 0.0104




7/15 - mean_loss: 0.4401 - binary_crossentropy: 0.4401 - binary_accuracy: 0.7973 - accuracy_metric_7: 0.0268




8/15 - mean_loss: 0.4366 - binary_crossentropy: 0.4366 - binary_accuracy: 0.8023 - accuracy_metric_7: 0.0391




9/15 - mean_loss: 0.4451 - binary_crossentropy: 0.4451 - binary_accuracy: 0.7903 - accuracy_metric_7: 0.0347




10/15 - mean_loss: 0.4447 - binary_crossentropy: 0.4447 - binary_accuracy: 0.7869 - accuracy_metric_7: 0.0375




11/15 - mean_loss: 0.4418 - binary_crossentropy: 0.4418 - binary_accuracy: 0.7864 - accuracy_metric_7: 0.0455




12/15 - mean_loss: 0.4410 - binary_crossentropy: 0.4410 - binary_accuracy: 0.7891 - accuracy_metric_7: 0.0469




13/15 - mean_loss: 0.4409 - binary_crossentropy: 0.4409 - binary_accuracy: 0.7904 - accuracy_metric_7: 0.0433




14/15 - mean_loss: 0.4386 - binary_crossentropy: 0.4386 - binary_accuracy: 0.7915 - accuracy_metric_7: 0.0402




15/15 - mean_loss: 0.4402 - binary_crossentropy: 0.4402 - binary_accuracy: 0.7879 - accuracy_metric_7: 0.0375
val_mean_loss tf.Tensor(0.44645104, shape=(), dtype=float32) tf.Tensor(0.43998414, shape=(), dtype=float32)
val_mean_loss tf.Tensor(0.44304338, shape=(), dtype=float32) tf.Tensor(0.41918972, shape=(), dtype=float32)
4/20 - mean_loss: 0.4402 - binary_crossentropy: 0.4402 - binary_accuracy: 0.7879 - accuracy_metric_7: 0.0375 * mean_loss: 0.4430 - val_binary_crossentropy: 0.4296 - val_binary_accuracy: 0.7719 - val_accuracy_metric_8: 0.0000
Epoch 5/20




1/15 - mean_loss: 0.4312 - binary_crossentropy: 0.4312 - binary_accuracy: 0.7750 - accuracy_metric_7: 0.0000




2/15 - mean_loss: 0.4307 - binary_crossentropy: 0.4307 - binary_accuracy: 0.7781 - accuracy_metric_7: 0.0000




3/15 - mean_loss: 0.4330 - binary_crossentropy: 0.4330 - binary_accuracy: 0.7833 - accuracy_metric_7: 0.0000




4/15 - mean_loss: 0.4303 - binary_crossentropy: 0.4303 - binary_accuracy: 0.7969 - accuracy_metric_7: 0.0000




5/15 - mean_loss: 0.4321 - binary_crossentropy: 0.4321 - binary_accuracy: 0.7962 - accuracy_metric_7: 0.0000




6/15 - mean_loss: 0.4343 - binary_crossentropy: 0.4343 - binary_accuracy: 0.8000 - accuracy_metric_7: 0.0104




7/15 - mean_loss: 0.4392 - binary_crossentropy: 0.4392 - binary_accuracy: 0.8018 - accuracy_metric_7: 0.0179




8/15 - mean_loss: 0.4370 - binary_crossentropy: 0.4370 - binary_accuracy: 0.8070 - accuracy_metric_7: 0.0312




9/15 - mean_loss: 0.4434 - binary_crossentropy: 0.4434 - binary_accuracy: 0.7986 - accuracy_metric_7: 0.0278




10/15 - mean_loss: 0.4430 - binary_crossentropy: 0.4430 - binary_accuracy: 0.7937 - accuracy_metric_7: 0.0312




11/15 - mean_loss: 0.4399 - binary_crossentropy: 0.4399 - binary_accuracy: 0.7949 - accuracy_metric_7: 0.0341




12/15 - mean_loss: 0.4393 - binary_crossentropy: 0.4393 - binary_accuracy: 0.7964 - accuracy_metric_7: 0.0365




13/15 - mean_loss: 0.4388 - binary_crossentropy: 0.4388 - binary_accuracy: 0.7966 - accuracy_metric_7: 0.0337




14/15 - mean_loss: 0.4363 - binary_crossentropy: 0.4363 - binary_accuracy: 0.7996 - accuracy_metric_7: 0.0312




15/15 - mean_loss: 0.4383 - binary_crossentropy: 0.4383 - binary_accuracy: 0.7950 - accuracy_metric_7: 0.0292
val_mean_loss tf.Tensor(0.443333, shape=(), dtype=float32) tf.Tensor(0.44564995, shape=(), dtype=float32)
val_mean_loss tf.Tensor(0.44111615, shape=(), dtype=float32) tf.Tensor(0.42116427, shape=(), dtype=float32)
5/20 - mean_loss: 0.4383 - binary_crossentropy: 0.4383 - binary_accuracy: 0.7950 - accuracy_metric_7: 0.0292 * mean_loss: 0.4411 - val_binary_crossentropy: 0.4334 - val_binary_accuracy: 0.7719 - val_accuracy_metric_8: 0.0000
Epoch 6/20




1/15 - mean_loss: 0.4309 - binary_crossentropy: 0.4309 - binary_accuracy: 0.8062 - accuracy_metric_7: 0.0000




2/15 - mean_loss: 0.4289 - binary_crossentropy: 0.4289 - binary_accuracy: 0.7906 - accuracy_metric_7: 0.0000




3/15 - mean_loss: 0.4357 - binary_crossentropy: 0.4357 - binary_accuracy: 0.7917 - accuracy_metric_7: 0.0000




4/15 - mean_loss: 0.4330 - binary_crossentropy: 0.4330 - binary_accuracy: 0.7922 - accuracy_metric_7: 0.0156




5/15 - mean_loss: 0.4350 - binary_crossentropy: 0.4350 - binary_accuracy: 0.7912 - accuracy_metric_7: 0.0125




6/15 - mean_loss: 0.4374 - binary_crossentropy: 0.4374 - binary_accuracy: 0.7948 - accuracy_metric_7: 0.0104




7/15 - mean_loss: 0.4422 - binary_crossentropy: 0.4422 - binary_accuracy: 0.7911 - accuracy_metric_7: 0.0179




8/15 - mean_loss: 0.4393 - binary_crossentropy: 0.4393 - binary_accuracy: 0.7937 - accuracy_metric_7: 0.0156




9/15 - mean_loss: 0.4451 - binary_crossentropy: 0.4451 - binary_accuracy: 0.7854 - accuracy_metric_7: 0.0139




10/15 - mean_loss: 0.4442 - binary_crossentropy: 0.4442 - binary_accuracy: 0.7831 - accuracy_metric_7: 0.0125




11/15 - mean_loss: 0.4407 - binary_crossentropy: 0.4407 - binary_accuracy: 0.7869 - accuracy_metric_7: 0.0170




12/15 - mean_loss: 0.4399 - binary_crossentropy: 0.4399 - binary_accuracy: 0.7896 - accuracy_metric_7: 0.0208




13/15 - mean_loss: 0.4398 - binary_crossentropy: 0.4398 - binary_accuracy: 0.7899 - accuracy_metric_7: 0.0192




14/15 - mean_loss: 0.4377 - binary_crossentropy: 0.4377 - binary_accuracy: 0.7924 - accuracy_metric_7: 0.0179




15/15 - mean_loss: 0.4392 - binary_crossentropy: 0.4392 - binary_accuracy: 0.7892 - accuracy_metric_7: 0.0167
val_mean_loss tf.Tensor(0.44107085, shape=(), dtype=float32) tf.Tensor(0.44061822, shape=(), dtype=float32)
val_mean_loss tf.Tensor(0.4394195, shape=(), dtype=float32) tf.Tensor(0.4212544, shape=(), dtype=float32)
6/20 - mean_loss: 0.4392 - binary_crossentropy: 0.4392 - binary_accuracy: 0.7892 - accuracy_metric_7: 0.0167 * mean_loss: 0.4394 - val_binary_crossentropy: 0.4309 - val_binary_accuracy: 0.7719 - val_accuracy_metric_8: 0.0000
Epoch 7/20




1/15 - mean_loss: 0.4279 - binary_crossentropy: 0.4279 - binary_accuracy: 0.8062 - accuracy_metric_7: 0.0625




2/15 - mean_loss: 0.4269 - binary_crossentropy: 0.4269 - binary_accuracy: 0.8062 - accuracy_metric_7: 0.0312




3/15 - mean_loss: 0.4303 - binary_crossentropy: 0.4303 - binary_accuracy: 0.8021 - accuracy_metric_7: 0.0208




4/15 - mean_loss: 0.4298 - binary_crossentropy: 0.4298 - binary_accuracy: 0.7953 - accuracy_metric_7: 0.0156




5/15 - mean_loss: 0.4324 - binary_crossentropy: 0.4324 - binary_accuracy: 0.7987 - accuracy_metric_7: 0.0125




6/15 - mean_loss: 0.4349 - binary_crossentropy: 0.4349 - binary_accuracy: 0.7990 - accuracy_metric_7: 0.0104




7/15 - mean_loss: 0.4398 - binary_crossentropy: 0.4398 - binary_accuracy: 0.7964 - accuracy_metric_7: 0.0179




8/15 - mean_loss: 0.4375 - binary_crossentropy: 0.4375 - binary_accuracy: 0.8000 - accuracy_metric_7: 0.0156




9/15 - mean_loss: 0.4438 - binary_crossentropy: 0.4438 - binary_accuracy: 0.7917 - accuracy_metric_7: 0.0139




10/15 - mean_loss: 0.4427 - binary_crossentropy: 0.4427 - binary_accuracy: 0.7894 - accuracy_metric_7: 0.0125




11/15 - mean_loss: 0.4393 - binary_crossentropy: 0.4393 - binary_accuracy: 0.7892 - accuracy_metric_7: 0.0114




12/15 - mean_loss: 0.4387 - binary_crossentropy: 0.4387 - binary_accuracy: 0.7901 - accuracy_metric_7: 0.0156




13/15 - mean_loss: 0.4387 - binary_crossentropy: 0.4387 - binary_accuracy: 0.7918 - accuracy_metric_7: 0.0144




14/15 - mean_loss: 0.4366 - binary_crossentropy: 0.4366 - binary_accuracy: 0.7933 - accuracy_metric_7: 0.0134




15/15 - mean_loss: 0.4381 - binary_crossentropy: 0.4381 - binary_accuracy: 0.7887 - accuracy_metric_7: 0.0125
val_mean_loss tf.Tensor(0.4395135, shape=(), dtype=float32) tf.Tensor(0.44064158, shape=(), dtype=float32)
val_mean_loss tf.Tensor(0.4381407, shape=(), dtype=float32) tf.Tensor(0.4202941, shape=(), dtype=float32)
7/20 - mean_loss: 0.4381 - binary_crossentropy: 0.4381 - binary_accuracy: 0.7887 - accuracy_metric_7: 0.0125 * mean_loss: 0.4381 - val_binary_crossentropy: 0.4305 - val_binary_accuracy: 0.7719 - val_accuracy_metric_8: 0.0000
Epoch 8/20




1/15 - mean_loss: 0.4312 - binary_crossentropy: 0.4312 - binary_accuracy: 0.8000 - accuracy_metric_7: 0.0000




2/15 - mean_loss: 0.4290 - binary_crossentropy: 0.4290 - binary_accuracy: 0.7875 - accuracy_metric_7: 0.0000




3/15 - mean_loss: 0.4320 - binary_crossentropy: 0.4320 - binary_accuracy: 0.7958 - accuracy_metric_7: 0.0208




4/15 - mean_loss: 0.4300 - binary_crossentropy: 0.4300 - binary_accuracy: 0.7922 - accuracy_metric_7: 0.0156




5/15 - mean_loss: 0.4324 - binary_crossentropy: 0.4324 - binary_accuracy: 0.7925 - accuracy_metric_7: 0.0125




6/15 - mean_loss: 0.4355 - binary_crossentropy: 0.4355 - binary_accuracy: 0.7917 - accuracy_metric_7: 0.0104




7/15 - mean_loss: 0.4400 - binary_crossentropy: 0.4400 - binary_accuracy: 0.7902 - accuracy_metric_7: 0.0179




8/15 - mean_loss: 0.4369 - binary_crossentropy: 0.4369 - binary_accuracy: 0.7945 - accuracy_metric_7: 0.0156




9/15 - mean_loss: 0.4429 - binary_crossentropy: 0.4429 - binary_accuracy: 0.7868 - accuracy_metric_7: 0.0139




10/15 - mean_loss: 0.4418 - binary_crossentropy: 0.4418 - binary_accuracy: 0.7862 - accuracy_metric_7: 0.0188




11/15 - mean_loss: 0.4387 - binary_crossentropy: 0.4387 - binary_accuracy: 0.7875 - accuracy_metric_7: 0.0284




12/15 - mean_loss: 0.4376 - binary_crossentropy: 0.4376 - binary_accuracy: 0.7891 - accuracy_metric_7: 0.0260




13/15 - mean_loss: 0.4373 - binary_crossentropy: 0.4373 - binary_accuracy: 0.7899 - accuracy_metric_7: 0.0240




14/15 - mean_loss: 0.4347 - binary_crossentropy: 0.4347 - binary_accuracy: 0.7929 - accuracy_metric_7: 0.0268




15/15 - mean_loss: 0.4362 - binary_crossentropy: 0.4362 - binary_accuracy: 0.7908 - accuracy_metric_7: 0.0250
val_mean_loss tf.Tensor(0.43835872, shape=(), dtype=float32) tf.Tensor(0.44141084, shape=(), dtype=float32)
val_mean_loss tf.Tensor(0.437205, shape=(), dtype=float32) tf.Tensor(0.41989893, shape=(), dtype=float32)
8/20 - mean_loss: 0.4362 - binary_crossentropy: 0.4362 - binary_accuracy: 0.7908 - accuracy_metric_7: 0.0250 * mean_loss: 0.4372 - val_binary_crossentropy: 0.4307 - val_binary_accuracy: 0.7719 - val_accuracy_metric_8: 0.0000
Epoch 9/20




1/15 - mean_loss: 0.4300 - binary_crossentropy: 0.4300 - binary_accuracy: 0.7875 - accuracy_metric_7: 0.0625




2/15 - mean_loss: 0.4291 - binary_crossentropy: 0.4291 - binary_accuracy: 0.7937 - accuracy_metric_7: 0.0312




3/15 - mean_loss: 0.4335 - binary_crossentropy: 0.4335 - binary_accuracy: 0.7937 - accuracy_metric_7: 0.0208




4/15 - mean_loss: 0.4309 - binary_crossentropy: 0.4309 - binary_accuracy: 0.8031 - accuracy_metric_7: 0.0156




5/15 - mean_loss: 0.4324 - binary_crossentropy: 0.4324 - binary_accuracy: 0.8025 - accuracy_metric_7: 0.0125




6/15 - mean_loss: 0.4353 - binary_crossentropy: 0.4353 - binary_accuracy: 0.8010 - accuracy_metric_7: 0.0104




7/15 - mean_loss: 0.4394 - binary_crossentropy: 0.4394 - binary_accuracy: 0.8009 - accuracy_metric_7: 0.0179




8/15 - mean_loss: 0.4366 - binary_crossentropy: 0.4366 - binary_accuracy: 0.8016 - accuracy_metric_7: 0.0156




9/15 - mean_loss: 0.4428 - binary_crossentropy: 0.4428 - binary_accuracy: 0.7924 - accuracy_metric_7: 0.0139




10/15 - mean_loss: 0.4419 - binary_crossentropy: 0.4419 - binary_accuracy: 0.7887 - accuracy_metric_7: 0.0188




11/15 - mean_loss: 0.4389 - binary_crossentropy: 0.4389 - binary_accuracy: 0.7909 - accuracy_metric_7: 0.0227




12/15 - mean_loss: 0.4384 - binary_crossentropy: 0.4384 - binary_accuracy: 0.7937 - accuracy_metric_7: 0.0260




13/15 - mean_loss: 0.4378 - binary_crossentropy: 0.4378 - binary_accuracy: 0.7947 - accuracy_metric_7: 0.0240




14/15 - mean_loss: 0.4356 - binary_crossentropy: 0.4356 - binary_accuracy: 0.7969 - accuracy_metric_7: 0.0268




15/15 - mean_loss: 0.4371 - binary_crossentropy: 0.4371 - binary_accuracy: 0.7937 - accuracy_metric_7: 0.0250
val_mean_loss tf.Tensor(0.43732652, shape=(), dtype=float32) tf.Tensor(0.43927103, shape=(), dtype=float32)
val_mean_loss tf.Tensor(0.43633413, shape=(), dtype=float32) tf.Tensor(0.4194637, shape=(), dtype=float32)
9/20 - mean_loss: 0.4371 - binary_crossentropy: 0.4371 - binary_accuracy: 0.7937 - accuracy_metric_7: 0.0250 * mean_loss: 0.4363 - val_binary_crossentropy: 0.4294 - val_binary_accuracy: 0.7719 - val_accuracy_metric_8: 0.0000
Epoch 10/20




1/15 - mean_loss: 0.4271 - binary_crossentropy: 0.4271 - binary_accuracy: 0.8062 - accuracy_metric_7: 0.0000




2/15 - mean_loss: 0.4253 - binary_crossentropy: 0.4253 - binary_accuracy: 0.8031 - accuracy_metric_7: 0.0312




3/15 - mean_loss: 0.4278 - binary_crossentropy: 0.4278 - binary_accuracy: 0.8021 - accuracy_metric_7: 0.0208




4/15 - mean_loss: 0.4256 - binary_crossentropy: 0.4256 - binary_accuracy: 0.8031 - accuracy_metric_7: 0.0156




5/15 - mean_loss: 0.4287 - binary_crossentropy: 0.4287 - binary_accuracy: 0.8000 - accuracy_metric_7: 0.0125




6/15 - mean_loss: 0.4320 - binary_crossentropy: 0.4320 - binary_accuracy: 0.8042 - accuracy_metric_7: 0.0104




7/15 - mean_loss: 0.4370 - binary_crossentropy: 0.4370 - binary_accuracy: 0.8027 - accuracy_metric_7: 0.0179




8/15 - mean_loss: 0.4345 - binary_crossentropy: 0.4345 - binary_accuracy: 0.8062 - accuracy_metric_7: 0.0312




9/15 - mean_loss: 0.4408 - binary_crossentropy: 0.4408 - binary_accuracy: 0.7972 - accuracy_metric_7: 0.0278




10/15 - mean_loss: 0.4399 - binary_crossentropy: 0.4399 - binary_accuracy: 0.7962 - accuracy_metric_7: 0.0250




11/15 - mean_loss: 0.4369 - binary_crossentropy: 0.4369 - binary_accuracy: 0.7955 - accuracy_metric_7: 0.0227




12/15 - mean_loss: 0.4363 - binary_crossentropy: 0.4363 - binary_accuracy: 0.7974 - accuracy_metric_7: 0.0208




13/15 - mean_loss: 0.4366 - binary_crossentropy: 0.4366 - binary_accuracy: 0.7966 - accuracy_metric_7: 0.0192




14/15 - mean_loss: 0.4344 - binary_crossentropy: 0.4344 - binary_accuracy: 0.7987 - accuracy_metric_7: 0.0179




15/15 - mean_loss: 0.4361 - binary_crossentropy: 0.4361 - binary_accuracy: 0.7983 - accuracy_metric_7: 0.0167
val_mean_loss tf.Tensor(0.43645152, shape=(), dtype=float32) tf.Tensor(0.4385641, shape=(), dtype=float32)
val_mean_loss tf.Tensor(0.43559018, shape=(), dtype=float32) tf.Tensor(0.4192249, shape=(), dtype=float32)
10/20 - mean_loss: 0.4361 - binary_crossentropy: 0.4361 - binary_accuracy: 0.7983 - accuracy_metric_7: 0.0167 * mean_loss: 0.4356 - val_binary_crossentropy: 0.4289 - val_binary_accuracy: 0.7719 - val_accuracy_metric_8: 0.0000
Epoch 11/20




1/15 - mean_loss: 0.4346 - binary_crossentropy: 0.4346 - binary_accuracy: 0.7688 - accuracy_metric_7: 0.0000




2/15 - mean_loss: 0.4292 - binary_crossentropy: 0.4292 - binary_accuracy: 0.7969 - accuracy_metric_7: 0.0000




3/15 - mean_loss: 0.4319 - binary_crossentropy: 0.4319 - binary_accuracy: 0.7958 - accuracy_metric_7: 0.0000




4/15 - mean_loss: 0.4287 - binary_crossentropy: 0.4287 - binary_accuracy: 0.7922 - accuracy_metric_7: 0.0000




5/15 - mean_loss: 0.4316 - binary_crossentropy: 0.4316 - binary_accuracy: 0.7937 - accuracy_metric_7: 0.0000




6/15 - mean_loss: 0.4339 - binary_crossentropy: 0.4339 - binary_accuracy: 0.7979 - accuracy_metric_7: 0.0000




7/15 - mean_loss: 0.4382 - binary_crossentropy: 0.4382 - binary_accuracy: 0.8000 - accuracy_metric_7: 0.0179




8/15 - mean_loss: 0.4359 - binary_crossentropy: 0.4359 - binary_accuracy: 0.8023 - accuracy_metric_7: 0.0234




9/15 - mean_loss: 0.4417 - binary_crossentropy: 0.4417 - binary_accuracy: 0.7951 - accuracy_metric_7: 0.0208




10/15 - mean_loss: 0.4409 - binary_crossentropy: 0.4409 - binary_accuracy: 0.7919 - accuracy_metric_7: 0.0250




11/15 - mean_loss: 0.4380 - binary_crossentropy: 0.4380 - binary_accuracy: 0.7932 - accuracy_metric_7: 0.0341




12/15 - mean_loss: 0.4373 - binary_crossentropy: 0.4373 - binary_accuracy: 0.7958 - accuracy_metric_7: 0.0365




13/15 - mean_loss: 0.4370 - binary_crossentropy: 0.4370 - binary_accuracy: 0.7966 - accuracy_metric_7: 0.0337




14/15 - mean_loss: 0.4345 - binary_crossentropy: 0.4345 - binary_accuracy: 0.7982 - accuracy_metric_7: 0.0312




15/15 - mean_loss: 0.4356 - binary_crossentropy: 0.4356 - binary_accuracy: 0.7962 - accuracy_metric_7: 0.0292
val_mean_loss tf.Tensor(0.43582925, shape=(), dtype=float32) tf.Tensor(0.44061112, shape=(), dtype=float32)
val_mean_loss tf.Tensor(0.4351536, shape=(), dtype=float32) tf.Tensor(0.4209654, shape=(), dtype=float32)
11/20 - mean_loss: 0.4356 - binary_crossentropy: 0.4356 - binary_accuracy: 0.7962 - accuracy_metric_7: 0.0292 * mean_loss: 0.4352 - val_binary_crossentropy: 0.4308 - val_binary_accuracy: 0.7719 - val_accuracy_metric_8: 0.0000


In [None]:
model.load_weights(folder_url+'my_checkpoints/best_second_weight_check.h5')

In [None]:
def evaluate(loss,metrics=None):
    metrics = " - ".join([f"{m.name}: {m.result():.4f}" for m in [loss] + (metrics or [])])

    print( metrics)

In [None]:
test_steps = len(test_dataset) // batch_size
test_metrics = [tf.keras.metrics.BinaryCrossentropy(),tf.keras.metrics.BinaryAccuracy(),AccuracyMetric(0.5)]
for step in range(0, test_steps):
        X_test_batch, y_test_batch = test_generator.__getitem__(step)
        y_test_pred = model(X_test_batch)
        test_main_loss = tf.reduce_mean(loss_fn(y_test_batch, y_test_pred))
        test_loss = tf.add_n([test_main_loss]  + model.losses)
        mean_loss(test_loss)
        for metric in test_metrics:
            print(y_test_batch.shape,y_test_pred.shape)
            metric(y_test_batch, y_test_pred)


evaluate(mean_loss,test_metrics)

(16, 10) (16, 10)
(16, 10) (16, 10)
(16, 10) (16, 10)
(16, 10) (16, 10)
(16, 10) (16, 10)
(16, 10) (16, 10)
mean_loss: 0.1179 - binary_crossentropy: 0.1786 - binary_accuracy: 0.9406 - accuracy_metric_6: 0.6562


In [None]:
mean_loss: 0.1638 - binary_crossentropy: 0.1600 - binary_accuracy: 0.9406 - accuracy_metric_12: 0.6250
mean_loss: 0.0601 - binary_crossentropy: 0.2150 - binary_accuracy: 0.9281 - accuracy_metric_15: 0.6562
mean_loss: 0.0356 - binary_crossentropy: 0.1448 - binary_accuracy: 0.9531 - accuracy_metric_14: 0.6875
mean_loss: 0.0601 - binary_crossentropy: 0.2150 - binary_accuracy: 0.9281 - accuracy_metric_15: 0.6562
mean_loss: 0.0529 - binary_crossentropy: 0.1502 - binary_accuracy: 0.9281 - accuracy_metric_19: 0.5938



In [None]:
X_test_batch, y_test_batch = val_generator.__getitem__(0)

In [None]:
tf.shape(X_test_batch)
import time

<tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 16, 224, 224,   3], dtype=int32)>

In [None]:
import time
start_time = time.time()
y_test_pred = model(X_test_batch)
print("--- %s seconds ---" % (time.time() - start_time))

--- 0.9801642894744873 seconds ---


In [None]:
0.9801642894744873/64

0.015315067023038864

In [None]:
test_bce=tf.keras.metrics.BinaryCrossentropy()
test_bce(y_test_batch,y_test_pred)

<tf.Tensor: shape=(), dtype=float32, numpy=0.13443612>

In [None]:
### binary accuracy
# H_full = history_full.history
H_freeze=history_freeze.history
fig = plt.figure(figsize = (20, 7))
plt.subplot(121)
# plt.plot(H['Accuracy'], label = 'acc')
# plt.plot(H_full['val_binary_accuracy'], label = 'val_bi_acc')
# plt.plot(H_full['binary_accuracy'], label = 'bi_acc')
plt.plot(H_freeze['val_binary_accuracy'], label = 'val_bi_acc_freeze')
plt.plot(H_freeze['binary_accuracy'], label = 'bi_acc_freeze')
plt.axhline(y = val_zero_shot, color = 'r', linestyle = '-',label = 'val_zero_shot')

plt.grid()
plt.legend()
plt.title("binary accuracy")


#### accuracy
plt.subplot(122)
# plt.plot(H['loss'], label = 'loss')
plt.plot(H_freeze['val_Accuracy'], label = 'val_acc_freeze')
plt.plot(H_freeze['Accuracy'], label = 'acc_freeze')

plt.grid()
plt.legend()
plt.title(" accuracy")

In [None]:
# Dataset hyperparameters
unlabeled_dataset_size = 100000
labeled_dataset_size = 5000
image_size = 96
image_channels = 3

# Algorithm hyperparameters
num_epochs = 20
batch_size = 525  # Corresponds to 200 steps per epoch
width = 128
temperature = 0.1
# Stronger augmentations for contrastive, weaker ones for supervised training
contrastive_augmentation = {"min_area": 0.25, "brightness": 0.6, "jitter": 0.2}
classification_augmentation = {"min_area": 0.75, "brightness": 0.3, "jitter": 0.1}

In [None]:

# def prepare_dataset():
#     # Labeled and unlabeled samples are loaded synchronously
#     # with batch sizes selected accordingly
#     steps_per_epoch = (unlabeled_dataset_size + labeled_dataset_size) // batch_size
#     unlabeled_batch_size = unlabeled_dataset_size // steps_per_epoch
#     labeled_batch_size = labeled_dataset_size // steps_per_epoch
#     print(
#         f"batch size is {unlabeled_batch_size} (unlabeled) + {labeled_batch_size} (labeled)"
#     )

#     unlabeled_train_dataset = (
#         tfds.load("stl10", split="unlabelled", as_supervised=True, shuffle_files=True)
#         .shuffle(buffer_size=10 * unlabeled_batch_size)
#         .batch(unlabeled_batch_size)
#     )
#     labeled_train_dataset = (
#         tfds.load("stl10", split="train", as_supervised=True, shuffle_files=True)
#         .shuffle(buffer_size=10 * labeled_batch_size)
#         .batch(labeled_batch_size)
#     )
#     test_dataset = (
#         tfds.load("stl10", split="test", as_supervised=True)
#         .batch(batch_size)
#         .prefetch(buffer_size=tf.data.AUTOTUNE)
#     )

#     # Labeled and unlabeled datasets are zipped together
#     train_dataset = tf.data.Dataset.zip(
#         (unlabeled_train_dataset, labeled_train_dataset)
#     ).prefetch(buffer_size=tf.data.AUTOTUNE)

#     return train_dataset, labeled_train_dataset, test_dataset


# # Load STL10 dataset
# train_dataset, labeled_train_dataset, test_dataset = prepare_dataset()

batch size is 500 (unlabeled) + 25 (labeled)
Downloading and preparing dataset 2.46 GiB (download: 2.46 GiB, generated: 1.86 GiB, total: 4.32 GiB) to /root/tensorflow_datasets/stl10/1.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/5000 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/stl10/1.0.0.incompleteYD0949/stl10-train.tfrecord*...:   0%|          | 0/…

Generating test examples...:   0%|          | 0/8000 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/stl10/1.0.0.incompleteYD0949/stl10-test.tfrecord*...:   0%|          | 0/8…

Generating unlabelled examples...:   0%|          | 0/100000 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/stl10/1.0.0.incompleteYD0949/stl10-unlabelled.tfrecord*...:   0%|         …

Dataset stl10 downloaded and prepared to /root/tensorflow_datasets/stl10/1.0.0. Subsequent calls will reuse this data.


## Image augmentations

The two most important image augmentations for contrastive learning are the
following:

- Cropping: forces the model to encode different parts of the same image
similarly, we implement it with the
[RandomTranslation](https://keras.io/api/layers/preprocessing_layers/image_preprocessing/random_translation/)
and
[RandomZoom](https://keras.io/api/layers/preprocessing_layers/image_preprocessing/random_zoom/)
layers
- Color jitter: prevents a trivial color histogram-based solution to the task by
distorting color histograms. A principled way to implement that is by affine
transformations in color space.

In this example we use random horizontal flips as well. Stronger augmentations
are applied for contrastive learning, along with weaker ones for supervised
classification to avoid overfitting on the few labeled examples.

We implement random color jitter as a custom preprocessing layer. Using
preprocessing layers for data augmentation has the following two advantages:

- The data augmentation will run on GPU in batches, so the training will not be
bottlenecked by the data pipeline in environments with constrained CPU
resources (such as a Colab Notebook, or a personal machine)
- Deployment is easier as the data preprocessing pipeline is encapsulated in the
model, and does not have to be reimplemented when deploying it

In [None]:

# Distorts the color distibutions of images
class RandomColorAffine(layers.Layer):
    def __init__(self, brightness=0, jitter=0, **kwargs):
        super().__init__(**kwargs)

        self.brightness = brightness
        self.jitter = jitter

    def get_config(self):
        config = super().get_config()
        config.update({"brightness": self.brightness, "jitter": self.jitter})
        return config

    def call(self, images, training=True):
        if training:
            batch_size = tf.shape(images)[0]

            # Same for all colors
            brightness_scales = 1 + tf.random.uniform(
                (batch_size, 1, 1, 1), minval=-self.brightness, maxval=self.brightness
            )
            # Different for all colors
            jitter_matrices = tf.random.uniform(
                (batch_size, 1, 3, 3), minval=-self.jitter, maxval=self.jitter
            )

            color_transforms = (
                tf.eye(3, batch_shape=[batch_size, 1]) * brightness_scales
                + jitter_matrices
            )
            images = tf.clip_by_value(tf.matmul(images, color_transforms), 0, 1)
        return images


# Image augmentation module
def get_augmenter(min_area, brightness, jitter):
    zoom_factor = 1.0 - math.sqrt(min_area)
    return keras.Sequential(
        [
            keras.Input(shape=(image_size, image_size, image_channels)),
            layers.Resizing(height=224, width=224, crop_to_aspect_ratio=True),
            layers.Rescaling(1 / 255),
            layers.RandomFlip("horizontal"),
            layers.RandomTranslation(zoom_factor / 2, zoom_factor / 2),
            layers.RandomZoom((-zoom_factor, 0.0), (-zoom_factor, 0.0)),
            RandomColorAffine(brightness, jitter),
        ]
    )


def visualize_augmentations(num_images):
    # Sample a batch from a dataset
    images = next(iter(train_dataset))[0][0][:num_images]
    # Apply augmentations
    augmented_images = zip(
        images,
        get_augmenter(**classification_augmentation)(images),
        get_augmenter(**contrastive_augmentation)(images),
        get_augmenter(**contrastive_augmentation)(images),
    )
    row_titles = [
        "Original:",
        "Weakly augmented:",
        "Strongly augmented:",
        "Strongly augmented:",
    ]
    plt.figure(figsize=(num_images * 2.2, 4 * 2.2), dpi=100)
    for column, image_row in enumerate(augmented_images):
        for row, image in enumerate(image_row):
            plt.subplot(4, num_images, row * num_images + column + 1)
            plt.imshow(image)
            if column == 0:
                plt.title(row_titles[row], loc="left")
            plt.axis("off")
    plt.tight_layout()


visualize_augmentations(num_images=8)

NameError: ignored

## Encoder architecture

In [None]:

# Define the encoder architecture
def get_encoder():
    # base_model=ResNet50(weights='imagenet', include_top=False)
    # for layer in base_model.layers:
    #       layer.trainable = False
    return keras.Sequential(
        [
            keras.Input(shape=(224, 224, image_channels)),
            # layers.Lambda(tf.keras.applications.resnet50.preprocess_input),
            # base_model,
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Flatten(),
            layers.Dense(width, activation="relu"),
        ],
        name="encoder",
    )


## Supervised baseline model

A baseline supervised model is trained using random initialization.

In [None]:
# Baseline supervised training with random initialization
baseline_model = keras.Sequential(
    [
        keras.Input(shape=(image_size, image_size, image_channels)),
        get_augmenter(**classification_augmentation),
        get_encoder(),
        layers.Dense(10),
    ],
    name="baseline_model",
)


In [None]:
# baseline_model.summary()
baseline_model.layers[1].summary()

Model: "encoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 111, 111, 128)     3584      
                                                                 
 conv2d_1 (Conv2D)           (None, 55, 55, 128)       147584    
                                                                 
 conv2d_2 (Conv2D)           (None, 27, 27, 128)       147584    
                                                                 
 conv2d_3 (Conv2D)           (None, 13, 13, 128)       147584    
                                                                 
 flatten (Flatten)           (None, 21632)             0         
                                                                 
 dense (Dense)               (None, 128)               2769024   
                                                                 
Total params: 3,215,360
Trainable params: 3,215,360
Non-tra

In [None]:
baseline_model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)

baseline_history = baseline_model.fit(
    labeled_train_dataset, epochs=num_epochs, validation_data=test_dataset
)

print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(baseline_history.history["val_acc"]) * 100
    )
)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Maximal validation accuracy: 31.41%


In [None]:
for layer in baseline_model.layers[1].layers[1].layers[150:]:
          layer.trainable = True
baseline_model.summary()

Model: "baseline_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 sequential_5 (Sequential)   (None, 224, 224, 3)       0         
                                                                 
 encoder (Sequential)        (None, 128)               26276352  
                                                                 
 dense_3 (Dense)             (None, 10)                1290      
                                                                 
Total params: 26,277,642
Trainable params: 12,680,074
Non-trainable params: 13,597,568
_________________________________________________________________


In [None]:
baseline_model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)

baseline_history = baseline_model.fit(
    labeled_train_dataset, epochs=num_epochs, validation_data=test_dataset
)

print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(baseline_history.history["val_acc"]) * 100
    )
)

Epoch 1/20
Epoch 2/20
Epoch 3/20


KeyboardInterrupt: ignored

## Self-supervised model for contrastive pretraining

We pretrain an encoder on unlabeled images with a contrastive loss.
A nonlinear projection head is attached to the top of the encoder, as it
improves the quality of representations of the encoder.

We use the InfoNCE/NT-Xent/N-pairs loss, which can be interpreted in the
following way:

1. We treat each image in the batch as if it had its own class.
2. Then, we have two examples (a pair of augmented views) for each "class".
3. Each view's representation is compared to every possible pair's one (for both
  augmented versions).
4. We use the temperature-scaled cosine similarity of compared representations as
  logits.
5. Finally, we use categorical cross-entropy as the "classification" loss

The following two metrics are used for monitoring the pretraining performance:

- [Contrastive accuracy (SimCLR Table 5)](https://arxiv.org/abs/2002.05709):
Self-supervised metric, the ratio of cases in which the representation of an
image is more similar to its differently augmented version's one, than to the
representation of any other image in the current batch. Self-supervised
metrics can be used for hyperparameter tuning even in the case when there are
no labeled examples.
- [Linear probing accuracy](https://arxiv.org/abs/1603.08511): Linear probing is
a popular metric to evaluate self-supervised classifiers. It is computed as
the accuracy of a logistic regression classifier trained on top of the
encoder's features. In our case, this is done by training a single dense layer
on top of the frozen encoder. Note that contrary to traditional approach where
the classifier is trained after the pretraining phase, in this example we
train it during pretraining. This might slightly decrease its accuracy, but
that way we can monitor its value during training, which helps with
experimentation and debugging.

Another widely used supervised metric is the
[KNN accuracy](https://arxiv.org/abs/1805.01978), which is the accuracy of a KNN
classifier trained on top of the encoder's features, which is not implemented in
this example.

In [None]:

# Define the contrastive model with model-subclassing
class ContrastiveModel(keras.Model):
    def __init__(self):
        super().__init__()

        self.temperature = temperature
        self.contrastive_augmenter = get_augmenter(**contrastive_augmentation)
        self.classification_augmenter = get_augmenter(**classification_augmentation)
        self.encoder = get_encoder()
        # Non-linear MLP as projection head
        self.projection_head = keras.Sequential(
            [
                keras.Input(shape=(width,)),
                layers.Dense(width, activation="relu"),
                layers.Dense(width),
            ],
            name="projection_head",
        )
        # Single dense layer for linear probing
        self.linear_probe = keras.Sequential(
            [layers.Input(shape=(width,)), layers.Dense(10)], name="linear_probe"
        )

        self.encoder.summary()
        self.projection_head.summary()
        self.linear_probe.summary()

    def compile(self, contrastive_optimizer, probe_optimizer, **kwargs):
        super().compile(**kwargs)

        self.contrastive_optimizer = contrastive_optimizer
        self.probe_optimizer = probe_optimizer

        # self.contrastive_loss will be defined as a method
        self.probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

        self.contrastive_loss_tracker = keras.metrics.Mean(name="c_loss")
        self.contrastive_accuracy = keras.metrics.SparseCategoricalAccuracy(
            name="c_acc"
        )
        self.probe_loss_tracker = keras.metrics.Mean(name="p_loss")
        self.probe_accuracy = keras.metrics.SparseCategoricalAccuracy(name="p_acc")

    @property
    def metrics(self):
        return [
            self.contrastive_loss_tracker,
            self.contrastive_accuracy,
            self.probe_loss_tracker,
            self.probe_accuracy,
        ]

    def contrastive_loss(self, projections_1, projections_2):
        # InfoNCE loss (information noise-contrastive estimation)
        # NT-Xent loss (normalized temperature-scaled cross entropy)

        # Cosine similarity: the dot product of the l2-normalized feature vectors
        projections_1 = tf.math.l2_normalize(projections_1, axis=1)
        projections_2 = tf.math.l2_normalize(projections_2, axis=1)
        similarities = (
            tf.matmul(projections_1, projections_2, transpose_b=True) / self.temperature
        )

        # The similarity between the representations of two augmented views of the
        # same image should be higher than their similarity with other views
        batch_size = tf.shape(projections_1)[0]
        contrastive_labels = tf.range(batch_size)
        self.contrastive_accuracy.update_state(contrastive_labels, similarities)
        self.contrastive_accuracy.update_state(
            contrastive_labels, tf.transpose(similarities)
        )

        # The temperature-scaled similarities are used as logits for cross-entropy
        # a symmetrized version of the loss is used here
        loss_1_2 = keras.losses.sparse_categorical_crossentropy(
            contrastive_labels, similarities, from_logits=True
        )
        loss_2_1 = keras.losses.sparse_categorical_crossentropy(
            contrastive_labels, tf.transpose(similarities), from_logits=True
        )
        return (loss_1_2 + loss_2_1) / 2

    def train_step(self, data):
        (unlabeled_images, _), (labeled_images, labels) = data

        # Both labeled and unlabeled images are used, without labels
        images = tf.concat((unlabeled_images, labeled_images), axis=0)
        # Each image is augmented twice, differently
        augmented_images_1 = self.contrastive_augmenter(images, training=True)
        augmented_images_2 = self.contrastive_augmenter(images, training=True)
        with tf.GradientTape() as tape:
            features_1 = self.encoder(augmented_images_1, training=True)
            features_2 = self.encoder(augmented_images_2, training=True)
            # The representations are passed through a projection mlp
            projections_1 = self.projection_head(features_1, training=True)
            projections_2 = self.projection_head(features_2, training=True)
            contrastive_loss = self.contrastive_loss(projections_1, projections_2)
        gradients = tape.gradient(
            contrastive_loss,
            self.encoder.trainable_weights + self.projection_head.trainable_weights,
        )
        self.contrastive_optimizer.apply_gradients(
            zip(
                gradients,
                self.encoder.trainable_weights + self.projection_head.trainable_weights,
            )
        )
        self.contrastive_loss_tracker.update_state(contrastive_loss)

        # Labels are only used in evalutation for an on-the-fly logistic regression
        preprocessed_images = self.classification_augmenter(
            labeled_images, training=True
        )
        with tf.GradientTape() as tape:
            # the encoder is used in inference mode here to avoid regularization
            # and updating the batch normalization paramers if they are used
            features = self.encoder(preprocessed_images, training=False)
            class_logits = self.linear_probe(features, training=True)
            probe_loss = self.probe_loss(labels, class_logits)
        gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
        self.probe_optimizer.apply_gradients(
            zip(gradients, self.linear_probe.trainable_weights)
        )
        self.probe_loss_tracker.update_state(probe_loss)
        self.probe_accuracy.update_state(labels, class_logits)

        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        labeled_images, labels = data

        # For testing the components are used with a training=False flag
        preprocessed_images = self.classification_augmenter(
            labeled_images, training=False
        )
        features = self.encoder(preprocessed_images, training=False)
        class_logits = self.linear_probe(features, training=False)
        probe_loss = self.probe_loss(labels, class_logits)
        self.probe_loss_tracker.update_state(probe_loss)
        self.probe_accuracy.update_state(labels, class_logits)

        # Only the probe metrics are logged at test time
        return {m.name: m.result() for m in self.metrics[2:]}


# Contrastive pretraining
pretraining_model = ContrastiveModel()
pretraining_model.compile(
    contrastive_optimizer=keras.optimizers.Adam(),
    probe_optimizer=keras.optimizers.Adam(),
)

pretraining_history = pretraining_model.fit(
    train_dataset, epochs=num_epochs, validation_data=test_dataset
)
print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(pretraining_history.history["val_p_acc"]) * 100
    )
)

Model: "encoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 lambda_2 (Lambda)           (None, 224, 224, 3)       0         
                                                                 
 resnet50 (Functional)       (None, None, None, 2048)  23587712  
                                                                 
 avg (GlobalAveragePooling2D  (None, 2048)             0         
 )                                                               
                                                                 
 fc1 (Dense)                 (None, 1024)              2098176   
                                                                 
 dropout_4 (Dropout)         (None, 1024)              0         
                                                                 
 fc2 (Dense)                 (None, 512)               524800    
                                                           

KeyboardInterrupt: ignored

## Supervised finetuning of the pretrained encoder

We then finetune the encoder on the labeled examples, by attaching
a single randomly initalized fully connected classification layer on its top.

In [None]:
# Supervised finetuning of the pretrained encoder
finetuning_model = keras.Sequential(
    [
        layers.Input(shape=(image_size, image_size, image_channels)),
        get_augmenter(**classification_augmentation),
        pretraining_model.encoder,
        layers.Dense(10),
    ],
    name="finetuning_model",
)
finetuning_model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)

finetuning_history = finetuning_model.fit(
    labeled_train_dataset, epochs=num_epochs, validation_data=test_dataset
)
print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(finetuning_history.history["val_acc"]) * 100
    )
)

## Comparison against the baseline

In [None]:

# The classification accuracies of the baseline and the pretraining + finetuning process:
def plot_training_curves(pretraining_history, finetuning_history, baseline_history):
    for metric_key, metric_name in zip(["acc", "loss"], ["accuracy", "loss"]):
        plt.figure(figsize=(8, 5), dpi=100)
        plt.plot(
            baseline_history.history[f"val_{metric_key}"], label="supervised baseline"
        )
        plt.plot(
            pretraining_history.history[f"val_p_{metric_key}"],
            label="self-supervised pretraining",
        )
        plt.plot(
            finetuning_history.history[f"val_{metric_key}"],
            label="supervised finetuning",
        )
        plt.legend()
        plt.title(f"Classification {metric_name} during training")
        plt.xlabel("epochs")
        plt.ylabel(f"validation {metric_name}")


plot_training_curves(pretraining_history, finetuning_history, baseline_history)