In [0]:
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import STL10
import os
import numpy as np
import torch.nn.functional as F
from torch import optim
import cv2

In [4]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        
        self.conv1=nn.Conv2d(3,64,9,stride=1,padding=4)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.conv2=nn.Conv2d(64,32,1,stride=1,padding=0)
        self.bn2 = nn.BatchNorm2d(32)
       
        self.conv3=nn.Conv2d(32,3,5,stride=1,padding=2)
        
       


    def forward(self, x):
        x = F.relu(self.conv1(x))
        x=self.bn1(x)
        
        x = F.relu(self.conv2(x))
        x=self.bn2(x)
        
        x=self.conv3(x)
        return x

In [0]:
def downupscale(img):
  temp=img.numpy().copy()
  B,C,H,W=img.size()
  
  out=np.empty((batch_size,C,H,W))
  for i in range(batch_size):
    a=temp[i]
    c=a.transpose((1,2,0))
    small = cv2.resize(c, (int(H/3),int(W/3)), interpolation = cv2.INTER_CUBIC)
    reconstruct = cv2.resize(small, (H,W), interpolation = cv2.INTER_CUBIC)
    
    #b=np.expand_dims(b, axis=2)
    
    out[i]=reconstruct.transpose((2,0,1))
  return torch.from_numpy(out)

In [23]:
if not os.path.exists('./gdrive/My Drive/Colab Notebooks/srcnn'):
    os.mkdir('./gdrive/My Drive/Colab Notebooks/srcnn')

    

num_epochs = 100
batch_size = 32
learning_rate = 1e-4

img_transform = transforms.Compose([#transforms.Resize(96),
    transforms.ToTensor()
    
])

dataset = STL10('./data',  split='unlabeled',transform=img_transform,download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)



model = SRCNN()
model.cuda()

from collections import OrderedDict

'''
state_dict = torch.load('./gdrive/My Drive/Colab Notebooks/vae.pth')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k
    #name = k[7:] # remove "module."
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)

if torch.cuda.is_available():
    model.cuda()
'''

loss_function = nn.MSELoss()



optimizer = optim.Adam(model.parameters(), lr=learning_rate)



Files already downloaded and verified


In [28]:
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(dataloader):
        img, _ = data
        LR=downupscale(img)
        
        #noised_img=noised_img.type(torch.FloatTensor)
        LR = Variable(LR)
        img=Variable(img)
        
        LR=LR.type(torch.FloatTensor)
        if torch.cuda.is_available():
            LR = LR.cuda()
            img=img.cuda()
            
        optimizer.zero_grad()
        HR  = model(LR)
        loss = loss_function(HR, img)
        loss.backward()
        train_loss += loss.data.item()
        optimizer.step()
        if batch_idx % 500 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch,
                batch_idx * len(img),
                len(dataloader.dataset), 100. * batch_idx / len(dataloader),
                loss.data.item() / len(img)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(dataloader.dataset)))
    if epoch % 5 == 0:
        img=img.cpu().data
        LR=LR.cpu().data
        HR=HR.cpu().data
        HR.clamp(0,1)
        save_image(img, './gdrive/My Drive/Colab Notebooks/srcnn/GT_{}.png'.format(epoch))
        save_image(LR, './gdrive/My Drive/Colab Notebooks/srcnn/LR_{}.png'.format(epoch))
        save_image(HR, './gdrive/My Drive/Colab Notebooks/srcnn/HR_{}.png'.format(epoch))

        torch.save(model.state_dict(), './gdrive/My Drive/Colab Notebooks/srcnn/srcnn.pth')

====> Epoch: 0 Average loss: 0.0002
====> Epoch: 1 Average loss: 0.0002
====> Epoch: 2 Average loss: 0.0002
====> Epoch: 3 Average loss: 0.0002
====> Epoch: 4 Average loss: 0.0002
====> Epoch: 5 Average loss: 0.0002
====> Epoch: 6 Average loss: 0.0002
====> Epoch: 7 Average loss: 0.0002
====> Epoch: 8 Average loss: 0.0002
====> Epoch: 9 Average loss: 0.0002
====> Epoch: 10 Average loss: 0.0002
====> Epoch: 11 Average loss: 0.0002
====> Epoch: 12 Average loss: 0.0002
====> Epoch: 13 Average loss: 0.0002
====> Epoch: 14 Average loss: 0.0002
====> Epoch: 15 Average loss: 0.0002
====> Epoch: 16 Average loss: 0.0002
====> Epoch: 17 Average loss: 0.0002
====> Epoch: 18 Average loss: 0.0002
====> Epoch: 19 Average loss: 0.0002
====> Epoch: 20 Average loss: 0.0002
====> Epoch: 21 Average loss: 0.0002
====> Epoch: 22 Average loss: 0.0002
====> Epoch: 23 Average loss: 0.0002
====> Epoch: 24 Average loss: 0.0002
====> Epoch: 25 Average loss: 0.0002
====> Epoch: 26 Average loss: 0.0002
====> Epoch