In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=5, shuffle=False)

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=3)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(6 * 13 * 13, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.pool(x)
        x = x.view(-1, 6 * 13 * 13)
        x = self.fc1(x)
        return x

model = SimpleCNN()


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader):.4f}")


filters = model.conv1.weight.data

print("Visualizing filters:")
fig, axes = plt.subplots(1, 6, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
    filter_img = filters[i, 0].detach().numpy()
    ax.imshow(filter_img, cmap='gray')
    ax.axis('off')
plt.show()

model.eval()
images, _ = next(iter(test_loader))
with torch.no_grad():
    feature_maps = model.conv1(images)

print("Visualizing feature maps for the first 5 test images:")
fig, axes = plt.subplots(5, 6, figsize=(12, 12)) 
for i in range(5):
    for j in range(6):
        feature_map = feature_maps[i, j].detach().numpy()
        axes[i, j].imshow(feature_map, cmap='gray')
        axes[i, j].axis('off')
plt.show()
