In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader
import torchvision
from PIL import Image
import h5py
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import math
from tqdm.auto import tqdm
from torchvision.transforms.functional import to_pil_image
import random
import os
from skimage.color import rgb2ycbcr
from http.server import HTTPServer, BaseHTTPRequestHandler

In [None]:
device = torch.device("mps")
print('Device:', device)

In [None]:
def psnr(label, outputs, max_val=255.):
    label = rgb2ycbcr(label)[:, :, 0]*1
    outputs = rgb2ycbcr(outputs)[:, :, 0]*1
    img_diff = outputs - label
    rmse = math.sqrt(np.mean((img_diff)**2))
    if rmse == 0:
        return 100
    else:
        psnr = 20 * math.log10(max_val/rmse)
        return psnr

In [None]:
class MyHandler(BaseHTTPRequestHandler):
    def do_GET(self):
        if 'HTTP/1.0' in self.request_version:
            self.send_response(400)
            self.end_headers()
            self.wfile.write(b"400 Bad Request")
            return
        elif self.path.endswith(".html"):
            try:
                f = open(self.path[1:])
                self.send_response(200)
                self.send_header('Content-Type', 'text/html')
                self.send_header('Content-Length', '1024')
                self.end_headers()
                self.wfile.write(bytes(f.read()))
                f.close()
            except:
                self.send_response(404)
                self.end_headers()
                self.wfile.write(b"404 Not Found")
        elif self.path.endswith(".jpg"):
            try:
                f = open('image.jpg', 'rb')
                self.send_response(200)
                self.send_header('Content-Type', 'image/jpg')
                self.send_header('Content-Length', '1024')
                self.end_headers()
                self.wfile.write(f.read())
                f.close()
                return
            except:
                self.send_response(404)
                self.end_headers()
                self.wfile.write(b"404 Not Found")
        else:
            self.send_response(404)
            self.send_header('Content-Type', 'text/html')
            self.end_headers()
            self.wfile.write(b"404 Not Found")
        return
    
if __name__ == '__main__':
    server = HTTPServer(('', 8888), MyHandler)
    print('Started WebServer on port 8888')
    print('Press Ctrl + c to quit webserver')
    server.serve_forever()

In [None]:
class VDSR(nn.Module):
  def __init__(self):
    super(VDSR, self).__init__()
    self.input = nn.Conv2d(3, 64 ,3, padding=1, bias=False)
    self.conv = nn.ModuleList([
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
      nn.Conv2d(64, 64, 3, padding=1, bias=False),
    ])
    self.output = nn.Conv2d(64, 3, 3, padding=1, bias=False)

    # for m in self.modules():
    #   if isinstance(m, nn.Conv2d):
    #       n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    #       m.weight.data.normal_(0, math.sqrt(2. / n))

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        torch.nn.init.kaiming_uniform_(m.weight)

  def forward(self, data):
    x = F.relu(self.input(data))
    for conv_layer in self.conv:
      x = F.relu(conv_layer(x))
    x = self.output(x)
    return torch.add(x, data)

In [None]:
import torchvision.transforms as transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(p=0.5),
])

eval_transform = transforms.Compose([
    transforms.ToTensor(),
])

def rotation(image):
    angle = random.choice([0, 90, 180, 270])
    return transforms.functional.rotate(image, angle)

toPIL = torchvision.transforms.ToPILImage()

In [None]:
class TrainDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_list = [file for file in os.listdir(root_dir)]

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image_name = os.path.join(self.root_dir, self.image_list[idx])
        image = Image.open(image_name)

        if self.transform:
          image = self.transform(image)
          crop_size = random.randrange(41, 71)
          image = transforms.RandomCrop((crop_size, crop_size))(image)
          image = transforms.Resize((41, 41), antialias=True)(image)
          image = rotation(image)

        label = image
        scale = random.randrange(2, 5)
        scale = 2
        image_pil = toPIL(image)
        image_pil = image_pil.resize((41//scale, 41//scale), Image.BICUBIC)
        image_pil = image_pil.resize((41, 41), Image.BICUBIC)
        image = transforms.ToTensor()(image_pil)
        return image, label

class EvalDataset(Dataset):
    def __init__(self, root_dir, transform=None, transform_input=None):
        self.root_dir = root_dir
        self.transform = transform
        self.transform_input = transform_input
        self.image_list = [file for file in os.listdir(root_dir) if file.endswith('.png')]

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image_name = os.path.join(self.root_dir, self.image_list[idx])
        image = Image.open(image_name)
        label = image

        if self.transform:
          label = self.transform(label)
          image = self.transform(image)
          image_pil = toPIL(image)
          image_pil = image_pil.resize((len(label[0][0])//2, len(label[0])//2), Image.BICUBIC)
          image_pil = image_pil.resize((len(label[0][0]), len(label[0])), Image.BICUBIC)
          rl = transforms.ToTensor()(image_pil)

        return rl, label

In [None]:
train_dataset = TrainDataset('/home/intern/datasets/291', transform=transform)

# crop_data = []
# for data, label in tqdm(train_dataset):
#   for i in range(len(data)):
#     crop_data.append((data[i], label[i]))

# del train_dataset
# data_np = np.array(crop_data)
# del crop_data
# np.save("/home/intern/datasets/data_prepared", data_np)

In [None]:
# data_prepared_np = np.load('/home/intern/datasets/data_prepared.npy') torch.from_numpy(data_prepared_np)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=2)

In [None]:
eval_dataset = EvalDataset('/Users/hyeon-seongkim/Computer_Vision/dataset/Set5', transform=eval_transform)
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

In [None]:
i=0

In [None]:
img, target = eval_dataset[0]
img_pil = toPIL((img))
target_pil = toPIL((target))
plt.figure()
plt.subplot(1,2,1)
plt.imshow(img_pil, cmap='gray')
plt.title('train')
plt.subplot(1,2,2)
plt.imshow(target_pil, cmap='gray')
plt.title('target')

In [None]:
PATH = '/Users/hyeon-seongkim/Computer_Vision/models/'

In [None]:
criterion = nn.MSELoss()
net = VDSR().to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-4)
# checkpoint = torch.load('/content/gdrive/My Drive/Colab Notebooks/models/VDSR.pt')
# net.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch']
# loss = checkpoint['loss']
epochs = 10
losses = []
train_psnr = []
net.train()
eval_psnr = []
for epoch in tqdm(range(epochs)):

  running_loss = 0.0
  running_psnr = 0.0
  net.zero_grad()
  for sub_ep in tqdm(range(2500)):
    for data in train_dataloader:
      inputs, labels = data 
      inputs = inputs.to(device)
      labels = labels.to(device)

      preds = net(inputs)
      loss = criterion(preds, labels)
      running_loss += loss.item()
      
      label_np = np.array(toPIL(labels[0].squeeze(dim=0)))/255.
      pred_np = np.array(toPIL(torch.clamp(preds[0].squeeze(dim=0), 0, 1)))/255.
      running_psnr += psnr(label_np, pred_np)
      
      optimizer.zero_grad()

      loss.backward()

      optimizer.step()
    
    sum = 0.0  
  for data in eval_dataloader:
    input, labels = data
    input = input.to(device)
    labels = labels.to(device)
    with torch.no_grad():
      preds = net(input)
      pred_np = np.array(toPIL(torch.clamp(preds.squeeze(dim=0), 0, 1)))/255.
      label_np = np.array(toPIL(labels.squeeze(dim=0)))/255.
      sum += psnr(pred_np, label_np)
  eval_psnr.append(sum/5)
  
  losses.append(running_loss/len(train_dataloader)/2500)
  train_psnr.append(running_psnr/len(train_dataloader)/2500)
  print("Epoch: ", epoch, " Loss: ", losses[epoch], " PSNR: ", train_psnr[epoch])
  # if(epoch%5 == 0):
  #   torch.save({
  #           'epoch': epoch,
  #           'model_state_dict': net.state_dict(),
  #           'optimizer_state_dict': optimizer.state_dict(),
  #           'loss': loss,
  #           }, PATH+'VDSR_ep_%d.pt'%epoch)

In [None]:
plt.figure()
plt.subplot(1,3,1)
plt.plot(range(epochs), losses)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.subplot(1,3,2)
plt.plot(range(epochs), train_psnr)
plt.ylabel('train_PSNR')
plt.xlabel('Epoch')
plt.subplot(1, 3, 3)
plt.plot(range(epochs), eval_psnr)
plt.ylabel('eval_PSNR')
plt.xlabel('Epoch')
plt.show()
plt.savefig("train_loss_psnr.png")

In [None]:
net = VDSR().to(device)
checkpoint = torch.load(PATH+'VDSR_ep_79.pt')
net.load_state_dict(checkpoint['model_state_dict'])
print("Epoch: ", checkpoint['epoch'])
eval_psnrs = []
net.eval()
i=0

for data in eval_dataloader:
  input, labels = data
  input = input.to(device)
  labels = labels.to(device)

  with torch.no_grad():
    preds = net(input)
    pred_np = np.array(toPIL(torch.clamp(preds.squeeze(dim=0), 0, 1)))/255.
    label_np = np.array(toPIL(labels.squeeze(dim=0)))/255.
    eval_psnrs.append(psnr(pred_np, label_np))
    plt.figure()
    plt.subplot(1,3,1)
    plt.imshow(to_pil_image(input.squeeze(dim=0)), cmap='gray')
    plt.title('Bicubic')
    plt.subplot(1,3,2)
    plt.imshow(to_pil_image(labels.squeeze(dim=0)), cmap='gray')
    plt.title('ground truth')
    plt.subplot(1,3,3)
    plt.imshow(to_pil_image(torch.clamp(preds.squeeze(dim=0), 0, 1)), cmap='gray')
    plt.title('VDSR')

    print("PSNR: ", eval_psnrs[i])
    i += 1
    
sum = 0
for i in eval_psnrs:
  sum += i
print("Average: ", sum/5)