<a href="https://colab.research.google.com/github/ydsyvn/mnist-activation-maximization/blob/main/barebones.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from nn_from_scratch import DeepNeuralNetwork

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
load_path = '/content/drive/MyDrive/Data/Mech Interp/mnist_model_weights.npz'

loaded_weights = np.load(load_path)

input_size = 784  # For MNIST (28*28)
hidden_size = 128
output_size = 10  # 10 digits

nn = DeepNeuralNetwork(input_size, hidden_size, output_size)

# Assign the loaded weights and biases to the new network instance
nn.W1 = loaded_weights['W1']
nn.b1 = loaded_weights['b1']
nn.W2 = loaded_weights['W2']
nn.b2 = loaded_weights['b2']

print("Model weights and biases loaded successfully.")

Model weights and biases loaded successfully.


In [None]:
import tensorflow as tf
import numpy as np

# Load MNIST using TensorFlow/Keras (you can change this to another library if you prefer)
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

# Preprocess the data
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0
X_train = X_train.reshape(-1, 28*28)
X_test = X_test.reshape(-1, 28*28)

# Custom one-hot encoding function
def one_hot_encode(y, num_classes=10):
    encoded = np.zeros((y.size, num_classes))
    encoded[np.arange(y.size), y] = 1
    return encoded

# Apply custom one-hot encoding
y_train = one_hot_encode(y_train)
y_test = one_hot_encode(y_test)


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [None]:
class ActivationMaximizationNN(DeepNeuralNetwork):
  def __init__(self, input_size, hidden_size, output_size, learning_rate=0.01):
    super().__init__(input_size, hidden_size, output_size, learning_rate)

  def generate_random_image(self, mean=0.5, std=0.1):
    """ Generate random image """
    image = np.random.normal(loc=mean, scale=std, size=(28, 28))
    return np.clip(image, 0.0, 1.0)

  def show_images_grid(self, images, title="Activation Maximization"):
    """ Display grid of images """
    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(2, 5, figsize=(10, 4))
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i].reshape(28, 28), cmap='gray')
        ax.set_title(f'Class {i}')
        ax.axis('off')
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

  def activation_max_backward(self, target_class):
    """d_a2 = np.zeros_like(self.a2)
    d_a2[0, target_class] = 1.0

    d_z2 = self.a2.copy()
    d_z2[0, target_class] -= 1.0
    d_z2 = -d_z2

    d_a1 = np.dot(d_z2, self.W2.T)
    d_z1 = d_a1 * self.relu_derivative(self.z1)
    d_x = np.dot(d_z1, self.W1.T)"""

    d_z2 = np.zeros_like(self.z2)
    d_z2[0, target_class] = 1.0

    # Backpropagate through the network
    d_a1 = np.dot(d_z2, self.W2.T)
    d_z1 = d_a1 * self.relu_derivative(self.z1)
    d_x = np.dot(d_z1, self.W1.T)

    return d_x

  def activation_maximize_class(self, target_class, steps=100, lr=0.01, verbose=True):
    """
    Maximize activation for specific class
    """
    image = self.generate_random_image(mean=0.1, std=0.5)
    history = []

    if verbose:
      print(f"Maximizing class {target_class}")

    for step in range(steps):
      image_flat = image.reshape(1, 784)  # flatten image

      # Forward pass
      output = self.forward(image_flat)
      activation = self.z2[0, target_class]

      # Compute gradients
      grad = self.activation_max_backward(target_class)
      grad = grad.reshape(28, 28)

      image += lr * grad

      image = np.clip(image, 0.0, 1.0)

      if step % 50 == 0 and step > 0:
        lr *= 0.9

      history.append(activation)

      if verbose and (step % 10 == 0 or step+1 == steps):
        print(f"Step {step+1}/{steps} | Activation: {activation:.4f}")

    if verbose:
      print()

    return history, image.reshape(1, 784)

  def activation_maximize_all_classes(self, steps=100, lr=0.01, num_attempts=3, visualize=True, verbose=True):
    best_images = []
    history = []

    for idx in range(10):

      best_image = None
      best_activation = -np.inf
      class_history = []

      for attempt in range(num_attempts):
        if verbose and num_attempts > 1:
          print(f"Attempt {attempt+1}/{num_attempts}")

        attempt_history, image = self.activation_maximize_class(
          target_class=idx,
          steps=steps,
          lr=lr,
          verbose=verbose
        )

        output = self.forward(image)
        activation = self.z2[0, idx]

        if activation > best_activation:
          best_image = image.copy()
          best_activation = activation

        class_history.append(attempt_history)

        if verbose and num_attempts > 1:
          print(f"Final activation: {activation:.4f}")

      best_images.append(best_image)
      history.append(class_history)

      if verbose:
        print(f"Best activation for class {idx}: {best_activation:.4f}")

    if visualize:
      self.show_images_grid(best_images)

    return history, best_images


In [None]:
def visualize_image(img):
  if (img.shape == (784,) or img.shape == (1, 784)):
    img = img.reshape(28, 28)

  plt.imshow(img, cmap='gray')
  plt.axis('off')
  plt.show()

In [None]:
act_max_nn = ActivationMaximizationNN(input_size=784, hidden_size=128, output_size=10)

act_max_nn.W1 = loaded_weights['W1']
act_max_nn.b1 = loaded_weights['b1']
act_max_nn.W2 = loaded_weights['W2']
act_max_nn.b2 = loaded_weights['b2']

In [None]:
history, images = act_max_nn.activation_maximize_all_classes(steps=200, lr=0.1, visualize=False)

In [None]:
act_max_nn.show_images_grid(images)

In [None]:
def plot_activation_history_grid(history):
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.flatten()

    for class_idx in range(10):
        ax = axes[class_idx]
        for attempt_idx, attempt in enumerate(history[class_idx]):
            ax.plot(attempt, label=f'Attempt {attempt_idx+1}')
        ax.set_title(f'Class {class_idx}')
        ax.set_xlabel('Step')
        ax.set_ylabel('Activation')
        ax.legend(fontsize='small')
        ax.grid(True)

    plt.tight_layout()
    plt.show()

In [None]:
plot_activation_history_grid(history)