In [2]:
# @title Imports

import numpy as np

# Scikit-Learn for machine learning utilities
from sklearn.decomposition import PCA
from sklearn import manifold

# --- Plotting tools
import seaborn as sns
import matplotlib.pyplot as plt

# --- Torch tools for the RNN
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [31]:
# @title 3. RNN Model Definition {display-mode: "form"}
class MemoryRNN(nn.Module):
    def __init__(self, input_dim, hidden_dim=100, num_classes=10, rnn_type="RNN"):
        super().__init__()
        if rnn_type == "LSTM":
            self.rnn = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        elif rnn_type == "GRU":
            self.rnn = nn.GRU(input_dim, hidden_dim, batch_first=True)
        else:
            self.rnn = nn.RNN(input_dim, hidden_dim, batch_first=True)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_dim, num_classes)
    def forward(self, x, return_seq=True):
        h_seq, _ = self.rnn(x)
        dropout = self.dropout(h_seq[:, -1])
        out = self.fc(dropout)
        if return_seq:
            return out, h_seq
        return out

In [3]:
class MemoryDataset(Dataset):
    def __init__(self, X, y, noise=None, noise_std=0.05):
        self.X = X
        self.y = y
        self.noise = noise
        self.noise_std = noise_std

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx]
        y = self.y[idx]

        if self.noise == "input":
            x = x + self.noise_std * torch.randn_like(x)

        return x, y

In [40]:
# @title Training Function (clean / input‑noise / weight‑noise) {display-mode: "form"}

def train_rnn(X, y, batch_size, variant, learning_rate, epochs=1000, noise_std=0.05):
    ds = MemoryDataset(X, y, noise="input" if variant=="input_noise" else None,
                       noise_std=noise_std)
    dl = DataLoader(ds, batch_size, shuffle=True, drop_last=False)
    print(len(dl))
    print(X.shape[1])

    model = MemoryRNN(X.shape[-1], rnn_type='GRU').to(device)
    opt   = torch.optim.Adam(model.parameters(), lr=learning_rate)
    lossf = nn.CrossEntropyLoss()
    print(f" Training variant: {variant}")

    losses = [] # List to store loss at each epoch
    for ep in range(epochs):
        running = 0
        for xb, yb in dl:
            xb, yb = xb.to(device), yb.to(device)

            # forward
            pred, _ = model(xb)
            #yb_class_idx = torch.argmax(yb, dim=0).to(torch.float32)
            loss = lossf(pred, yb)

            # weight noise variant
            if variant == "weight_noise":
                for p in model.parameters():
                    p.data += noise_std * torch.randn_like(p)

            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()
            running += loss.item()
        epoch_loss = running / len(dl)
        losses.append(epoch_loss) # Store the loss
        print(f"  Epoch {ep+1}/{epochs} | loss={epoch_loss:.4f}")
    return model, losses # Return model and losses

In [57]:
# @title Model Evaluation

def evaluate_model(model, X_test, y_test, batch_size=128):
    model.eval() # Set model to evaluation mode
    test_ds = MemoryDataset(X_test, y_test)
    test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

    correct = 0
    total = 0
    hidden_states_list = []
    with torch.no_grad(): # Disable gradient calculations during evaluation
        for xb, yb in test_dl:
            xb, yb = xb.to(device), yb.to(device)
            outputs, hidden_states = model(xb)
            _, predicted = torch.max(outputs.data, 1)
            total += yb.size(0)
            correct += (predicted == yb).sum().item()
            hidden_states_list.append(hidden_states)
    accuracy = 100 * correct / total
    print(f'\nAccuracy on test set: {accuracy:.2f}%')
    return accuracy, hidden_states_list

In [7]:
# @title MNIST Dataset preparation

import torchvision
import torchvision.transforms as transforms

mnist_transform = transforms.Compose([
    transforms.ToTensor(), # Converts PIL Image to FloatTensor and scales to [0, 1]
    transforms.Normalize((0.1307,), (0.3081,)) # Standard normalization for MNIST
])

# Load MNIST training dataset
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=mnist_transform,
    download=True
)

# Load MNIST test dataset
test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    transform=mnist_transform,
    download=True
)

# Prepare data for RNN: reshape images
# Each image (1, 28, 28) needs to become (28, 28) for the RNN,
# where 28 is sequence_length and 28 is input_dim.
# The DataLoader will add the batch dimension, making it (batch_size, 28, 28).

# Extract X and y from datasets and reshape
# X_train_mnist will be (num_samples, sequence_length, input_dim) -> (60000, 28, 28)
X_train_mnist = train_dataset.data.float().view(-1, 28, 28) # Reshape to (num_samples, 28, 28)
y_train_mnist = train_dataset.targets # Labels are already LongTensor (integers)

# X_test_mnist will be (num_samples, sequence_length, input_dim) -> (10000, 28, 28)
X_test_mnist = test_dataset.data.float().view(-1, 28, 28)
y_test_mnist = test_dataset.targets

100%|██████████| 9.91M/9.91M [00:00<00:00, 51.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.64MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 11.5MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 1.39MB/s]


In [42]:
#@title Training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_clean, losses = train_rnn(X_train_mnist, y_train_mnist, batch_size=128, learning_rate=1e-3, epochs=20, variant="clean")

469
28
 Training variant: clean
  Epoch 1/20 | loss=1.0256
  Epoch 2/20 | loss=0.5538
  Epoch 3/20 | loss=0.4567
  Epoch 4/20 | loss=0.3997
  Epoch 5/20 | loss=0.3533
  Epoch 6/20 | loss=0.3265
  Epoch 7/20 | loss=0.3040
  Epoch 8/20 | loss=0.2819
  Epoch 9/20 | loss=0.2692
  Epoch 10/20 | loss=0.2613
  Epoch 11/20 | loss=0.2503
  Epoch 12/20 | loss=0.2397
  Epoch 13/20 | loss=0.2325
  Epoch 14/20 | loss=0.2266
  Epoch 15/20 | loss=0.2189
  Epoch 16/20 | loss=0.2141
  Epoch 17/20 | loss=0.2083
  Epoch 18/20 | loss=0.2035
  Epoch 19/20 | loss=0.2037
  Epoch 20/20 | loss=0.1958


In [61]:
#@title Evaluation
_, hidden_states = evaluate_model(model_clean, X_test_mnist, y_test_mnist, batch_size=128)
# --- Svaes the model hidden states for the test process
# --- Dimensions are batch_size=128, image_columns=28, rnn_hidden_units=100
torch.save(hidden_states, 'hidden_states.pt')


Accuracy on test set: 92.16%


In [62]:
hidden_states[0].shape

torch.Size([128, 28, 100])