# Spatial Transformer Networks

### Import

In [None]:
# License: BSD
# Author: Ghassen Hamrouni

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

plt.ion()   # interactive mode

### Loading the data

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training dataset
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])), batch_size=64, shuffle=True, num_workers=4)
# Test dataset
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])), batch_size=64, shuffle=True, num_workers=4)

In [None]:
from model import Net

model = Net().to(device)

### Training the model

In [None]:
from main import train, test

optimizer = optim.SGD(model.parameters(), lr=0.01)

for epoch in range(1, 20 + 1):
    train(epoch, model, optimizer, train_loader, device)
    test(model, test_loader, device)

### Visualizing the STN results

In [None]:
from utils import visualize_stn

# Visualize the STN transformation on some input batch
visualize_stn()
plt.ioff()
plt.show()