In [None]:
import os
import random
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers.experimental import preprocessing
 
from easydict import EasyDict as edict
import matplotlib.pyplot as plt

# modules
from dataloader import Cifar10DataLoader, MnistDataLoader
from dnn import DNN
from unet import CUNet
from diffusion import DiffusionUnet

In [None]:
dataloader_args = edict({"batch_size": 128, "epochs": 50, "da": True})
dataloader = MnistDataLoader(dataloader_args=dataloader_args)
train_dataset, valid_dataset, test_dataset = dataloader.load_dataset()

def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'Predicted Image', 'Test Image']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()

def plot_train(display_list):
  # plt.figure(figsize=(10, 10))
  label = ['Train', 'Test']
  for i in range(len(display_list)):
    plt.plot(display_list[i], label=label[i])
  plt.legend()
  plt.show()
  

In [None]:
model_args = edict({"units":[128,64,32,10], "activations":["relu","relu","relu","softmax"]})
model = DNN(units=model_args.units, activations=model_args.activations)

In [None]:
train_loss_fn = tf.keras.losses.CategoricalCrossentropy()
mt_loss_fn = tf.keras.metrics.Mean()
test_loss_fn = tf.keras.losses.CategoricalCrossentropy()
mte_loss_fn = tf.keras.metrics.Mean()
opt_loss_fn = tf.keras.losses.categorical_crossentropy

train_metrics = tf.keras.metrics.CategoricalAccuracy()
test_metrics = tf.keras.metrics.CategoricalAccuracy()
optimizer = tf.keras.optimizers.SGD(0.1)

In [None]:
# @tf.function(experimental_relax_shapes=True, experimental_compile=None)
def _train_step(inputs, labels, first_batch=False):
    with tf.GradientTape() as tape:
        predictions, feat = model(inputs)
        loss = train_loss_fn(labels, predictions)
        metrics = tf.reduce_mean(train_metrics(labels, predictions))
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    mt_loss_fn.update_state(loss)
    
    return loss, metrics, feat

def _test_step(inputs, labels):
    predictions, feat = model(inputs)
    loss = test_loss_fn(labels, predictions)
    opt_loss = opt_loss_fn(labels, predictions)
    metrics = tf.reduce_mean(test_metrics(labels, predictions))
    mte_loss_fn.update_state(loss)
    
    return loss, metrics, opt_loss

In [None]:
iter_train = iter(train_dataset)
iter_valid = iter(valid_dataset)
iter_test = iter(test_dataset)
test_data =  iter_test.get_next()
# display([test_data["inputs"][0],test_data["inputs"][1],test_data["inputs"][2]])

In [None]:
model_opt = []
opt_label = []
feat_const = []
def collect_model_operator(variables, loss, feat):
    weights = [w.numpy() for w in variables]
    opt = DNN(units=model_args.units, 
            activations=model_args.activations,
            init_value=weights)
    opt_label.append(loss)
    model_opt.append(opt)
    feat_const.append(feat)

records = edict({'epoch':[],'train_loss':[],'test_loss':[],'train_metric':[],'test_metric':[]})
def obtain_model_opts(sample_start=30, sample_gap=20):
    for e in range(dataloader.info.epochs):
        mt_loss_fn.reset_states()
        train_metrics.reset_states()
        mte_loss_fn.reset_states()
        test_metrics.reset_states()
        for step in range(dataloader.info.train_step):
            data = iter_train.get_next()
            train_loss, acc, feat = _train_step(inputs=data["inputs"], labels=data["labels"])
            if (e*dataloader.info.train_step + step)%sample_gap ==0:
                if e >= sample_start:
                    test_loss, test_acc, opt_loss = _test_step(inputs=test_data["inputs"], labels=test_data["labels"])
                    collect_model_operator(model.trainable_variables, opt_loss, feat)
                    
        test_loss, test_acc, _ = _test_step(inputs=test_data["inputs"], labels=test_data["labels"])
        records.epoch        += [e]
        records.train_loss   += [mt_loss_fn.result().numpy()]
        records.train_metric += [train_metrics.result().numpy()]
        records.test_loss    += [mte_loss_fn.result().numpy()]
        records.test_metric  += [test_metrics.result().numpy()]
        log = ""
        for k,v in records.items():
            log += "{}: {} ".format(k,v[-1])
        print(log)

In [None]:

def init_model_opt(raw_model_opt, data):
    
    def opt_test_step(opt, inputs, labels):
        predictions, _ = opt(inputs)
        loss = test_loss_fn(labels, predictions)
        metrics = tf.reduce_mean(test_metrics(labels, predictions))
        mte_loss_fn.update_state(loss)
        return loss, metrics

    for idx in range(len(raw_model_opt)):
        mte_loss_fn.reset_states()
        test_metrics.reset_states()
        for step in range(1):
            data = test_data
            test_loss, test_acc = opt_test_step(opt=raw_model_opt[idx], inputs=data["inputs"], labels=data["labels"])
        print("Init: opt_id:{}, Test loss:{}, Test acc:{}".format(idx,
                                                        mte_loss_fn.result().numpy(),
                                                        test_metrics.result().numpy()))
            
def hard_save_model_opt(online_model_opt, path="./model_opt"):
    labels = [lab.numpy() for lab in opt_label]
    lab_path = os.path.join(path, "lab")
    np.save(lab_path, np.asarray(labels))
    
    init_model_opt(online_model_opt, test_data)
    for idx in range(len(online_model_opt)):
        mpath = os.path.join(path, "opt_{}".format(idx))
        online_model_opt[idx].save(mpath, overwrite=True, save_format='tf')

def load_model_opt(path="./model_opt"):
    offline_model_opt = []
    model_opt_list = os.listdir(path=path)
    for idx in range(len(model_opt_list)-1):
        mpath = os.path.join(path,  "opt_{}".format(idx))
        offline_model_opt.append(tf.keras.models.load_model(mpath))
    init_model_opt(offline_model_opt, test_data)
    
    labels = np.load(path+"/lab.npy")
    label_opts = [tf.constant(lab) for lab in list(labels)] 
    
    return offline_model_opt, label_opts

def sample_from_dict(test_label, sample_pool):
    samples = []
    labels = test_label
    for lab in labels:
        key = str(np.argmax(lab, axis=0))
        choice_idx = random.sample(range(len(sample_pool[key])), 1)[0]
        samples += [sample_pool[key][choice_idx]]
    da = tf.keras.Sequential([preprocessing.Resizing(32,32)])
    da_sample = da(tf.constant(samples))
    sample_dict = {"inputs":da_sample, "labels":test_label}
    return sample_dict

In [None]:
# training
# obtain_model_opts(sample_start=30, sample_gap=20) 
# print(len(model_opt))

# load model opt
model_opt, opt_label = load_model_opt()

In [None]:
GIAO_EPOCH = 11
GIAO_BATCH = 1

giao_optimizer = tf.keras.optimizers.Adam(1e-4)
giao_loss_fn = tf.keras.losses.MeanSquaredError()
def _opt_train_step(unet, regs, train_inputs, train_labels, labels):
    gradients = []
    losses = []
    for idx in range(len(regs)):
        with tf.GradientTape() as tape:
            pseudo_inputs = unet(train_inputs)
            predictions, feat = regs[idx](pseudo_inputs)
            reg_loss = opt_loss_fn(train_labels, predictions)
            giao_loss = tf.reduce_mean(tf.abs(labels[idx]-reg_loss))
        
            losses.append(giao_loss)
            # print(giao_loss)
            grad = tape.gradient(giao_loss, unet.model.trainable_variables)
            if gradients == []:
                gradients = grad
            else:
                gradients = [sg1+sg2 for sg1,sg2 in zip(grad, gradients)]
        
    reduced_grads = [g/GIAO_BATCH for g in gradients]
    giao_optimizer.apply_gradients(zip(reduced_grads, unet.model.trainable_variables))
    reduced_loss = sum(losses)/GIAO_BATCH
    return reduced_loss, pseudo_inputs

# def _opt_train_step(unet, regs, train_inputs, train_labels, labels):
#     gradients = []
#     losses = []
#     for idx in range(len(regs)):
#         with tf.GradientTape() as tape:
#             pseudo_inputs = unet(train_inputs)
#             # pseudo_inputs = train_inputs
#             _, feat_reg = regs[idx](train_inputs)
#             predictions, feat = regs[idx](pseudo_inputs)
#             reg_loss = opt_loss_fn(train_labels, predictions)
#             # print(reg_loss, labels[idx])
#             giao_loss = giao_loss_fn(labels[idx], reg_loss)
#             feat_loss = tf.norm(feat-feat_reg)/128
#             # giao_loss = tf.math.reduce_sum(labels[idx]-reg_loss)
#             losses.append(giao_loss+feat_loss)
#             # print(giao_loss)
#             grad = tape.gradient(giao_loss+feat_loss, unet.model.trainable_variables)
#             if gradients == []:
#                 gradients = grad
#             else:
#                 gradients = [sg1+sg2 for sg1,sg2 in zip(grad, gradients)]
        
#     reduced_grads = [g/GIAO_BATCH for g in gradients]
#     giao_optimizer.apply_gradients(zip(reduced_grads, unet.model.trainable_variables))
#     reduced_loss = sum(losses)/GIAO_BATCH
#     return reduced_loss, pseudo_inputs

In [None]:
# unet = UNet(input_shape=[32, 32, 3])
# unet = CUNet(input_shape=[32, 32, 1])
unet = DiffusionUnet(img_channels=1)

In [None]:
sample_pool,_ = dataloader.load_datadict()

for j in range(5000):
    train_data = sample_from_dict(test_label=test_data["labels"],sample_pool=sample_pool)
    pseudo_inputs = train_data["inputs"]
    for i in range(GIAO_EPOCH):
        idx = random.sample(range(len(model_opt)), GIAO_BATCH)
        labels = [opt_label[i] for i in idx]
        regs = [model_opt[i] for i in idx]
        giao_train_loss, pseudo_inputs = _opt_train_step(unet, regs, pseudo_inputs, test_data["labels"], labels)
        if (j*GIAO_EPOCH+ i) % 10 == 0:
            print("Epoch:{}, GIAO_Train_Loss:{}".format(j, giao_train_loss))
            display([train_data["inputs"].numpy()[0], pseudo_inputs.numpy()[0], test_data["inputs"][0]])
        

In [None]:
# code test
for i in range(10):
    display([train_data["inputs"].numpy()[i], pseudo_inputs.numpy()[i], test_data["inputs"].numpy()[i]])

In [None]:
# save model opt
hard_save_model_opt(model_opt)

In [None]:
# load model opt
model_opt, opt_label = load_model_opt()

In [None]:
# metrics visualization
plot_train([records.train_loss, records.test_loss])
plot_train([opt_label, opt_label])

In [None]:
train_data = iter_train.get_next()

unet = UNet(input_shape=[32, 32, 3])
output = unet(train_data["inputs"])
print(output.shape)
display([train_data["inputs"].numpy()[0],output.numpy()[0]])