## Training Script

### Dataloader

In [1]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms

class ClutteredMNISTDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Custom dataset for Cluttered MNIST.
        :param root_dir: Root directory of the dataset (e.g., "dataset/cluttered_mnist")
        :param transform: Optional torchvision transforms to apply to the images
        """
        self.root_dir = root_dir
        self.transform = transform

        # Gather all image paths and their labels
        self.data = []
        for label in range(10):  # Assuming labels are 0-9
            label_dir = os.path.join(root_dir, str(label))
            if os.path.isdir(label_dir):
                for file_name in os.listdir(label_dir):
                    if file_name.endswith(('.png', '.jpg', '.jpeg')):
                        file_path = os.path.join(label_dir, file_name)
                        self.data.append((file_path, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        Retrieve an image and its label at the specified index.
        :param idx: Index of the data point
        :return: Tuple (image, label)
        """
        image_path, label = self.data[idx]
        image = Image.open(image_path).convert('L')  # Convert to grayscale

        if self.transform:
            image = self.transform(image)

        return image, label

# Define transforms for the dataset
transform = transforms.Compose([
    transforms.Resize((100, 100)),  # Resize images to 100x100
    transforms.ToTensor(),          # Convert images to PyTorch tensors
    transforms.Normalize((0.1307,), (0.3081,))  # Normalize using MNIST mean and std
])

# Create dataset
dataset_dir = "dataset/cluttered_mnist"
cluttered_mnist_dataset = ClutteredMNISTDataset(root_dir=dataset_dir, transform=transform)

# Split dataset into train and test (80% train, 20% test)
train_size = int(0.9 * len(cluttered_mnist_dataset))
test_size = len(cluttered_mnist_dataset) - train_size
train_dataset, test_dataset = random_split(cluttered_mnist_dataset, [train_size, test_size])

# DataLoader for batching and shuffling
batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Example: Iterate through the DataLoader
for images, labels in train_loader:
    print("Training batch of images shape:", images.shape)  # (batch_size, 1, 100, 100)
    print("Training batch of labels shape:", labels.shape)  # (batch_size,)
    break

for images, labels in test_loader:
    print("Test batch of images shape:", images.shape)  # (batch_size, 1, 100, 100)
    print("Test batch of labels shape:", labels.shape)  # (batch_size,)
    break

print(f"train_size: {train_size}, test_size: {test_size}")
64*900

Training batch of images shape: torch.Size([64, 1, 100, 100])
Training batch of labels shape: torch.Size([64])
Test batch of images shape: torch.Size([64, 1, 100, 100])
Test batch of labels shape: torch.Size([64])
train_size: 54000, test_size: 6000


57600

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from glimpse import GlimpseModel

"""print('the code god was here')"""
# Define the RNN model
class MNISTRNN(nn.Module):
    def __init__(self, image_size, hidden_size, num_layers, num_classes, num_kernels):
        super(MNISTRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.image_size = image_size

        # RNN to process the crops
        self.rnn = nn.RNN(num_kernels, hidden_size, num_layers, batch_first=True)
        self.fc_class = nn.Linear(hidden_size, num_classes)  # Class prediction
        self.fc_action = nn.Linear(hidden_size, 2)  # Next crop center prediction
        
        self.eyes = GlimpseModel((image_size, image_size), num_kernels)
        
        self.sc = torch.zeros((batch_size, 2))
        self.sz = torch.ones((batch_size, 1))

    def crop(self, padded_image, center):
        """
        Crop a region around the given center from the padded image.
        :param padded_image: Padded input image (batch_size, 1, padded_size, padded_size)
        :param center: Crop centers (batch_size, 2)
        :return: Cropped image regions (batch_size, crop_size, crop_size)
        """
        crop_size = self.crop_size
        half_crop = crop_size // 2

        # Compute cropping indices
        x_start = (center[:, 0] - half_crop).long()
        x_end = (center[:, 0] + half_crop).long()
        y_start = (center[:, 1] - half_crop).long()
        y_end = (center[:, 1] + half_crop).long()

        # Perform efficient tensor slicing
        crops = torch.stack([
            padded_image[b, :, y_start[b].item():y_end[b].item(), x_start[b].item():x_end[b].item()]
            for b in range(padded_image.size(0))
        ])
        return crops

    def forward(self, images, center, h0):
        """
        Forward pass with dynamic cropping and RNN processing.
        :param image: Full input image (batch_size, 1, 28, 28)
        :param center: Initial crop centers (batch_size, 2)
        :param h0: Initial hidden state (num_layers, batch_size, hidden_size)
        :return: Class prediction, next center, hidden state
        """
        batch_size = len(images)
        
         # TODO: This is the thing that we need to control.
        input = images.squeeze(1)
        output_tensor = self.eyes(input, self.sc, self.sz) # (B, 144)

        # # Process with RNN
        # crops = crops.unsqueeze(1)  # Add sequence dimension (batch_size, seq_len=1, crop_size^2)
        rnn_input = output_tensor.view(batch_size, 1, 144)
        
        
        out, hn = self.rnn(rnn_input, h0)

        # Predict class and next crop center
        class_pred = self.fc_class(out[:, -1, :])  # Class prediction
        action_pred = self.fc_action(out[:, -1, :])  # Action (next crop center)

        return class_pred, action_pred, hn

In [3]:
# Hyperparameters
image_size = 100
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 64
learning_rate = 0.001
num_epochs = 10
num_steps = 3  # RNN steps per image
num_kernels = 12*12

In [4]:
# Create model, optimizer, and loss functions
model = MNISTRNN(image_size, hidden_size, num_layers, num_classes, num_kernels)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion_class = nn.CrossEntropyLoss()
criterion_action = nn.MSELoss()  # For predicting the next center

## Train That Bad Boy

In [5]:
# Training loop
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        batch_size = images.size(0)

        # Initialize the hidden state and center
        h0 = torch.zeros(num_layers, batch_size, hidden_size).to(images.device)
        centers = torch.tensor([[14, 14]] * batch_size).float().to(images.device)  # Initial center

        loss = 0
        for step in range(num_steps):
            # Forward pass
            class_pred, action_pred, h0 = model(images, centers, h0)

            # Compute losses
            loss_class = criterion_class(class_pred, labels)
            # loss_action = criterion_action(action_pred, centers)  # Target is to stay at the initial center
            loss += loss_class
            
            print("Completed a step")

            # Update centers for the next step
            # centers = torch.clip(action_pred, min=crop_size // 2, max=28 + crop_size // 2)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 1 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')

1
2
3


4
5
6
Completed a step
1
2
3
4
5
6
Completed a step
1
2
3
4
5
6
Completed a step
Epoch [1/10], Step [1/844], Loss: 6.9147
1
2
3
4
5
6
Completed a step
1
2
3
4
5
6
Completed a step
1
2
3
4
5
6
Completed a step
Epoch [1/10], Step [2/844], Loss: 7.0246
1
2
3
4
5
6
Completed a step
1
2
3
4
5
6
Completed a step
1
2
3
4
5
6
Completed a step
Epoch [1/10], Step [3/844], Loss: 7.0200
1
2
3
4
5
6
Completed a step
1
2
3
4
5
6
Completed a step
1
2
3
4
5
6
Completed a step
Epoch [1/10], Step [4/844], Loss: 6.9341
1
2
3
4
5
6
Completed a step
1
2
3
4
5
6
Completed a step
1
2
3
4
5
6
Completed a step


KeyboardInterrupt: 

In [None]:
# Evaluate the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        batch_size = images.size(0)
        h0 = torch.zeros(num_layers, batch_size, hidden_size).to(images.device)
        
        for step in range(num_steps):
            class_pred, action_pred, h0 = model(images, None, h0)

        _, predicted = torch.max(class_pred.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

Accuracy of the model on the test images: 97.25%
