In [0]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from math import log10
from __future__ import print_function

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_epochs = 30
learning_rate = 0.01

# Image preprocessing modules
transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()])

# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='../../data/',
                                             train=True, 
                                             transform=transform,
                                             download=True)

test_dataset = torchvision.datasets.CIFAR10(root='../../data/',
                                            train=False, 
                                            transform=transforms.ToTensor())

# Data loader
trainloader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=100, 
                                           shuffle=True)

testloader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=100, 
                                          shuffle=False)

class SRCNN(nn.Module):
    def __init__(self):
        #TODO: define padding
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, padding=3, bias=True)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=1, bias=True)
        self.relu = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
        #self.conv4 = nn.Conv2d(5, 1, kernel_size=1, bias=True)
        #self.shuffle = nn.PixelShuffle(3)
        #self.fc = nn.Linear(5, 20)

    def forward(self, x):
      x = self.conv1(x)
      #print(x.size())
      x = self.relu(x)
      x = self.conv2(x)
      #print(x.size())
      x = self.relu(x)
      x = self.conv3(x)
      #print(x.size())
      #x = self.conv4(x)
      #x = x.view(x.size(0), -1)
      #x = self.shuffle(x)
      #x = self.fc(x)

      return x

model=SRCNN().to(device)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# For updating learning rate
def update_lr(optimizer, lr):    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

# Train the model
total_step = len(trainloader)
curr_lr = learning_rate
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(trainloader):
        images = images.to(device)
        #print(images.size())
        #labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        #outputs = torch.Tensor([outputs.unsqueeze(0), 0, 0 ,0])
        #outputs.unsqueeze(0)
        #print(outputs.size())
        #print(outputs)
        loss = criterion(outputs, images)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}"
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

    # Decay learning rate
    if (epoch+1) % 20 == 0:
        curr_lr /= 3
        update_lr(optimizer, curr_lr)

# Test the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in testloader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        mse = criterion(prediction, label)
        psnr = 10 * log10(1/mse.item())
        avg_psnr += psnr
    print(batch_num, len(testingloader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))
   #  print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))

# Save the model checkpoint
torch.save(model.state_dict(), 'resnet.ckpt')
"""
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
train_loss = 0

for epoch in range(num_epochs):
    for batch_num, (image, label) in enumerate(trainloader):
        image = image.to(device)
        label = label.to(device).unsqueeze(1)
        print(image.size())
        print(label.size())

        outputs = model(image)
        loss = criterion(outputs, label)
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(batch_num, len(self.training_loader), 'Loss: %.4f' % (train_loss / (batch_num + 1)))

    print("    Average Loss: {:.4f}".format(train_loss / len(self.training_loader)))

model.eval().to(device)
with torch.no_grad():
    for batch_num, (image, label) in enumerate(testloader):
        image = image.to(device)
        label = label.to(device)

        prediction = model(image)
        mse = criterion(prediction, label)
        psnr = 10 * log10(1/mse.item())
        avg_psnr += psnr
        progress_bar(batch_num, len(testing_loader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))
"""

Files already downloaded and verified
Epoch [1/30], Step [100/500] Loss: 0.0113
Epoch [1/30], Step [200/500] Loss: 0.0081
Epoch [1/30], Step [300/500] Loss: 0.0058
Epoch [1/30], Step [400/500] Loss: 0.0058
Epoch [1/30], Step [500/500] Loss: 0.0041
Epoch [2/30], Step [100/500] Loss: 0.0040
Epoch [2/30], Step [200/500] Loss: 0.0055
Epoch [2/30], Step [300/500] Loss: 0.0030
Epoch [2/30], Step [400/500] Loss: 0.0032
Epoch [2/30], Step [500/500] Loss: 0.0026
Epoch [3/30], Step [100/500] Loss: 0.0024
Epoch [3/30], Step [200/500] Loss: 0.0030
Epoch [3/30], Step [300/500] Loss: 0.0022
Epoch [3/30], Step [400/500] Loss: 0.0019
Epoch [3/30], Step [500/500] Loss: 0.0020
Epoch [4/30], Step [100/500] Loss: 0.0023
Epoch [4/30], Step [200/500] Loss: 0.0017
Epoch [4/30], Step [300/500] Loss: 0.0044
Epoch [4/30], Step [400/500] Loss: 0.0296
Epoch [4/30], Step [500/500] Loss: 0.0018
Epoch [5/30], Step [100/500] Loss: 0.0014
Epoch [5/30], Step [200/500] Loss: 0.0015
Epoch [5/30], Step [300/500] Loss: 0.0

RuntimeError: ignored