In [None]:
!nvidia-smi

In [None]:
import os
# CHANGE this for whatever GPU index you have
os.environ["CUDA_VISIBLE_DEVICES"] = "5"

In [None]:
import tensorflow as tf
tfkl = tf.keras.layers
import numpy as np
from matplotlib import pyplot as plt

from aldi.modeling.layers import DownLevel, UpLevel
from aldi.modeling.callbacks import ReconstructionPlotCallback, CodebookResetter
from aldi.modeling.vq import Autoencoder, RVQ

In [None]:
batch_size = 256
(train_images, _), (test_images, _) = tf.keras.datasets.cifar10.load_data()

train_images = train_images.astype(np.float32) / 255.
test_images = test_images.astype(np.float32) / 255.

train_data = tf.data.Dataset.from_tensor_slices(train_images).shuffle(50000).batch(batch_size, num_parallel_calls=tf.data.AUTOTUNE, drop_remainder=True)
test_data = tf.data.Dataset.from_tensor_slices(test_images).batch(batch_size, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)
train_data = train_data.prefetch(tf.data.AUTOTUNE)

In [None]:
plt.figure(figsize=(15,15))
for ind, img in enumerate(test_images[:64]):
    plt.subplot(8, 8, ind+1)
    plt.imshow(img, cmap="Greys")
    plt.axis("off")
plt.show()

In [None]:
a = plt.hist(test_images.reshape(-1), bins=255)
plt.show()

In [None]:
def encoder_stack(inputs, filters, strides, blocks_per_level):
    inputs = tfkl.Conv2D(filters[0], 3, padding="same")(inputs)
    for level_ind, (level_filters, level_strides) in enumerate(zip(filters[1:], strides)):
        inputs = DownLevel(2,
                           blocks_per_level,
                           level_filters,
                           level_filters,
                           3,
                           level_strides,
                           normalization=tfkl.BatchNormalization,
                           name="down_level" + str(level_ind))(inputs)
        
    return inputs


def decoder_stack(inputs, filters, strides, blocks_per_level):
    inputs = tfkl.Conv2D(filters[0], 3, padding="same")(inputs)
    for level_ind, (level_filters, level_strides) in enumerate(zip(filters[1:], strides)):
        inputs = UpLevel(2,
                         blocks_per_level,
                         level_filters,
                         level_filters,
                         3,
                         level_strides,
                         normalization=tfkl.BatchNormalization,
                         name="up_level" + str(level_ind))(inputs)
        
    return inputs

In [None]:
models = []
histories = []

d = 4
codebook_powers = range(1, 14)
codebook_sizes = [2**power for power in codebook_powers]
betas = [0.002, 0.02, 0.2, 2.]
betas = [b/d for b in betas]
for cbs in codebook_sizes:
    models.append([])
    histories.append([])
    print("\n\n\nRUNNING codebook size = {}".format(cbs))
    for beta in betas:
        print("\n\n\nRUNNING beta = {}".format(beta))

        inp = tf.keras.Input((32, 32, 3))

        blocks_per_level = 2
        filters = [16, 32, 64, 128]
        strides = [2, 2, 2]

        encoder_final = encoder_stack(inp, filters, strides, blocks_per_level)
        encoder_final = tfkl.Conv2D(d, 1, padding="same")(encoder_final)

        encoder = tf.keras.Model(inp, encoder_final, name="encoder")

        decoder_input = tf.keras.Input(encoder_final.shape[1:])
        decoder_output = decoder_stack(decoder_input, list(reversed(filters)), list(reversed(strides)), blocks_per_level)
        decoder_final = tfkl.Conv2D(3, 1,  padding="same")(decoder_output)

        decoder = tf.keras.Model(decoder_input, decoder_final, name="decoder")

        quantizer = RVQ(cbs, encoder_final.shape[-1], 1)

        model = Autoencoder(inp, encoder, decoder, tf.keras.losses.MeanSquaredError(), 
                            quantizer=quantizer, beta=beta, name="autoencoder")
        model.summary(expand_nested=True)
        
        
        batches = []
        ind = 0
        for batch in train_data:
            dummyenc = encoder(batch)
            batches.append(dummyenc)
            ind+=1
            if ind >= 4:
                break
        dummyenc = tf.concat(batches, axis=0)
        
        quantizer.init_with_k_means(dummyenc, n_init=1, batch_n_multiplier=4)

        
        train_steps = 500000
        n_data = 50000
        n_epochs = train_steps // (n_data // batch_size)
        optimizer = tf.optimizers.Adam()

        model.compile(optimizer=optimizer, jit_compile=True)


        reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(patience=2, verbose=1, factor=0.5,
                                                         min_delta=0.000001)
        earlystop = tf.keras.callbacks.EarlyStopping(patience=6, verbose=1, restore_best_weights=True,
                                                         min_delta=0.000001)

        average_code_use = batch_size*4*4 / cbs
        history = model.fit(train_data, validation_data=test_data, epochs=n_epochs, 
                 callbacks=[ReconstructionPlotCallback(test_images[:16], 10, clip=True),
                            CodebookResetter(1, threshold=average_code_use/256, iteration_source=train_data),
                            reduce_lr, earlystop])

        models[-1].append(model)
        histories[-1].append(history.history)

In [None]:
best_models = []
d_losses = np.zeros((len(codebook_sizes), len(betas)))
for cb_i, cb in enumerate(codebook_sizes):
    for b_i, beta in enumerate(betas):
        model = models[cb_i][b_i]
        hmm = model.evaluate(test_data)
        d_losses[cb_i, b_i] = hmm[0]
    
    best_models.append(models[cb_i][np.argmin(d_losses[cb_i])])

In [None]:
plt.semilogx(codebook_sizes, d_losses.min(axis=1), "-*")
plt.xlabel("d")
plt.ylabel("Validation MSE")
plt.show()

In [None]:
 d_losses.min(axis=1)

In [None]:
np.save("losses_d4.npy", d_losses.min(axis=1))

In [None]:
for history in histories:
    for key in history:
        vals = history[key]
        plt.plot(vals)
        plt.title(key)
        plt.show()

In [None]:
errors = []
for model in best_models:
    encoded = model.encoder.predict(test_data)[:1000]
    quantized, _, _ = model.quantizer(encoded)
    quantized = quantized.numpy()

    dotprod = (encoded * quantized).sum(axis=-1) / (np.linalg.norm(encoded,axis=-1)*np.linalg.norm(quantized,axis=-1))
    errors.append((1-dotprod).mean())


In [None]:
plt.plot(errors)
plt.show()

In [None]:
for model in models:
    encoded = model.encoder.predict(test_data)
    plt.hist(encoded.reshape(-1), bins=250)
    plt.show()
    enc_flat = encoded.reshape((-1, d))
    plt.scatter(enc_flat[:, 0], enc_flat[:, 1], marker=".", alpha=0.1)
    plt.gca().set_aspect("equal")
    plt.show()

In [None]:
for d, model in zip(ds, models):
    model.save("basic_d{}".format(d))

In [None]:
import pickle
for d, history in zip(ds, histories):
    with open("basic_d{}.pkl".format(d), "wb") as file:
        pickle.dump(history, file)