In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(1,10,kernel_size=5)
        self.conv2=nn.Conv2d(10,20,kernel_size=5)
        self.conv2_drop=nn.Dropout() #dropout layer
        self.fc1=nn.Linear(320,50)
        self.fc2=nn.Linear(50,10)
         
        self.localization=nn.Sequential(nn.Conv2d(1,8,kernel_size=7),
                                        nn.MaxPool2d(2,stride=2),
                                        nn.ReLU(True),
                                        nn.Conv2d(8,10,kernel_size=5),
                                        nn.MaxPool2d(2,stride=2),
                                        nn.ReLU(True)) #convolutional followed by max pooling and activation function 2x
        self.fc_loc=nn.Sequential(nn.Linear(10*3*3,32),
                                  nn.ReLU(True),
                                  nn.Linear(32,3*2)) #takes output of convolutional layers, reshapes it, passes it through 2 fully connected layers 
                                                     #to output parameters for the spatial transformation
        self.fc_loc[2].weight.data.zero_() #initializes weights of last fully connected layer
        self.fc_loc[2].bias.data.copy_(torch.tensor([1,0,0,0,1,0])) #initalizes bias of last fully connected layer
        
    def stn(self,x): #spatial transformer
        xs=self.localization(x) #localization network extracts relevant features for the spatial transformation
        xs=xs.view(-1,10*3*3) #reshapes
        theta=self.fc_loc(xs) #applies fully connected layer to predict parameters for the spatial transformation
        theta=theta.view(-1,2,3) #reshapes
        grid=F.affine_grid(theta,x.size()) #sampling grid to match coordinates from input image to construct output
        x=F.grid_sample(x,grid)
        return x
    def forward(self,x):
        x=self.stn(x)
        x=F.relu(F.max_pool2d(self.conv1(x),2))
        x=F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2))
        x=x.view(-1,320) #flattened to 1-dimensional tensor
        x=self.fc1(x)
        x=F.relu(x)
        x=F.dropout(x,training=self.training)
        x=self.fc2(x)
        return F.log_softmax(x,dim=1) #output tensor containing probabilities for each class
    
net=Net().to(device)
        

In [None]:
images,target=next(iter(train_loader))
optimizer=optim.SGD(net.parameters(),lr=0.01)
t_loss = []
acc = []
def train(epoch):
    net.train()
    for batch_idx,(data,target) in enumerate(train_loader):
        data,target=data.to(device),target.to(device)
        optimizer.zero_grad()
        output=net(data)
        loss=F.nll_loss(output,target)
        loss.backward() #computes gradients of loss
        optimizer.step() #updates model parameters
        if batch_idx % 500 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test():
    with torch.no_grad():
        net.eval()
        test_loss=0
        correct=0
        for data,target in test_loader:
            data,target=data.to(device),target.to(device)
            output=net(data)
            test_loss+=F.nll_loss(output,target,size_average=False).item()
            pred=output.max(1,keepdim=True)[1] #obtain predicted class label
            correct+=pred.eq(target.view_as(pred)).sum().item()
        test_loss/=len(test_loader.dataset)
        t_loss.append(test_loss)
        acc.append(correct/len(test_loader.dataset))
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
              .format(test_loss, correct, len(test_loader.dataset),
                      100. * correct / len(test_loader.dataset)))    