In [1]:
import os
import random
import numpy as np
from PIL import Image
from copy import deepcopy

In [2]:
class Augumentation:
    def __init__(self):
        self.lrts = './data/Train/LR/'
        self.hrts = './data/Train/HR/'
        self.lsize = 128
        self.hsize = 512
        self.lrt = deepcopy(sorted(os.listdir(self.lrts)))
        self.hrt = deepcopy(sorted(os.listdir(self.hrts)))
        self.randomlist = deepcopy(np.arange(0, len(self.lrt)))
        self.length = len(deepcopy(self.lrt))
        self.N = len(deepcopy(self.lrt))
        self.callFunctions()
    
    def callFunctions(self):
        self.VerticalFlip()
        self.HoriZontalFlip()
        self.rotate270()
        self.RGBBlend()
        self.cutBlur()
        
        
    def VerticalFlip(self):
        ''' Function to perform Vertical Flip on Random Train Images'''
        randomSelect = sorted(random.sample(list(self.randomlist), 50))
        self.randomlist = list(set(self.randomlist) - set(randomSelect))
        for i, imPath in enumerate(self.lrt):
            if i in randomSelect:
                imageLR = Image.open(self.lrts+imPath).convert('RGB')
                imageLR = imageLR.transpose(Image.FLIP_TOP_BOTTOM)
                newLRPath = os.path.join(self.lrts, f'{self.length:04d}.png')
                imageLR = imageLR.save(newLRPath)
                imageHR = Image.open(self.hrts+imPath).convert('RGB')
                imageHR = imageHR.transpose(Image.FLIP_TOP_BOTTOM)
                newHRPath = os.path.join(self.hrts, f'{self.length:04d}.png')
                imageHR = imageHR.save(newHRPath)
                self.length+=1
        
        
    def HoriZontalFlip(self):
        ''' Function to perform Vertical Flip on Random Train Images'''
        randomSelect = sorted(random.sample(list(self.randomlist), 50))
        self.randomlist = list(set(self.randomlist) - set(randomSelect))
        for i, imPath in enumerate(self.lrt):
            if i in randomSelect:
                imageLR = Image.open(self.lrts+imPath).convert('RGB')
                imageLR = imageLR.transpose(Image.FLIP_LEFT_RIGHT)
                newLRPath = os.path.join(self.lrts, f'{self.length:04d}.png')
                imageLR = imageLR.save(newLRPath)
                imageHR = Image.open(self.hrts+imPath).convert('RGB')
                imageHR = imageHR.transpose(Image.FLIP_LEFT_RIGHT)
                newHRPath = os.path.join(self.hrts, f'{self.length:04d}.png')
                imageHR = imageHR.save(newHRPath)
                self.length+=1
                
    def rotate270(self):
        ''' Function to perform Vertical Flip on Random Train Images'''
        randomSelect = sorted(random.sample(list(self.randomlist), 50))
        self.randomlist = list(set(self.randomlist) - set(randomSelect))
        for i, imPath in enumerate(self.lrt):
            if i in randomSelect:
                imageLR = Image.open(self.lrts+imPath).convert('RGB')
                imageLR = imageLR.transpose(Image.ROTATE_270)
                newLRPath = os.path.join(self.lrts, f'{self.length:04d}.png')
                imageLR = imageLR.save(newLRPath)
                imageHR = Image.open(self.hrts+imPath).convert('RGB')
                imageHR = imageHR.transpose(Image.ROTATE_270)
                newHRPath = os.path.join(self.hrts, f'{self.length:04d}.png')
                imageHR = imageHR.save(newHRPath)
                self.length+=1
                
                
    def RGBBlend(self):
        ''' Function to perform Vertical Flip on Random Train Images'''
        rgblist = [0, 1, 2]
        randomSelect = sorted(random.sample(list(self.randomlist), 50))
        self.randomlist = list(set(self.randomlist) - set(randomSelect))
        for i, imPath in enumerate(self.lrt):
            if i in randomSelect:
                blendlist = random.sample(list(rgblist), 3)
                imageLR = Image.open(self.lrts+imPath).convert('RGB')
                imageLR = np.asarray(imageLR)
                imageLR = imageLR[:, :, blendlist]
                imageLR = Image.fromarray(np.uint8(imageLR)).convert('RGB')
                newLRPath = os.path.join(self.lrts, f'{self.length:04d}.png')
                imageLR = imageLR.save(newLRPath)
                imageHR = Image.open(self.hrts+imPath).convert('RGB')
                imageHR = np.asarray(imageHR)
                imageHR = imageHR[:, :, blendlist]
                imageHR = Image.fromarray(np.uint8(imageHR)).convert('RGB')
                newHRPath = os.path.join(self.hrts, f'{self.length:04d}.png')
                imageHR = imageHR.save(newHRPath)
                self.length+=1
                
    def cutblurHelper(self, im1, im2, prob=1.0, alpha=1):

        '''
        Input
        im1 - hr
        im2 - lr

        output

        im1 - unmodified hr
        im2 - modified lr 
        '''
        if im1.shape != im2.shape:
            raise ValueError("im1 and im2 have to be the same resolution.")

        # if alpha <= 0 or np.random.rand(1) >= prob:
        #     return im1, im2

        cut_ratio = np.random.randn() * 0.01 + alpha
        h, w = im2.shape[0], im2.shape[1]
        ch, cw = int(h*cut_ratio), int(w*cut_ratio)
        cy = np.random.randint(0, h-ch+1)
        cx = np.random.randint(0, w-cw+1)

        # apply CutBlur to inside or outside
        if np.random.random() > 0.5:
            im2_aug = im2.copy()
            im2_aug[cy:cy+ch, cx:cx+cw, :] = im1[cy:cy+ch, cx:cx+cw, :]
            im2 = im2_aug
        else:
            im2_aug = im1.copy()
            im2_aug[cy:cy+ch, cx:cx+cw,:] = im2[cy:cy+ch, cx:cx+cw,:]
            im2 = im2_aug
        return im2
    
    def cutBlur(self):
        randomSelect = sorted(random.sample(list(self.randomlist), 100))
        self.randomlist = list(set(self.randomlist) - set(randomSelect))
        for i, imPath in enumerate(self.lrt):
            if i in randomSelect:
                imageLR = Image.open(self.lrts+imPath).convert('RGB')
                imageLR = imageLR.resize((512,512),Image.BILINEAR)
                imageHR = Image.open(self.hrts+imPath).convert('RGB')
                imageLR = np.copy(imageLR)
                imageHR = np.copy(imageHR)
                imageLR = self.cutblurHelper(imageHR, imageLR, alpha=0.6)
                imageLR = Image.fromarray(np.uint8(imageLR)).convert('RGB')
                imageLR = imageLR.resize((self.lsize, self.lsize))
                imageHR = Image.fromarray(np.uint8(imageHR)).convert('RGB')
                newLRPath = os.path.join(self.lrts, f'{i:04d}.png')
                imageLR = imageLR.save(newLRPath)
                newHRPath = os.path.join(self.hrts, f'{i:04d}.png')
                imageHR = imageHR.save(newHRPath)
                self.length+=1
        
        

In [3]:
ag = Augumentation()