In [None]:
import glob
import random
import matplotlib.pyplot as plt
from PIL import Image
import scipy.misc
import scipy.ndimage
import numpy as np
import os
import torch
from torchvision import transforms
from PIL import Image

In [None]:
def imsave(image, path):
  return scipy.misc.imsave(path, image)

def merge(images, size):
  h, w = images.shape[1], images.shape[2]
  img = np.zeros((h*size[0], w*size[1], 1))
  for idx, image in enumerate(images):
    i = idx % size[1]
    j = idx // size[1]
    img[j*h:j*h+h, i*w:i*w+w, :] = image

  return img

In [None]:
def modcrop(image, scale=3):
  """
  To scale down and up the original image, first thing to do is to have no remainder while scaling operation.
  
  We need to find modulo of height (and width) and scale factor.
  Then, subtract the modulo from height (and width) of original image size.
  There would be no remainder even after scaling operation.
  """
  if len(image.shape) == 3:
    h, w, _ = image.shape
    h = h - np.mod(h, scale)
    w = w - np.mod(w, scale)
    image = image[0:h, 0:w, :]
  else:
    h, w = image.shape
    h = h - np.mod(h, scale)
    w = w - np.mod(w, scale)
    image = image[0:h, 0:w]
  return image

In [None]:
def preprocess(path, scale=3):
  """
  Preprocess single image file 
    (1) Read original image as YCbCr format (and grayscale as default)
    (2) Normalize
    (3) Apply image file with bicubic interpolation

  Args:
    path: file path of desired file
    input_: image applied bicubic interpolation (low-resolution)
    label_: image with original resolution (high-resolution)
  """
  image = imread(path, is_grayscale=True)
  label_ = modcrop(image, scale)

  # Must be normalized
  input_ = label_ / 255.
  label_ = label_ / 255.

  #input_ = scipy.ndimage.interpolation.zoom(label_, (1./scale), prefilter=False)
  #input_ = scipy.ndimage.interpolation.zoom(input_, (scale/1.), prefilter=False)

  return input_, label_

In [None]:


def input_setup(config):
    """
    Read image files from a directory and create sub-images for training or testing.
    """
    # Load data path
    if config.is_train:
        data_path = os.path.join(config.dataset_path, "Train")
    else:
        data_path = os.path.join(config.dataset_path, "Test")

    image_files = [os.path.join(data_path, f) for f in os.listdir(data_path) if f.endswith(('.png', '.jpg', '.jpeg'))]

    sub_input_sequence = []
    sub_label_sequence = []
    padding = abs(config.image_size - config.label_size) // 2  # Typically 0 if sizes are the same

    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    if config.is_train:
        for img_file in image_files:
            input_, label_ = preprocess(img_file, config.scale, transform)

            if input_.ndim == 3:
                h, w, _ = input_.shape
            else:
                h, w = input_.shape

            for x in range(0, h - config.image_size + 1, config.stride):
                for y in range(0, w - config.image_size + 1, config.stride):
                    sub_input = input_[:, x:x + config.image_size, y:y + config.image_size]  # [3 x 33 x 33]
                    sub_label = label_[:, x + padding:x + padding + config.label_size, y + padding:y + padding + config.label_size]  # [3 x 33 x 33]

                    sub_input_sequence.append(sub_input)
                    sub_label_sequence.append(sub_label)

        # Shuffle
        combined = list(zip(sub_input_sequence, sub_label_sequence))
        random.shuffle(combined)
        sub_input_sequence, sub_label_sequence = zip(*combined)
        sub_input_sequence = torch.stack(sub_input_sequence)
        sub_label_sequence = torch.stack(sub_label_sequence)

    else:
        input_init, label_init = preprocess(image_files[2], config.scale, transform)  # Choose a specific image for testing

        if input_init.ndim == 3:
            h, w, _ = input_init.shape
        else:
            h, w = input_init.shape

        pad_h = config.image_size - divmod(h, config.image_size)[1]
        pad_w = config.image_size - divmod(w, config.image_size)[1]
        input_ = np.pad(input_init, ((0, pad_h), (0, pad_w), (0, 0)), 'symmetric')
        label_ = input_
        h = h + pad_h
        w = w + pad_w

        nx = ny = 0
        for x in range(0, h - config.image_size + 1, config.stride):
            nx += 1
            ny = 0
            for y in range(0, w - config.image_size + 1, config.stride):
                ny += 1
                sub_input = input_[:, x:x + config.image_size, y:y + config.image_size]  # [3 x 33 x 33]
                sub_label = label_[:, x + padding:x + padding + config.label_size, y + padding:y + padding + config.label_size]  # [3 x 33 x 33]

                sub_input_sequence.append(sub_input)
                sub_label_sequence.append(sub_label)

        sub_input_sequence = torch.stack(sub_input_sequence)
        sub_label_sequence = torch.stack(sub_label_sequence)

        return sub_input_sequence, sub_label_sequence, nx, ny, pad_h, pad_w

    return sub_input_sequence, sub_label_sequence




In [None]:
class Config:
    is_train = True
    dataset_path = './images/BSDS200'
    image_size = 33
    label_size = 33
    scale = 1
    stride = 14
    epoch = 2
    checkpoint_dir = './new_checkpoints'

config = Config()
train_inputs, train_labels = input_setup(config)

In [None]:
print(train_inputs.shape)
print(train_labels.shape)

In [None]:
import time
import torch
from torch.utils.data import DataLoader, TensorDataset
from torchvision.utils import save_image
from PIL import Image
import numpy as np

def train(model, config):
    # Prepare data
    if config.is_train:
        sub_input_sequence, sub_label_sequence = input_setup(config)
    else:
        sub_input_sequence, sub_label_sequence, nx, ny, pad_h, pad_w = input_setup(config)

    train_data = torch.stack(sub_input_sequence).to(model.device)
    train_label = torch.stack(sub_label_sequence).to(model.device)

    dataset = TensorDataset(train_data, train_label)
    data_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)

    # Training loop
    model.model.train()
    start_time = time.time()

    if model.load(config.checkpoint_dir):
        print(" [*] Load SUCCESS")
    else:
        print(" [!] Load failed...")

    if config.is_train:
        print("Training...")

        for ep in range(config.epoch):
            for idx, (batch_images, batch_labels) in enumerate(data_loader):
                model.optimizer.zero_grad()
                outputs = model.model(batch_images)
                loss = model.criterion(outputs, batch_labels)
                loss.backward()
                model.optimizer.step()

                if idx % 10 == 0:
                    print(f"Epoch: [{ep+1}/{config.epoch}], Step: [{idx}/{len(data_loader)}], "
                          f"Time: [{time.time() - start_time:.4f}], Loss: [{loss.item():.8f}]")

                if idx % 500 == 0:
                    model.save(config.checkpoint_dir, idx)

    else:
        print("Testing...")
        model.model.eval()
        with torch.no_grad():
            results = model.model(train_data)
            results = results.cpu()

        results = results.numpy()
        results = merge(results, [nx, ny])
        results = results.squeeze()

        # Change back to original size
        h, w = np.shape(results)
        results = results[0:(h - pad_h), 0:(w - pad_w)]
#         image_path = os.path.join(os.getcwd(), config.sample_dir, "test.png")
#         os.makedirs(os.path.dirname(image_path), exist_ok=True)
#         save_image(torch.tensor(results), image_path)

In [None]:
config = Config()
model = Model(config)
train(model, config)