In [None]:
import numpy as np
import pandas as pd
import os
import json
import matplotlib.pyplot as plt
import joblib
import cv2
import plotly.express as px

from os.path import join
from glob import glob
from PIL import Image
from ipywidgets import interact
from img_plots import make_badge_scatter
from typing import Callable
from matplotlib.offsetbox import OffsetImage

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import Sequence
from tensorflow.keras.layers import Input, Flatten, Dense
from tensorflow.keras.layers import Conv2D, MaxPool2D, BatchNormalization, Activation
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

In [None]:
%matplotlib inline

In [None]:
DATA_FOLDER = 'archive'

Load data from the previous stage (front back clustering initial view)

In [None]:
LOAD_PATH = join(DATA_FOLDER, 'cropped_edge_imgs_order_fix.joblib')
LOAD_DF_INFO_PATH = join(DATA_FOLDER, 'cropped_edge_imgs_info.csv')

with open(LOAD_PATH, 'rb') as fin:
    data = joblib.load(fin)

cat_name_df = pd.read_csv(LOAD_DF_INFO_PATH)

# Resize each class to fixed shape

In [None]:
FORCE_LOAD = False
IS_SAVE = False
SAVE_PATH = join(DATA_FOLDER, 'resized_data_arr.joblib')

if not os.path.exists(SAVE_PATH) or FORCE_LOAD:
    data_arr = dict()
    
    for _, iter_row in cat_name_df.iterrows():
    
        iter_class = str(iter_row['class'])
        front_back_keys = iter_row['front_back_keys']
    
        iter_elements = [
            el for k, el in data.items()
            if k.split('_')[0] == iter_class
            and k not in front_back_keys
        ]
        med_y = int(np.median([el.shape[0] for el in iter_elements]))
        med_x = int(np.median([el.shape[1] for el in iter_elements]))
    
        iter_X = np.concatenate(
            [
                np.expand_dims(
                    np.expand_dims(
                        cv2.resize(img, [med_x, med_y]).astype(np.int32),
                        axis=0
                    ),
                    axis=-1
                )
                for img in iter_elements
            ]
        )
        data_arr[iter_class] = iter_X
else:
    print('loading from file...')
    with open(SAVE_PATH, 'rb') as fin:
        data_arr = joblib.load(fin)
if IS_SAVE:
    with open(SAVE_PATH, 'wb') as fout:
        joblib.dump(data_arr, fout)

In [None]:
class IICDataPairGenerator(Sequence):


    def __init__(self, sample_X,
                 transform_generator: ImageDataGenerator,
                 batch_size: int) -> None:
        self.batch_size = batch_size
        self.sample_X = sample_X
        self.transform_generator = transform_generator


    def __getitem__(self, index):
        """
        Image sample Indexes for the current batch
        """
        start_index = index * self.batch_size
        end_index = (index + 1) * self.batch_size

        X_transformed = np.concatenate([
            next(
                self.transform_generator.flow(
                     self.sample_X[i:(i + 1), :]
                )
            )
            for i in range(start_index, end_index)
        ])
        result_X = np.concatenate([
            self.sample_X[start_index:end_index, :],
            X_transformed
        ])
        return result_X, result_X


    def __len__(self):
        """
        Number of batches per epoch
        """
        return int(np.floor(self.sample_X.shape[0] / self.batch_size))

In [None]:
class MILoss(tf.keras.losses.Loss):


    def __init__(self, batch_size: int, n_heads: int, **kwargs) -> None:
        self.batch_size = batch_size
        self.n_heads = n_heads
        super().__init__(**kwargs)


    def call(self, y_true, y_pred):
        size = self.batch_size
        n_labels = y_pred.shape[-1]
        # lower half is Z
        Z = y_pred[0: size, :]
        Z = K.expand_dims(Z, axis=2)
        # upper half is Zbar
        Zbar = y_pred[size: y_pred.shape[0], :]
        Zbar = K.expand_dims(Zbar, axis=1)
        # compute joint distribution (Eq 10.3.2 & .3)
        P = K.batch_dot(Z, Zbar)
        P = K.sum(P, axis=0)
        # enforce symmetric joint distribution (Eq 10.3.4)
        P = (P + K.transpose(P)) / 2.0
        # normalization of total probability to 1.0
        P = P / K.sum(P)
        # marginal distributions (Eq 10.3.5 & .6)
        Pi = K.expand_dims(K.sum(P, axis=1), axis=1)
        Pj = K.expand_dims(K.sum(P, axis=0), axis=0)
        Pi = K.repeat_elements(Pi, rep=n_labels, axis=1)
        Pj = K.repeat_elements(Pj, rep=n_labels, axis=0)
        P = K.clip(P, K.epsilon(), np.finfo(float).max)
        Pi = K.clip(Pi, K.epsilon(), np.finfo(float).max)
        Pj = K.clip(Pj, K.epsilon(), np.finfo(float).max)
        # negative MI loss (Eq 10.3.7)
        neg_mi = K.sum((P * (K.log(Pi) + K.log(Pj) - K.log(P))))
        # each head contribute 1/n_heads to the total loss
        return neg_mi / self.n_heads

In [None]:
class IIC:
    def __init__(self,
                 iic_data_gen,
                 n_heads,
                 n_labels,
                 backbone):
        self.backbone = backbone
        self._model = None
        self.train_gen = iic_data_gen
        self.n_labels = n_labels
        self.n_heads = n_heads       


    def build_model(self, **dense_kwargs):
        """Build the n_heads of the IIC model
        """
        inputs = Input(shape=self.train_gen.sample_X.shape[1:], name='x')
        x = self.backbone(inputs)
        x = Flatten()(x)
        # number of output heads
        outputs = []
        for i in range(self.n_heads):
            name = "z_head%d" % i
            outputs.append(Dense(self.n_labels,
                                 name=name,
                                 **dense_kwargs)(x))
        self._model = Model(inputs, outputs, name='encoder')


    def compile(self, **compile_kwargs):
        mi_loss = MILoss(batch_size=self.train_gen.batch_size,
                         n_heads=self.n_heads)
        self._model.compile(loss=mi_loss, **compile_kwargs)


    def train(self, **fit_kwargs):
        self._model.fit(x=self.train_gen, **fit_kwargs)

    
    @property
    def model(self):
        return self._model

In [None]:
rotate_image_generator = ImageDataGenerator(
    rotation_range=50
)
data_to_generator = data_arr['3']
np.random.shuffle(data_to_generator)
iic_paired_gen = IICDataPairGenerator(data_to_generator,
                                      rotate_image_generator, 10)
backbone = tf.keras.models.Sequential([
    Input(iic_paired_gen.sample_X.shape[1:]),
    Conv2D(filters=4, kernel_size=(16, 16)),
    BatchNormalization(),
    Activation('relu'),
    MaxPool2D(pool_size=(20, 20)),
    Conv2D(filters=16, kernel_size=(3, 3), kernel_initializer='he_normal'),
    BatchNormalization(),
    Activation('relu')
])

iic_model = IIC(iic_paired_gen, n_heads=1, n_labels=2, backbone=backbone)
iic_model.build_model(activation='softmax', kernel_regularizer=tf.keras.regularizers.l1(1e-3))
optimizer = Adam(learning_rate=1e-3)
iic_model.compile(optimizer=optimizer)
iic_model.model.summary()

In [None]:
iic_model.train(epochs=50)

In [None]:
results_prob = pd.DataFrame(iic_model.model.predict(data_arr['3']), columns=['0_prob', '1_prob'])
results_prob['keys'] = results_prob.index
RANDOM_COEF = 15
results_prob['final_label'] = results_prob['0_prob'].apply(lambda prob: 1 if prob <= 0.5 else -1)
results_prob['x'] = results_prob['final_label'] + results_prob['final_label'] * np.random.random(size=len(results_prob)) * RANDOM_COEF
results_prob['y'] = results_prob['x'] + results_prob['final_label'] * np.random.random(size=len(results_prob)) * RANDOM_COEF

In [None]:
make_badge_scatter(
    df=results_prob,
    img_provider=lambda key: OffsetImage(data_to_generator[int(key)][:, :, 0], zoom=0.1, alpha=1), 
)

In [None]:
plt.imshow(data_to_generator[0])