This is the main code for the paper titled **"Improving Out-of-Distribution Data Handling and Corruption Resistance via Modern Hopfield Networks"**.

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

In [13]:
import sys
sys.path.append('/content/drive/MyDrive/Hopfield/')

In [None]:
!pip install light-the-torch >> /.tmp
!ltt install torch torchvision >> /.tmp
!pip install fastai --upgrade >> /.tmp

In [None]:
cd /content/drive/MyDrive/Hopfield/

In [None]:
# Importing necessary libraries
import torch
import numpy as np
import time
import pickle
import matplotlib.pyplot as plt
from torchvision import transforms
from torch import nn

# Store the appropriate device
use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_available()

if use_cuda:
    device = torch.device("cuda")
elif use_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(device)

# Training the HopfieldPooling

In this section, we train the HopfieldPooling layer on the denoising task. To do so, we utilize the official implementation of the HopfieldPooling layer (https://github.com/ml-jku/hopfield-layers). As a result, it is necessary to clone this repository.

In [None]:
!git clone https://github.com/ml-jku/hopfield-layers.git
!pip3 install git+https://github.com/ml-jku/hopfield-layers

In [None]:
!python train_denoising_task.py --save-model

In [None]:
# Loading the trained model
from train_denoising_task import HopfieldModule
hopfieldPooling = HopfieldModule()
hopfieldPooling = hopfieldPooling.to(device)
hopfieldPooling.load_state_dict(torch.load('models/hop_new.pt', map_location= device))

In [None]:
# Loading the training history
with open('logs/hopfield_denoise.pkl', 'rb') as f:
    hopfield_denoise_history = pickle.load(f)

hopfield_denoise_history["loss"]

# Training the base model

To ensure repeatability, we train and use the default convolutional neural network provided by the official PyTorch repository:

https://github.com/pytorch/examples/blob/main/mnist/main.py


In [None]:
# Training the baseline model
!python conv_mnist.py --save-model

In [None]:
# Loading the trained model
from conv_mnist import Net
baseline = Net()
baseline = baseline.to(device)
baseline.load_state_dict(torch.load('models/mnist_cnn.pt', map_location= device))

# Loading MNIST-C Test Data

In this section, we load and visualize the MNIST-C dataset. For this purpose, we used the implementation of `TORCH UNCERTAINTY` with some minor changes to fix some bugs.

You can find their official repository here:

https://github.com/ENSTA-U2IS-AI/torch-uncertainty/blob/main/torch_uncertainty/datasets/classification/mnist_c.py

In [None]:
import mnist_c

# Loading data for all corruptions
test_data_all = mnist_c.MNISTC(root=".", download=True, split = "test", transform=transforms.ToTensor(), subset = "all")

# Create the test loader for all corruptions
test_loader_all = torch.utils.data.DataLoader(test_data_all, batch_size=20, num_workers=1, shuffle = True)

In [None]:
def visualize_data(data_loader) -> None:
    """
    Helper method to visualize a sample of data.
    :param data_loader: The data loader to pull the samples from.
    :return: Nothing.
    """
    # Create a plot for four random samples with their labels.
    fig, ax = plt.subplots(2, 2, figsize=(6, 6))
    # Get a random batch from the data loader.
    images, labels = next(iter(data_loader))
    # Display each image and label.
    for i in range(4):
        img = images[i].squeeze()
        ax[i // 2, i % 2].imshow(img, cmap="gray")
        ax[i // 2, i % 2].axis("off")
        ax[i // 2, i % 2].set_title(f"Label: {labels[i].item()}")
    # Check out the shape of one batch.
    print(f"Shape of a batch images: {images.shape}")
    print(f"Shape of a batch labels: {labels.shape}")

visualize_data(test_loader_all)

# The Integration Algorithm

In this section, we implement our proposed integration algorithm using the pre-trained `hopfieldPooling` module and `baseline` model.



In [None]:
mnistc_subsets = [
    "identity",
    "brightness",
    "canny_edges",
    "dotted_line",
    "fog",
    "glass_blur",
    "impulse_noise",
    "motion_blur",
    "rotate",
    "scale",
    "shear",
    "shot_noise",
    "spatter",
    "stripe",
    "translate",
    "zigzag",
]

def test_hop(
    basemodel: nn.Module, hop: None | nn.Module, cdae: None | nn.Module, corruption: str
) -> None:
    test_data = mnist_c.MNISTC(root=".", split = "test", transform=transforms.ToTensor(), subset = corruption)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=20, num_workers=1, shuffle = True)
    number_use = 0
    basemodel.eval()
    test_loss = 0
    correct = 0
    number_hop_use = 0
    added_time = 0
    condition = hop is not None or cdae is not None
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            if condition:

                output1 = basemodel(data)

                if hop is not None:
                    output2 = basemodel(hop(data))
                else:
                    output2 = basemodel(cdae(data))

                prob1, pred1 = output1.max(
                    dim=1, keepdim=True
                )  # get the index of the max log-probability
                prob2, pred2 = output2.max(dim=1, keepdim=True)
                mult = prob1 > prob2
                pred = (mult * pred1) + ((~mult) * pred2)
                number_use += (~mult).sum().item()

            else:
                output = basemodel(data)
                pred = output.argmax(
                    dim=1, keepdim=True
                )  # get the index of the max log-probability

            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        "\nCorruption: {}, {}{} -> Accuracy: {}/{} ({:.2f}%)\n".format(
            corruption,
            (
                (", Hopfield-integrated" if hop is not None else "CDAE-integrated")
                if condition
                else "not-integrated"
            ),
            ", Percentage of Use: ({}/{}:{:0.2f})".format(number_use,
                                                        len(test_loader.dataset),
                                                        number_use / len(test_loader.dataset) * 100)
            if condition
            else "",
            correct,
            len(test_loader.dataset),
            100.0 * correct / len(test_loader.dataset),
        )
    )

    if condition:
        return correct / len(test_loader.dataset), number_use / len(test_loader.dataset)
    else:
        return correct / len(test_loader.dataset)

In [None]:
acc, acc_hop, hop_usage = {}, {}, {}

for sub in mnistc_subsets:
  acc[sub] = test_hop(baseline, hop = None, cdae= None, corruption=sub)
  acc_hop[sub], hop_usage[sub] = test_hop(baseline, hop = hopfieldPooling, cdae= None, corruption=sub)

In [None]:
# Save logs
with open(f'logs/acc.pkl', 'wb') as f:
  pickle.dump(acc, f)

with open(f'logs/acc_hop.pkl', 'wb') as f:
  pickle.dump(acc_hop, f)

with open(f'logs/hop_usage.pkl', 'wb') as f:
  pickle.dump(hop_usage, f)

In [None]:
# Calculate corruption robustness metrics
def calculate_robustness_metrics(acc, acc_integrated):
  baseline = {"corruption_accuracy": np.mean(np.array([v*100 for k, v in acc.items() if k != "identity"])),
              "relative mCE": 100,
              "mCE": 100}
  integrated = {"corruption_accuracy": np.mean(np.array([v*100 for k, v in acc_integrated.items() if k != "identity"])),
              "relative mCE": None,
              "mCE": None}


  def mCE(relative = False):
    numerator = 0
    denominator = 0
    for c in acc.keys():
      if c != "identity":
        denominator += (1-acc[c]) - ((1-acc["identity"]) if relative else 0)

    for c in acc_integrated.keys():
      if c != "identity":
        numerator += (1-acc_integrated[c]) - ((1-acc_integrated["identity"]) if relative else 0)

    return numerator/denominator*100

  integrated["relative mCE"] = mCE(relative = True)
  integrated["mCE"] = mCE(relative = False)

  return baseline, integrated

In [None]:
baseline_metrics, hopfield_integration_metrics = calculate_robustness_metrics(acc, acc_hop)

In [None]:
baseline_metrics

In [None]:
hopfield_integration_metrics

# Ablation Study

In this part, we replace the HopfieldPooling layer with a stacked Convolutional Denoising Autoencoder (CDAE) and compare the results.

### Denoising Task

In [None]:
!python train_denoising_task.py --cdae --save-model

In [None]:
# Loading the trained model
from train_denoising_task import CDAE
cdae = CDAE()
cdae = cdae.to(device)
cdae.load_state_dict(torch.load('models/CDAE.pt', map_location= device))

In [None]:
# Loading the training history
with open('logs/CDAE_denoise.pkl', 'rb') as f:
    cdae_denoise_history = pickle.load(f)

cdae_denoise_history["loss"]

In [None]:
# Compare HopfieldPooling layer and CDAE in terms of MSE for the denoising task
plt.figure(figsize=(15, 5), dpi = 200)
plt.plot(hopfield_denoise_history["loss"], marker = "o", label = "Hopfield")
plt.plot(cdae_denoise_history["loss"], marker = "^", label = "Autoencoder")
plt.xticks(range(20), labels=[f"{i}" for i in range(1, 21)])
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.show()

### Integration Algorithm

In [None]:
acc_AE, AE_usage = {}, {}

for sub in mnistc_subsets:
  acc_AE[sub], AE_usage[sub] = test_hop(baseline, hop = None, cdae= cdae, corruption=sub)

In [None]:
with open(f'acc_AE.pkl', 'wb') as f:
  pickle.dump(acc_AE, f)

with open(f'AE_usage.pkl', 'wb') as f:
  pickle.dump(AE_usage, f)

In [None]:
baseline_metrics, cdae_integration_metrics = calculate_robustness_metrics(acc, acc_AE)

In [None]:
baseline_metrics

In [None]:
cdae_integration_metrics

### Visualize Output of Modules

In [None]:
batch_size = 20
corruption = "fog"

test_data = mnist_c.MNISTC(root=".", split = "test", transform=transforms.ToTensor(), subset = corruption)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=1, shuffle = True)

# obtain one batch of test images
dataiter = iter(test_loader)
images, labels = next(dataiter)

images, labels = images.to(device), labels.to(device)

# get sample outputs
output1 = hopfieldPooling(images)
output2 = cdae(images)

# output is resized into a batch of images
output1 = output1.view(batch_size, 1, 28, 28)
output2 = output2.view(batch_size, 1, 28, 28)
# use detach when it's an output that requires_grad
output1 = output1.cpu().detach().numpy()
output2 = output2.cpu().detach().numpy()

# plot the first seven input images and then reconstructed images
fig, axes = plt.subplots(nrows=3, ncols=7, sharex=True, sharey=True, figsize=(10,4), dpi = 200)

r, c = 0, 0
y_labels = ["Corrupted", "Hopfield", "Autoencoder"]
# input images on top row, reconstructions on bottom
for img, row in zip([images.cpu(), output1, output2], axes):
    c = 0
    for i, (img, ax) in enumerate(zip(img, row)):
        if not r:
           ax.set_title(f"True label: {labels[i].item()}")
        ax.imshow(np.squeeze(img), cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.set_yticks([])
        if not c:
          ax.set_ylabel(y_labels[r])
        c+=1
    r+=1