In [1]:
import os
import re

import numpy as np
import seaborn as sns
import dataframe_image as dfi
import matplotlib.pylab as plt
import pandas as pd


class MLExplain:
    
    def __init__(self,
                 checkpoint_dir: str,
                 test_data,
                 test_labels
                ):
        
        self.checkpoint_dir = checkpoint_dir
        self.export_dir = './ml_explain'
        self.test_data = test_data
        self.test_labels = test_labels
        self.callback = None
        
        os.makedirs(self.export_dir, exist_ok=True)
        
    
    def compare_models(self):
    
        files = os.listdir(self.checkpoint_dir)
        num_files = range(len(files))
        file_dict = {i+1 : None for i in num_files}
    
        for f in files:
            _f = re.split('-|\.', f)[1]
            file_dict[int(_f)] = f"{self.checkpoint_dir}/{f}"
    
        for i in num_files:
            if not file_dict.get(i+2): break
                
            _dir = f"{self.export_dir}/epoch{i+1}_epoch{i+2}"
            os.makedirs(_dir, exist_ok=True)
            
            model1 = create_model()
            model1.load_weights(file_dict[i+1])
        
            model2 = create_model()
            model2.load_weights(file_dict[i+2])
        
            model1_weights = model1.layers[-1].get_weights()[0]
            model2_weights = model2.layers[-1].get_weights()[0]
            delta = abs(model1_weights-model2_weights)
            np.savetxt(f'{_dir}/delta.txt', delta, delimiter=',')
            
            df = self.make_visualization(_dir, delta)
            
            self.make_preds(model1, model2, _dir, i+1)
    
        return file_dict
    
    def make_preds(self,
                   model1,
                   model2,
                   out_dir,
                   epoch
                  ):
        loss1, acc1 = model1.evaluate(self.test_data, self.test_labels, verbose=2)
        loss2, acc2 = model2.evaluate(self.test_data, self.test_labels, verbose=2)
        
        proba_model1 = tf.keras.Sequential(
            [
                model1, 
                tf.keras.layers.Softmax()
            ])
        
        proba_model2 = tf.keras.Sequential(
            [
                model2, 
                tf.keras.layers.Softmax()
            ])
        
        preds1 = proba_model1.predict(self.test_data)
        preds1 = np.array([np.argmax(pred) for pred in preds1])
        conf1 = tf.math.confusion_matrix(self.test_labels, preds1)
        np.savetxt(f'{out_dir}/preds_{epoch}.txt', preds1, delimiter=',')
        
        preds2 = proba_model2.predict(self.test_data)
        preds2 = np.array([np.argmax(pred) for pred in preds2])
        conf2 = tf.math.confusion_matrix(self.test_labels, preds2)
        
        
        msg = f"""Epoch {epoch}: Loss - {loss1}, Acc - {acc1}
        Epoch {epoch+1}: Loss - {loss2}, Acc - {acc2}
        
        Confusion Matrix {epoch}: \n{conf1}
        
        Confusion Matrix {epoch+1}: \n{conf2}"""
            
        with open(f"{out_dir}/metrics.txt", "w") as f:
            f.write(msg)
        f.close()
    
    def make_visualization(self,
                           filepath: str, 
                           delta: list, 
                           filename = 'out',
                           hexcode = None, 
                           to_export = True
                          ):
        if not hexcode:
            hexcode = '#b00707'

        cm = sns.light_palette(hexcode, as_cmap=True)
        df = pd.DataFrame(delta)
        df = df.style.background_gradient(cmap=cm)

        if to_export:
            dfi.export(df, f"{filepath}/{filename}.png")
        else:
            display(df)

        return df
        
    def make_callback(self,
                      monitor: str,
                     ):
        checkpoint_path = self.checkpoint_dir+"/epoch-{epoch:02d}.ckpt"

        # Create a callback that saves the model's weights
        cp_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath=checkpoint_path,
            monitor=monitor, 
            verbose=1, 
            save_best_only=False, 
            mode='max')
        
        self.callback = cp_callback
        
        return cp_callback

In [2]:
import tensorflow as tf
from tensorflow import keras

print(tf.version.VERSION)

2.6.0


In [3]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

In [4]:
# Define a simple sequential model
def create_model():
    model = tf.keras.Sequential([
        keras.layers.Dense(64, activation='relu', input_shape=(784,)),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(10)
    ])

    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
                 )

    return model

# Create a basic model instance
model = create_model()

# Display the model's architecture
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 64)                50240     
_________________________________________________________________
dropout (Dropout)            (None, 64)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                650       
Total params: 50,890
Trainable params: 50,890
Non-trainable params: 0
_________________________________________________________________


In [5]:
ml_exp = MLExplain(checkpoint_dir='training', test_data=test_images, test_labels=test_labels)
ml_exp.make_callback(monitor='val_loss')

<keras.callbacks.ModelCheckpoint at 0x7fd0456b1400>

In [6]:
# Train the model with the new callback
model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images, test_labels),
          callbacks=[ml_exp.callback])  # Pass callback to training

Epoch 1/10

Epoch 00001: saving model to training/epoch-01.ckpt
INFO:tensorflow:Assets written to: training/epoch-01.ckpt/assets
Epoch 2/10

Epoch 00002: saving model to training/epoch-02.ckpt
INFO:tensorflow:Assets written to: training/epoch-02.ckpt/assets
Epoch 3/10

Epoch 00003: saving model to training/epoch-03.ckpt
INFO:tensorflow:Assets written to: training/epoch-03.ckpt/assets
Epoch 4/10

Epoch 00004: saving model to training/epoch-04.ckpt
INFO:tensorflow:Assets written to: training/epoch-04.ckpt/assets
Epoch 5/10

Epoch 00005: saving model to training/epoch-05.ckpt
INFO:tensorflow:Assets written to: training/epoch-05.ckpt/assets
Epoch 6/10

Epoch 00006: saving model to training/epoch-06.ckpt
INFO:tensorflow:Assets written to: training/epoch-06.ckpt/assets
Epoch 7/10

Epoch 00007: saving model to training/epoch-07.ckpt
INFO:tensorflow:Assets written to: training/epoch-07.ckpt/assets
Epoch 8/10

Epoch 00008: saving model to training/epoch-08.ckpt
INFO:tensorflow:Assets written to

<keras.callbacks.History at 0x7fd048d7d580>

In [7]:
ml_exp.compare_models()

/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
32/32 - 0s - loss: 1.4338 - sparse_categorical_accuracy: 0.6290
32/32 - 0s - loss: 0.9121 - sparse_categorical_accuracy: 0.7380
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Goog

32/32 - 0s - loss: 0.9121 - sparse_categorical_accuracy: 0.7380
32/32 - 0s - loss: 0.7249 - sparse_categorical_accuracy: 0.7950
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
32/32 - 0s - loss: 0.7249 - sparse_categorical_accuracy: 0.7950
32/32 - 0s - loss: 0.6203 - sparse_categorical_accuracy: 0.8290
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Applications/Google Chrome.app/Contents/MacOS/Google Chrome
/Application

{1: 'training/epoch-01.ckpt',
 2: 'training/epoch-02.ckpt',
 3: 'training/epoch-03.ckpt',
 4: 'training/epoch-04.ckpt',
 5: 'training/epoch-05.ckpt',
 6: 'training/epoch-06.ckpt',
 7: 'training/epoch-07.ckpt',
 8: 'training/epoch-08.ckpt',
 9: 'training/epoch-09.ckpt',
 10: 'training/epoch-10.ckpt'}