In [None]:
%%writefile EncDec.py

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,4,3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(4,8,3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(8,16,3, stride =2, padding=1)
        self.deconv1 = nn.ConvTranspose2d(16,8,4, stride =2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(8,4,4, stride =2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(4,3,4, stride =2, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))      
        x = self.deconv3(x)
        return x

### Create an instance of the Net class
net = Net()

## Loading the training and test sets
# Converting the images for PILImage to tensor, so they can be accepted as the input to the network
transform = transforms.ToTensor()

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=5, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=5, shuffle=False)

### Define the loss and create your optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr = 0.0005)

### Main training loop
for epoch in range(5):
    total_batch_loss = 0.0
    for i, data in enumerate(trainloader, 0):

        ## Getting the input and the target from the training set
        input, dummy = data
        target = input
        out = net(input)
        loss = criterion(out, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_batch_loss += loss.item()

    #print('For %d epoch, average loss is %.3f'%(epoch, total_batch_loss/10000))
        
        if (i+1)%5000 == 0 and i!=0 :
            print('For %d epoch, %5d minibatch, average loss is %.5f'%(epoch, (i+1), total_batch_loss/5000))
            total_batch_loss = 0.0


### Testing the network on 10,000 test images and computing the loss
testLoss = 0.0
with torch.no_grad():
    for data in testloader:
        input, dummy = data
        target = input
        out = net(target)
        loss = criterion(out, target)
        testLoss += loss.item()
    print('Test Set Loss %.6f' %(testLoss/2000))

### Displaying or saving the results as well as the ground truth images for the first five images in the test set

with torch.no_grad():
    testiterator = iter(testloader)
    o, _ = testiterator.next()
    testout = net(o)

#Display
plt.imshow(torchvision.utils.make_grid(testout).numpy().transpose((1,2,0)))
plt.show()
plt.imshow(torchvision.utils.make_grid(o).numpy().transpose((1,2,0)))
plt.show()

#save
torchvision.utils.save_image(torch.cat((testout,o)), "EncDec.png", nrow=5)

Files already downloaded and verified
Files already downloaded and verified
For 0 epoch,  5000 minibatch, average loss is 0.01871
For 0 epoch, 10000 minibatch, average loss is 0.00689
For 1 epoch,  5000 minibatch, average loss is 0.00537
For 1 epoch, 10000 minibatch, average loss is 0.00471
For 2 epoch,  5000 minibatch, average loss is 0.00439
For 2 epoch, 10000 minibatch, average loss is 0.00421
For 3 epoch,  5000 minibatch, average loss is 0.00414
For 3 epoch, 10000 minibatch, average loss is 0.00409
For 4 epoch,  5000 minibatch, average loss is 0.00404
For 4 epoch, 10000 minibatch, average loss is 0.00404
Test Set Loss 0.003993


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

In [2]:
net.conv1

Conv2d(3, 4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))

In [4]:
input.shape

torch.Size([5, 3, 32, 32])