In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, roc_curve, auc, classification_report
from sklearn.preprocessing import label_binarize
import seaborn as sns

# Load data and preprocess it
def load_data():
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    x_train = x_train[..., tf.newaxis]
    x_test = x_test[..., tf.newaxis]
    y_train = tf.keras.utils.to_categorical(y_train, 10)
    y_test = tf.keras.utils.to_categorical(y_test, 10)
    return x_train, y_train, x_test, y_test

# Define a simple Capsule Layer
def CapsuleLayer(inputs, num_capsule, dim_capsule, routings):
    x = layers.Dense(num_capsule * dim_capsule)(inputs)
    x = layers.Reshape((num_capsule, dim_capsule))(x)
    return layers.Lambda(lambda s: tf.sqrt(tf.reduce_sum(tf.square(s), axis=-1)))(x)

# Build the CapsNet model
def build_capsnet(input_shape):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(64, (3, 3), activation='relu')(inputs)
    x = layers.Conv2D(64, (3, 3), activation='relu')(x)
    x = layers.Flatten()(x)  # Flatten preserving batch-size
    capsule_output = CapsuleLayer(x, num_capsule=10, dim_capsule=16, routings=3)
    output = layers.Dense(10, activation='softmax')(capsule_output)
    model = models.Model(inputs=inputs, outputs=output)
    return model

# Train the model
def train_model(model, x_train, y_train, x_test, y_test, epochs=10, batch_size=64):
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    history = model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=epochs, batch_size=batch_size)
    return history

# Load and preprocess data
x_train, y_train, x_test, y_test = load_data()

# Build and train the initial model
print("Training initial model...")
model = build_capsnet(x_train.shape[1:])
initial_history = train_model(model, x_train[:30000], y_train[:30000], x_test, y_test)

# Save the initial trained model
model.save('pretrained_capsnet.h5')

# Load the pre-trained model and apply transfer learning
print("Applying transfer learning...")
model = tf.keras.models.load_model('pretrained_capsnet.h5')

# Optionally freeze layers or modify the model here
for layer in model.layers[:-1]:  # Freeze all but the last layer
    layer.trainable = False

# Retrain on the entire data
transfer_history = train_model(model, x_train, y_train, x_test, y_test)

# Function to plot confusion matrix
def plot_confusion_matrix(y_true, y_pred):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 10))
    sns.heatmap(cm, annot=True, fmt="d", linewidths=.5)
    plt.title('Confusion Matrix')
    plt.ylabel('Actual Label')
    plt.xlabel('Predicted Label')
    plt.show()

# Function to plot ROC curve
def plot_roc_curve(y_test, y_pred):
    # Binarize the output
    y_test_binarized = label_binarize(y_test, classes=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
    n_classes = y_test_binarized.shape[1]

    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test_binarized[:, i], y_pred[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Plot all ROC curves
    plt.figure(figsize=(8, 6))
    colors = iter(plt.cm.rainbow(np.linspace(0, 1, n_classes)))
    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=2,
                 label='ROC curve of class {0} (area = {1:0.2f})'.format(i, roc_auc[i]))

    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic for multi-class')
    plt.legend(loc="lower right")
    plt.show()

# Evaluate model
print("Evaluating model...")
y_pred = model.predict(x_test)
y_pred_labels = np.argmax(y_pred, axis=1)
y_true_labels = np.argmax(y_test, axis=1)
plot_confusion_matrix(y_true_labels, y_pred_labels)
plot_roc_curve(y_test, y_pred)
print(classification_report(y_true_labels, y_pred_labels))
