<a href="https://colab.research.google.com/github/youngG124/infrared-image-colorize-with-GAN-model/blob/main/%EC%8B%9C%EC%97%B0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import glob
import numpy as np
import datetime
import matplotlib.pyplot as plt
from PIL import Image

from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
from torch.autograd import Variable

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

# U-NET 생성

class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_size))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super(UNetUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(out_size),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))

        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)

        return x


class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(GeneratorUNet, self).__init__()
        
        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.5)
        self.down5 = UNetDown(512, 512, dropout=0.5)
        self.down6 = UNetDown(512, 512, dropout=0.5)
        self.down7 = UNetDown(512, 512, dropout=0.5)
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)

        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 512, dropout=0.5)
        self.up4 = UNetUp(1024, 512, dropout=0.5)
        self.up5 = UNetUp(1024, 256)
        self.up6 = UNetUp(512, 128)
        self.up7 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(128, out_channels, 4, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        # U-Net generator with skip connections from encoder to decoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)

        return self.final(u7)

class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels * 2, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )

    def forward(self, img_A, img_B):
        # Concatenate image and condition image by channels to produce input
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

In [None]:
root = ''

n_epochs = 100
dataset_name = "Night449"
lr = 0.0002
b1 = 0.5                    # adam: decay of first order momentum of gradient
b2 = 0.999                  # adam: decay of first order momentum of gradient
decay_epoch = 100         # epoch from which to start lr decay
#n_cpu = 8                   # number of cpu threads to use during batch generation
channels = 3                # number of image channels
checkpoint_interval = 20    # interval between model checkpoints


batch_size = 12
img_height = 256
img_width = 256

# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100

# Calculate output of image discriminator (PatchGAN)
patch = (1, img_height // 2 ** 4, img_width // 2 ** 4)

# Initialize generator and discriminator
generator = GeneratorUNet()
discriminator = Discriminator()

cuda = True if torch.cuda.is_available() else False

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion_GAN.cuda()
    criterion_pixelwise.cuda()

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# Tensor type
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [None]:
generator.load_state_dict(torch.load("/content/drive/MyDrive/weights_with_new_mean_std/generator_47.pth", map_location=torch.device('cpu')))
discriminator.load_state_dict(torch.load("/content/drive/MyDrive/weights_with_new_mean_std/discriminator_47.pth", map_location=torch.device('cpu')))

<All keys matched successfully>

In [None]:
pip install ffmpeg-python

Collecting ffmpeg-python
  Downloading ffmpeg_python-0.2.0-py3-none-any.whl (25 kB)
Installing collected packages: ffmpeg-python
Successfully installed ffmpeg-python-0.2.0


In [None]:
import cv2
import os
import shutil


if os.path.isfile('output.mp4') :
  os.remove('output.mp4')

PREPROCESSING_DIR = '/content/drive/MyDrive/cut_images'
if os.path.exists(PREPROCESSING_DIR) :
  shutil.rmtree(PREPROCESSING_DIR)
  os.makedirs(PREPROCESSING_DIR)
else :
  os.makedirs(PREPROCESSING_DIR)

AFTER_DIR = '/content/drive/MyDrive/colored_cut_images'
if os.path.exists(AFTER_DIR) :
  shutil.rmtree(AFTER_DIR)
  os.makedirs(AFTER_DIR)
else :
  os.makedirs(AFTER_DIR)

  

###########################################
## 흑백 영상 이미지로 쪼개기

# import ffmpeg
# (
#     ffmpeg
#     .input("/content/drive/MyDrive/road.mp4")
#     .filter('fps', fps='24')
#     .output('/content/drive/MyDrive/cut_images/%04d.jpg', start_number = 0)
#     .overwrite_output()
#     .run()
# )


def vid_to_img(link):
    cam = cv2.VideoCapture(link)
    
    fps = cam.get(cv2.CAP_PROP_FPS)


    try:
        if not os.path.exists('/content/drive/MyDrive/cut_images'):
            os.makedirs('/content/drive/MyDrive/cut_images')

    except OSError:
        print('Error : Creating directory of data')

    currentframe = 0

    while(True) :

        ret, frame = cam.read()

        if (ret is True):

            # 비디오가 이미지 계속 생성, 이름은 네자리 정수
            name = '/content/drive/MyDrive/cut_images/' + str("%04d"% currentframe) + '.jpg'
            if (currentframe % 50 == 0) :
                print('전처리중...' + name)            

            # 추출된 이미지 쓰기
            cv2.imwrite(name, frame)

            # 다음프레임
            currentframe += 1

            if cv2.waitKey(1) > 0 :
                break
        else :
            break

    cam.release()
    cv2.destroyAllWindows()
    print("총 프레임 : " + str(currentframe))

    return int(fps)


video_fps = vid_to_img("/content/drive/MyDrive/g.mp4")

print("fps : " + str(video_fps))


#################################
# 저장된 흑백 이미지들 셋에 싣기

class NightTestDataset(Dataset):
    def __init__(self, root, color_transforms_=None, gray_transforms_=None):

        self.color_transforms = transforms.Compose(color_transforms_)
        self.gray_transforms = transforms.Compose(gray_transforms_)
        self.gray_files = sorted(glob.glob(os.path.join(root, '/content/drive/MyDrive/cut_images') + "/*.*"))
        self.color_files = sorted(glob.glob(os.path.join(root, '/content/drive/MyDrive/cut_images') + "/*.*"))
     
    def __getitem__(self, index):
        gray_img = Image.open(self.gray_files[index % len(self.gray_files)]).convert("RGB")
        color_img = Image.open(self.color_files[index % len(self.color_files)]).convert("RGB")
    
        gray_img = self.gray_transforms(gray_img)
        color_img = self.color_transforms(color_img)

        return {"A": gray_img}

    def __len__(self):
        return len(self.gray_files)




color_mean = [0.485, 0.456, 0.406]
color_std = [0.229, 0.224, 0.225]
gray_mean = [ 0.44,0.44,0.44]
gray_std = [ 0.26,0.26,0.26]

color_transforms_ = [
                     transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=color_mean, std=color_std),
]

gray_transforms_ = [
                     transforms.Resize((256,256)),
    transforms.ToTensor(),
                    transforms.Grayscale(num_output_channels=3),
    transforms.Normalize(mean=gray_mean, std=gray_std),
]



test_root = root + '/content/drive/MyDrive/cut_images'
test_batch_size = 6

test_loader = DataLoader(
    NightTestDataset(test_root, color_transforms_=color_transforms_, gray_transforms_=gray_transforms_),
    batch_size=test_batch_size,
    shuffle=False
)


def reNormalize(img, mean, std):
    img = img.numpy().transpose(1, 2, 0)
    img = img * std + mean
    img = img.clip(0, 1)
    return img




#################################
# eval모드, 이미지 채색해서 저장

generator.eval()
discriminator.eval()

from skimage import *
from skimage import io

def test_images(epoch, loader, mode):

    c = 0
    for i, batch in enumerate(loader) :
       gray = Variable(batch["A"].type(Tensor))
       output = generator(gray)
       for j in range(len(output)) :
          out_img = np.transpose(output.cpu().data.numpy()[j],(1,2,0))
          io.imsave("/content/drive/MyDrive/colored_cut_images/" + "%04d"%(c*6+j)  + ".jpg" , img_as_ubyte(out_img))
          if (c*6+j % 50 == 0) :
                print('후처리중...' + str(c*6+j) +'/'+ )    
       c += 1

test_images(n_epochs, test_loader, 'test')


##############################
# 채색된 이미지들 mp4로 output

import ffmpeg
(
    ffmpeg
    .input('/content/drive/MyDrive/colored_cut_images/*.jpg', pattern_type='glob', framerate=video_fps)
    .output('output.mp4')
    .run()
)


전처리중.../content/drive/MyDrive/cut_images/0000.jpg
전처리중.../content/drive/MyDrive/cut_images/0050.jpg
전처리중.../content/drive/MyDrive/cut_images/0100.jpg
전처리중.../content/drive/MyDrive/cut_images/0150.jpg
전처리중.../content/drive/MyDrive/cut_images/0200.jpg
전처리중.../content/drive/MyDrive/cut_images/0250.jpg
전처리중.../content/drive/MyDrive/cut_images/0300.jpg
전처리중.../content/drive/MyDrive/cut_images/0350.jpg
전처리중.../content/drive/MyDrive/cut_images/0400.jpg
전처리중.../content/drive/MyDrive/cut_images/0450.jpg
전처리중.../content/drive/MyDrive/cut_images/0500.jpg
전처리중.../content/drive/MyDrive/cut_images/0550.jpg
전처리중.../content/drive/MyDrive/cut_images/0600.jpg
전처리중.../content/drive/MyDrive/cut_images/0650.jpg
전처리중.../content/drive/MyDrive/cut_images/0700.jpg
전처리중.../content/drive/MyDrive/cut_images/0750.jpg
전처리중.../content/drive/MyDrive/cut_images/0800.jpg
전처리중.../content/drive/MyDrive/cut_images/0850.jpg
전처리중.../content/drive/MyDrive/cut_images/0900.jpg
전처리중.../content/drive/MyDrive/cut_images/0950.jpg


(None, None)