In [1]:
import torch
import os
import glob
import pickle
import torch.nn.functional as F
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torch import nn
from skimage import io

In [2]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2)
        self.bn1 = nn.BatchNorm2d(32)

        self.av2 = nn.AvgPool2d(kernel_size=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        self.av3 = nn.AvgPool2d(kernel_size=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)

        self.av4 = nn.AvgPool2d(kernel_size=2)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(256)

        self.av5 = nn.AvgPool2d(kernel_size=2)
        self.conv5 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(256)

        self.un6 = nn.UpsamplingNearest2d(scale_factor=2)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.bn6 = nn.BatchNorm2d(256)

       
        self.un7 = nn.UpsamplingNearest2d(scale_factor=2)
        self.conv7 = nn.Conv2d(256 * 2, 128, kernel_size=3, stride=1, padding=1)
        self.bn7 = nn.BatchNorm2d(128)
        
        self.un8 = nn.UpsamplingNearest2d(scale_factor=2)
        self.conv8 = nn.Conv2d(128 * 2, 64, kernel_size=3, stride=1, padding=1)
        self.bn8 = nn.BatchNorm2d(64)

        self.un9 = nn.UpsamplingNearest2d(scale_factor=4)
        self.conv9 = nn.Conv2d(64 * 2, 32, kernel_size=3, stride=1, padding=1)
        self.bn9 = nn.BatchNorm2d(32)

        self.conv10 = nn.Conv2d(32 * 2, 3, kernel_size=5, stride=1, padding=2)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x1 = F.relu(self.bn1(self.conv1(x)), inplace=True)
        x2 = F.relu(self.bn2(self.conv2(self.av2(x1))), inplace=True)
        x3 = F.relu(self.bn3(self.conv3(self.av3(x2))), inplace=True)
        x4 = F.relu(self.bn4(self.conv4(self.av4(x3))), inplace=True)
        x = F.relu(self.bn5(self.conv5(self.av5(x4))), inplace=True)
        x = F.relu(self.bn6(self.conv6(self.un6(x))), inplace=True)
        x = torch.cat([x, x4], dim=1)
        x = F.relu(self.bn7(self.conv7(self.un7(x))), inplace=True)
        x = torch.cat([x, x3], dim=1)
        x = F.relu(self.bn8(self.conv8(self.un8(x))), inplace=True)
        x = torch.cat([x, x2], dim=1)
        x = F.relu(self.bn9(self.conv9(self.un9(x))), inplace=True)
        x = torch.cat([x, x1], dim=1)
        x = self.tanh(self.conv10(x))
        return x

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2)
        self.in1 = nn.InstanceNorm2d(16)

        self.av2 = nn.AvgPool2d(kernel_size=2)
        self.conv2_1 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.in2_1 = nn.InstanceNorm2d(32)
        self.conv2_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.in2_2 = nn.InstanceNorm2d(32)

        self.av3 = nn.AvgPool2d(kernel_size=2)
        self.conv3_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.in3_1 = nn.InstanceNorm2d(64)
        self.conv3_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.in3_2 = nn.InstanceNorm2d(64)

        self.av4 = nn.AvgPool2d(kernel_size=2)
        self.conv4_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.in4_1 = nn.InstanceNorm2d(128)
        self.conv4_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.in4_2 = nn.InstanceNorm2d(128)

        self.av5 = nn.AvgPool2d(kernel_size=2)
        self.conv5_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.in5_1 = nn.InstanceNorm2d(256)
        self.conv5_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.in5_2 = nn.InstanceNorm2d(256)

        self.av6 = nn.AvgPool2d(kernel_size=2)
        self.conv6 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.in6 = nn.InstanceNorm2d(512)

        self.conv7 = nn.Conv2d(512, 1, kernel_size=1)

    def forward(self, x):      
        x = F.leaky_relu(self.in1(self.conv1(x)), 0.2, inplace=True)
        x = F.leaky_relu(self.in2_1(self.conv2_1(self.av2(x))), 0.2, inplace=True)
        x = F.leaky_relu(self.in2_2(self.conv2_2(x)), 0.2, inplace=True)
        x = F.leaky_relu(self.in3_1(self.conv3_1(self.av3(x))), 0.2, inplace=True)
        x = F.leaky_relu(self.in3_2(self.conv3_2(x)), 0.2, inplace=True)
        x = F.leaky_relu(self.in4_1(self.conv4_1(self.av4(x))), 0.2, inplace=True)
        x = F.leaky_relu(self.in4_2(self.conv4_2(x)), 0.2, inplace=True)
        x = F.leaky_relu(self.in5_1(self.conv5_1(self.av5(x))), 0.2, inplace=True)
        x = F.leaky_relu(self.in5_2(self.conv5_2(x)), 0.2, inplace=True)
        x = F.leaky_relu(self.in6(self.conv6(self.av6(x))), 0.2, inplace=True)
        x = self.conv7(x)

        return x

In [4]:
g, r = Generator(), Discriminator()
test_imgs = torch.randn([2,3,128,128])
test_imgs = g(test_imgs)
test_res =  r(test_imgs)

print("Generator_output", test_imgs.size())
print("Discriminator_output", test_res.size())

Generator_output torch.Size([2, 3, 128, 128])
Discriminator_output torch.Size([2, 1, 4, 4])


In [5]:
class DataAugment():
    def __init__(self, resize):
        self.data_tramsform = transforms.Compose([
            transforms.RandomResizedCrop(resize, scale=(0.9, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip()
        ])
    
    def __call__(self, img):
        return self.data_tramsform(img)

In [6]:
class MonoColorDataset(data.Dataset):
    def __init__(self, file_list, transform_tensor, augment=None):
        self.file_list = file_list
        self.augment = augment
        self.transform_tensor = transform_tensor
        
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, index):
        img_path = self.file_list[index]
        img = Image.open(img_path)
        img = img.convert("RGB")
        
        if self.augment is not None:
            img = self.augment(img)
        
        img_gray = img.copy()
        img_gray = transform_tensor.functional.to_grayscale(img_gray, num_output_channels=3)
        
        img = self.transform_tensor(img)
        img_gray = self.transform_tensor(img_gray)
        
        return img, img_gray

In [7]:
def load_train_dataloade(file_path, batch_size):
    size = 128
    mean = (0.5, 0.5, 0.5)
    std = (0.5, 0.5, 0.5)
    train_dataset = MonoColorDataset(file_path_train, transform=ImgTransform(size, mean, std), augment=DataAugment(size))
    train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    return train_dataloader

In [8]:
def mat_grid_imgs(imgs, nrow, save_path=None):
    imgs = torchvision.utils.make_grid(
        imgs[0:(nrow**2), :, :, :], nrow=nrow, padding=5)
    imgs = imgs.numpy().transpose([1,2,0])
    imgs -= np.min(imgs)
    imgs /= np.max(imgs)
    
    plt.imshow(imgs)
    plt.xticks([])
    plt.yticks([])
    plt.show([])
    
    if save_path is not None:
        io.imsave(save_path, imgs)

In [9]:
def evaluate_test(file_path_test, model_G, device="cuda:0", nrow=4):
    model_G = model_G.to(device)
    size = 128
    mean = (0.5, 0.5, 0.5)
    std = (0.5, 0.5, 0.5)
    test_dataset = MonoColorDataset(
        file_path_test,
        transform=ImgTransform(size, mean, std),
        augment = None
    )
    test_dataloader = data.DataLoader(test_dataset, batch_size=nrow**2, shuffle=False)
    
    for img, img_gray in test_dataloader:
        mat_grid_imgs(img_gray, nrow=nrow)
        img = img.to(device)
        img_gray = img_gray.to(device)
        img_fake = model_G(img_gray)
        img_fake = img_fake.to("cpu")
        img_fake = img_fake.datach()
        mat_grid_imgs(img_fake, nrow=nrow)

In [10]:
from skimage import io, color, transform

def color_mono(image, threshold=150):
    image_size = image.shape[0] * image.shape[1]
    
    diff = np.abs(np.sum(image[:,:,0] - image[:,:,1])) / image_size
    diff += np. abs(np.sum(image[:,:,0] - image[:,:,2])) / image_size
    diff += np.abs(np.sum(image[:,:,1] - image[:,:,2])) / image_size
    
    if diff > threshold:
        return "color"
    else:
        return "mono"

    def bright_check(image, ave_thres=0.15, std_thres=0.1):
        try:
            image = color.rgb2gray(image)
            
            if image.shape[0] < 144:
                return False
            if np.average(image) > (1.-ave_thres):
                return False
            if np.average(image) < ave_thres:
                return False
            if np.std(image) < std_thres:
                return False
            return True
        except:
            return False
    
    paths = glob.glob("./test2018/*")
    
    for i, path in enumerate(paths):
        image = io.imread(path)
        save_name = "./trans\\mscoco_" + str(i) + ".png"
        x = image.shape[0]
        y = image.shape[1]
        
        try:
            clip_half = min(x,y)/2
            image = image[int(x/2 - clip_half): int(x/2 + clip_half), int(y/2 - clip_half): int(y/2 + clip_half), :]
            
            if color_mono(image) == "color":
                if bright_check(image):
                    image = transform.resize(image, (144, 144, 3), anti_aliasing=True)
                    image = np.uint8(image*255)
                    io.imsave(save_name, image)
        except:
            pass

In [12]:
def train(model_G, model_D, epoch, epoch_plus):
    device = "cude:0"
    batch_size = 32
    
    model_G = model_G.to(device)
    model_D = model_D.to(device)
    
    params_G = torch.optim.Adam(model_G,parameters(), lr=0.0002, betas=(0.5, 0.999))
    params_D = torch.optim.Adam(model_D.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    true_labels = torch.ones(batch_size, 1,4,4).to(device)
    false_labels = torch.zeros(batch_size, 1,4,4).to(device)
    
    bce_loss = nn.BCEWithLogitsLoss()
    mae_loss = nn.L1Loss()
    
    log_loss_G_sum, log_loss_G_bce, log_loss_G_mae = list(), list(), list()
    log_loss_D = list()
    
    for i in range(epoch):
        loss_G_sum, loss_G_bce, loss_G_mae = list(), list(), list()
        loss_D = list()
        
        train_dataloader = load_train_dataloade(file_path_train, batch_size)
        
        for real_color, input_gray in train_dataloader:
            batch_len = len(real_color)
            real_color = real_color.to(device)
            inpiut_gray = input_gray.to(device)
            fake_color = model_G(input_gray)
            fake_color_tensor = fake_color.detach()
            LAMBD = 100.0
            out = model_D(fake_color)
            loss_G_bce_tmp = bce_loss(out)
            loss_G_mae_tmp = LAMBD * mae_loss(fake_color, real_color)
            loss_G_sum_tmp = loss_G_bce_tmp + loss_G_mae_tmp
            
            loss_G_bce.append(loss_G_bce_tmp.item())
            loss_G_mae.append(loss_G_mae_tmp.item())
            loss_G_sum.append(loss_G_sum_tmp.item())
            
            params_D.zero_grad()
            params_G.zero_grad()
            loss_G_sum_tmp.backward()
            params_G.step()
            
            real_out = model_D(real_color)
            fake_out = model_D(fake_color_tensor)
            
            loss_D_real = bce_loss(real_out, true_labels[:batch_len])
            loss_D_fake = bce_loss(fake_out, false_labels[:batch_len])
            
            loss_D_tmp = loss_D_real + loss_D_fake
            loss_D.append(loss_D_tmp.item())
            
            params_D.zero_grad()
            params_G.zero_grad()
            loss_D_tmp.backward()
            params_D.backward()
            params_D.step()

    i = i + epoch_plus
    print(i, "loss_G", np.mean(loss_G_sum), "loss_D", np.mean(loss_D))
    log_loss_G_sum.append(np.mean(loss_G_sum))
    log_loss_G_bce.append(np.mean(loss_G_bce))
    log_loss_G_mae.append(np.mean(loss_G_mae))
    log_loss_D.append(np.mean(loss_D))
    file_path_test = glob.glob("test/*")
    evaluate_test(file_path_test, model_G, device)
    return model_G, model_D, [log_loss_G_sum, log_loss_G_bce, log_loss_G_mae, log_loss_D]