In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from SimCLR_data_util import preprocess_for_train
from resnet_small import ResNet18
from tensorflow.keras.layers import Dense
from SimCLR import SimCLR

In [2]:
from logistic_regression import *

In [3]:
from datasets.cifar_10 import get_unsupervised_dataset
dataset = get_unsupervised_dataset(batch_size=64)          # increase batch size if GPU memory available

In [4]:
class MyAugmentation(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
    
    def call(self, x):
        augment_image = lambda im: preprocess_for_train(im, 32, 32)
        return tf.map_fn(augment_image, x)

augmentation = MyAugmentation()

In [5]:
def get_projection_head():

    projection_head = tf.keras.Sequential([
        Dense(1024, activation='relu'),
        tf.keras.layers.Dropout(0.5),
        Dense(512, activation='relu'),
        tf.keras.layers.Dropout(0.5),
        Dense(256)
    ])

    return projection_head

projection_head = get_projection_head()

In [6]:
def get_encoder():
    model = ResNet18(10)
    encoder = tf.keras.Sequential(model.layers[:-1])
    return encoder

encoder = get_encoder()

In [7]:
sim_clr_model = SimCLR(encoder, augmentation, projection_head, temperature=0.07)

In [8]:
sim_clr_model.load_weights('./cifar_10_experiment/simclr_weights/ckpt1')
encoder = sim_clr_model.encoder

In [9]:
from datasets.cifar_10 import get_supervised_dataset
dataset = get_supervised_dataset()     

In [10]:
dataset[0][0].shape

(50000, 32, 32, 3)

In [11]:
dataset[0][1].shape

TensorShape([50000, 10])

In [12]:
len(dataset[1])

2

In [13]:
dataset[1][0].shape

(10000, 32, 32, 3)

In [14]:
dataset[1][1].shape

TensorShape([10000, 10])

In [16]:
classifier = TrainClassifier(10, encoder)

In [32]:
classifier.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), 
                    loss=tf.keras.losses.categorical_crossentropy,
                    metrics=[tf.keras.metrics.CategoricalAccuracy])
x = tf.convert_to_tensor(dataset[0][0])
y = dataset[0][1]
classifier.fit(x=x, y=y)



KeyboardInterrupt: 

In [28]:
dataset[0][0].shape, dataset[0][1].shape

((50000, 32, 32, 3), TensorShape([50000, 10]))

In [21]:
type(dataset[0)

numpy.ndarray

In [16]:
from tensorflow.keras import Input

In [17]:
inp = Input(shape=(32, 32, 3))

In [18]:
x = encoder(inp)

In [19]:
x.shape

TensorShape([None, 8192])

In [20]:
encoder.summary()

Model: "sequential_13"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 32, 32, 64)        1728      
_________________________________________________________________
batch_normalization (BatchNo (None, 32, 32, 64)        256       
_________________________________________________________________
sequential_3 (Sequential)    (None, 32, 32, 64)        148480    
_________________________________________________________________
sequential_6 (Sequential)    (None, 16, 16, 128)       526848    
_________________________________________________________________
sequential_9 (Sequential)    (None, 8, 8, 256)         2102272   
_________________________________________________________________
sequential_12 (Sequential)   (None, 4, 4, 512)         8398848   
_________________________________________________________________
flatten (Flatten)            (None, 8192)            

In [21]:
logits = tf.keras.layers.Dense(10, activation=None)(x)

In [22]:
logits

<KerasTensor: shape=(None, 10) dtype=float32 (created by layer 'dense_5')>

In [23]:
model = tf.keras.Model(inputs=inp, outputs=logits)

In [24]:
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
sequential_13 (Sequential)   (None, 8192)              11178432  
_________________________________________________________________
dense_5 (Dense)              (None, 10)                81930     
Total params: 11,260,362
Trainable params: 81,930
Non-trainable params: 11,178,432
_________________________________________________________________


In [25]:
model = tf.keras.Sequential([
    encoder, 
    tf.keras.layers.Dense(10, activation=None)
])