# Colab-ICT

Original: [raywzy/ICT](https://github.com/raywzy/ICT)

My fork: [styler00dollar/Colab-ICT](https://github.com/styler00dollar/Colab-ICT)

In [None]:
!nvidia-smi

In [None]:
#@title install
!git clone https://github.com/raywzy/ICT
%cd ICT
!wget -O ckpts_ICT.zip https://www.dropbox.com/s/cqjgcj0serkbdxd/ckpts_ICT.zip?dl=1
!unzip ckpts_ICT.zip
!mkdir /content/input
!mkdir /content/masks
!mkdir -p /content/output

In [None]:
#@title Guided_Upsample/src/Guided_Upsampler.py (fixing with .cuda())
%%writefile /content/ICT/Guided_Upsample/src/Guided_Upsampler.py
import os
import numpy as np
import torch
from torch.utils.data import DataLoader
from .dataset_my import Dataset
from .models import  InpaintingModel
from .utils import Progbar, create_dir, stitch_images, imsave
from .metrics import PSNR
import torchvision.utils as vutils

class Guided_Upsampler():
    def __init__(self, config):
        self.config = config

        if config.MODEL == 1:
            model_name = 'edge'
        elif config.MODEL == 2:
            model_name = 'inpaint'

        self.debug = False
        self.model_name = model_name

        self.inpaint_model = InpaintingModel(config).cuda()

        self.psnr = PSNR(255.0).cuda()

        # test mode
        if self.config.MODE == 2:
            self.test_dataset = Dataset(config, config.TEST_FLIST, config.TEST_EDGE_FLIST, config.TEST_MASK_FLIST, augment=False, training=False)
        else:
            self.train_dataset = Dataset(config, config.TRAIN_FLIST, config.TRAIN_EDGE_FLIST, config.TRAIN_MASK_FLIST, augment=True, training=True)
            self.val_dataset = Dataset(config, config.VAL_FLIST, config.VAL_EDGE_FLIST, config.VAL_MASK_FLIST, augment=False, training=True)
            self.sample_iterator = self.val_dataset.create_iterator(config.SAMPLE_SIZE)

        self.samples_path = os.path.join(config.PATH, 'samples')
        self.results_path = os.path.join(config.PATH, 'results')

        if config.RESULTS is not None:
            self.results_path = os.path.join(config.RESULTS)

        if config.DEBUG is not None and config.DEBUG != 0:
            self.debug = True

        self.log_file = os.path.join(config.PATH, 'log_' + model_name + '.dat')

    def load(self):

        self.inpaint_model.load()
        
    def save(self):

        self.inpaint_model.save()

    def train(self):
        train_loader = DataLoader(
            dataset=self.train_dataset,
            batch_size=self.config.BATCH_SIZE,
            num_workers=4,
            drop_last=True,
            shuffle=True
        )

        epoch = 0
        keep_training = True
        model = self.config.MODEL
        max_iteration = int(float((self.config.MAX_ITERS)))
        total = len(self.train_dataset)

        if total == 0:
            print('No training data was provided! Check \'TRAIN_FLIST\' value in the configuration file.')
            return

        while(keep_training):
            epoch += 1
            print('\n\nTraining epoch: %d' % epoch)

            if self.config.No_Bar:
                pass
            else:
                progbar = Progbar(total, width=20, stateful_metrics=['epoch', 'iter'])

            for items in train_loader:
                self.inpaint_model.train()

                images, edges, masks = self.cuda(*items)

                # print(images.shape)
                # print(edges.shape)
                # print(masks.shape)

                if model == 2:
                    # train
                    outputs, gen_loss, dis_loss, logs = self.inpaint_model.process(images, edges, masks)
                    outputs_merged = (outputs * masks) + (images * (1 - masks))

                    # metrics
                    psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged))
                    mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float()
                    logs.append(('psnr', psnr.item()))
                    logs.append(('mae', mae.item()))

                    # backward
                    # self.inpaint_model.backward(gen_loss, dis_loss)
                    iteration = self.inpaint_model.iteration




                if iteration >= max_iteration:
                    keep_training = False
                    break

                logs = [
                    ("epoch", epoch),
                    ("iter", iteration),
                ] + logs

                if self.config.No_Bar:
                    pass
                else:
                    progbar.add(len(images), values=logs if self.config.VERBOSE else [x for x in logs if not x[0].startswith('l_')])

                # log model at checkpoints
                if self.config.LOG_INTERVAL and iteration % self.config.LOG_INTERVAL == 0:
                    self.log(logs)

                # sample model at checkpoints
                if self.config.SAMPLE_INTERVAL and iteration % self.config.SAMPLE_INTERVAL == 0:
                    self.sample()

                # evaluate model at checkpoints
                if self.config.EVAL_INTERVAL and iteration % self.config.EVAL_INTERVAL == 0:
                    print('\nstart eval...\n')
                    self.eval()

                # save model at checkpoints
                if self.config.SAVE_INTERVAL and iteration % self.config.SAVE_INTERVAL == 0:
                    self.save()

        print('\nEnd training....')

    def eval(self):
        val_loader = DataLoader(
            dataset=self.val_dataset,
            batch_size=self.config.BATCH_SIZE,
            drop_last=True,
            shuffle=False
        )

        model = self.config.MODEL
        total = len(self.val_dataset)

        self.inpaint_model.eval()

        if self.config.No_Bar:
            pass
        else:
            progbar = Progbar(total, width=20, stateful_metrics=['it'])
        iteration = 0

        for items in val_loader:
            iteration += 1
            images, edges, masks = self.cuda(*items)



            # inpaint model
            if model == 2:
                # eval
                outputs, gen_loss, dis_loss, logs = self.inpaint_model.process(images, edges, masks)
                outputs_merged = (outputs * masks) + (images * (1 - masks))

                # metrics
                psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged))
                mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float()
                logs.append(('psnr', psnr.item()))
                logs.append(('mae', mae.item()))


            logs = [("it", iteration), ] + logs
            if self.config.No_Bar:
                pass
            else:
                progbar.add(len(images), values=logs)

    def test(self):

        self.inpaint_model.eval()

        model = self.config.MODEL
        create_dir(self.results_path)

        test_loader = DataLoader(
            dataset=self.test_dataset,
            batch_size=self.config.test_batch_size,
        )

        index = 0
        for items in test_loader:

            name = self.test_dataset.load_name(index)
            
            print(name)
            
            if self.config.same_face:
                path = os.path.join(self.results_path, name)
            else:
                path = os.path.join(self.results_path, name[:-4]+"_%d"%(index%self.config.condition_num)+'.png')

            images, edges, masks = self.cuda(*items)
            index += self.config.test_batch_size

            # inpaint model
            if model == 2:
                outputs = self.inpaint_model(images, edges, masks)
                if self.config.merge:
                    outputs_merged = (outputs * masks) + (images * (1 - masks))
                else:
                    outputs_merged = outputs

            if self.config.same_face:
                all_tensor=[images,edges,images * (1 - masks),outputs_merged]
                all_tensor=torch.cat(all_tensor,dim=0)
                vutils.save_image(all_tensor,path,nrow=self.config.test_batch_size,padding=0,normalize=False)
                print(index, name)
            else:
                output = self.postprocess(outputs_merged)[0]
                print(index, name)
                imsave(output, path)

            if self.debug:
                edges = self.postprocess(1 - edges)[0]
                masked = self.postprocess(images * (1 - masks) + masks)[0]
                fname, fext = name.split('.')

                imsave(edges, os.path.join(self.results_path, fname + '_edge.' + fext))
                imsave(masked, os.path.join(self.results_path, fname + '_masked.' + fext))

        print('\nEnd test....')

    def sample(self, it=None):
        # do not sample when validation set is empty
        if len(self.val_dataset) == 0:
            return

        self.inpaint_model.eval()

        model = self.config.MODEL
        items = next(self.sample_iterator)
        images, edges, masks = self.cuda(*items)


        # inpaint model
        if model == 2:
            iteration = self.inpaint_model.iteration
            inputs = (images * (1 - masks)) + masks
            outputs = self.inpaint_model(images, edges, masks)
            outputs_merged = (outputs * masks) + (images * (1 - masks))

        if it is not None:
            iteration = it

        image_per_row = 2
        if self.config.SAMPLE_SIZE <= 6:
            image_per_row = 1

        images = stitch_images(
            self.postprocess(images),
            self.postprocess(inputs),
            self.postprocess(edges),
            self.postprocess(outputs),
            self.postprocess(outputs_merged),
            img_per_row = image_per_row
        )


        path = os.path.join(self.samples_path, self.model_name)
        name = os.path.join(path, str(iteration).zfill(5) + ".png")
        create_dir(path)
        print('\nsaving sample ' + name)
        images.save(name)

    def log(self, logs):
        with open(self.log_file, 'a') as f:
            f.write('%s\n' % ' '.join([str(item[1]) for item in logs]))

    def cuda(self, *args):
        return (item.to(self.config.DEVICE) for item in args)

    def postprocess(self, img):
        # [0, 1] => [0, 255]
        img = img * 255.0
        img = img.permute(0, 2, 3, 1)
        return img.int()





In [None]:
#@title Transformer/inference.py (removing broken import, replacing with glob)
%%writefile /content/ICT/Transformer/inference.py
## Inference

import numpy as np
import torchvision
import torch
import matplotlib.pyplot as plt
import logging
from utils.util import set_seed
#from datas.dataset import ImageDataset
from models.model import GPTConfig,GPT
import argparse
from utils.util import sample_mask,sample_mask_all
from tqdm import tqdm
from PIL import Image
import os
import time

import glob

if __name__=='__main__':


    parser=argparse.ArgumentParser()
    parser.add_argument('--GPU_ids',type=str,default='0')
    parser.add_argument('--ckpt_path',type=str,default='./ckpt')
    parser.add_argument('--BERT',action='store_true', help='BERT model, Image Completion')
    parser.add_argument('--image_url',type=str,default='',help='the folder of image')
    parser.add_argument('--mask_url',type=str,default='',help='the folder of mask')
    parser.add_argument('--top_k',type=int,default=100)

    parser.add_argument('--image_size',type=int,default=32,help='input sequence length: image_size*image_size')

    parser.add_argument('--n_layer',type=int,default=14)
    parser.add_argument('--n_head',type=int,default=8)
    parser.add_argument('--n_embd',type=int,default=256)
    parser.add_argument('--GELU_2',action='store_true',help='use the new activation function')

    parser.add_argument('--save_url',type=str,default='./',help='save the output results')
    parser.add_argument('--n_samples',type=int,default=8,help='sample cnt')

    parser.add_argument('--sample_all',action='store_true',help='sample all pixel together, ablation use')
    parser.add_argument('--skip_number',type=int,default=0,help='since the inference is slow, skip the image which has been inferenced')

    parser.add_argument('--no_progressive_bar',action='store_true',help='')
    # parser.add_argument('--data_path',type=str,default='/home/ziyuwan/workspace/data/')

    opts=parser.parse_args()

    s_time=time.time()

    # model_config=GPTConfig(512,32*32,
    #                        embd_pdrop=0.0, resid_pdrop=0.0, 
    #                        attn_pdrop=0.0, n_layer=14, n_head=8,
    #                        n_embd=256,BERT=opts.BERT)

    model_config=GPTConfig(512,opts.image_size*opts.image_size,
                           embd_pdrop=0.0, resid_pdrop=0.0, 
                           attn_pdrop=0.0, n_layer=opts.n_layer, n_head=opts.n_head,
                           n_embd=opts.n_embd, BERT=opts.BERT, use_gelu2=opts.GELU_2)

    # Load model
    IGPT_model=GPT(model_config)
    checkpoint=torch.load(opts.ckpt_path)
    
    if opts.ckpt_path.endswith('.pt'):
        IGPT_model.load_state_dict(checkpoint)
    else:
        IGPT_model.load_state_dict(checkpoint['model'])

    IGPT_model.cuda()

    # Load clusters
    C = np.load('kmeans_centers.npy') ## [0,1]
    C = np.rint(127.5 * (C + 1.0))
    C = torch.from_numpy(C)

    n_samples=opts.n_samples

    """
    img_list=sorted(os.listdir(opts.image_url))
    mask_list=sorted(os.listdir(opts.mask_url))
    """
    img_list = sorted(glob.glob(opts.image_url + '/**/*.png', recursive=True))
    mask_list = sorted(glob.glob(opts.mask_url + '/**/*.png', recursive=True))


    # mask_list=mask_list[-len(img_list):]
    if opts.skip_number>0:
        img_list=img_list[opts.skip_number-1:]
        mask_list=mask_list[opts.skip_number-1:]
        print("Resume from %s"%(img_list[0]))


    if opts.BERT:

        for x_name,y_name in zip(img_list,mask_list):

            if x_name!=y_name:
                print("### Something Wrong ###")

            image_url=os.path.join(opts.image_url,x_name)
            input_image=Image.open(image_url).convert("RGB")
            x = input_image.resize((opts.image_size,opts.image_size),resample=Image.BILINEAR)
            x = torch.from_numpy(np.array(x)).view(-1, 3)
            x = x.float()
            a = ((x[:, None, :] - C[None, :, :])**2).sum(-1).argmin(1) # cluster assignments

            mask_url=os.path.join(opts.mask_url,y_name)
            input_mask=Image.open(mask_url).convert("L")
            y = input_mask.resize((opts.image_size,opts.image_size),resample=Image.NEAREST)
            y = torch.from_numpy(np.array(y)/255.).view(-1)
            y = y>0.5
            y = y.float()

            a_list=[a]*n_samples
            a_tensor=torch.stack(a_list,dim=0) ## Input images
            b_list=[y]*n_samples
            b_tensor=torch.stack(b_list,dim=0) ## Input masks
            a_tensor*=(1-b_tensor).long()

            if opts.sample_all:
                pixels=sample_mask_all(IGPT_model,context=a_tensor,length=opts.image_size*opts.image_size,num_sample=n_samples,top_k=opts.top_k,mask=b_tensor,no_bar=opts.no_progressive_bar)
            else:
                pixels=sample_mask(IGPT_model,context=a_tensor,length=opts.image_size*opts.image_size,num_sample=n_samples,top_k=opts.top_k,mask=b_tensor,no_bar=opts.no_progressive_bar)

            img_name=x_name[:-4]+'.png'
            for i in range(n_samples):

                current_url=os.path.join(opts.save_url,'condition_%d'%(i+1))
                os.makedirs(current_url,exist_ok=True)
                current_img=C[pixels[i]].view(opts.image_size, opts.image_size, 3).numpy().astype(np.uint8)
                tmp=Image.fromarray(current_img)

                tmp.save(os.path.join(current_url,os.path.basename(img_name)))
            print("Finish %s"%(img_name))
        
        e_time=time.time()
        print("This test totally costs %.5f seconds"%(e_time-s_time))


In [None]:
#@title Guided_Upsample/src/models.py (fixing state_dict)
%%writefile /content/ICT/Guided_Upsample/src/models.py
import os
import torch
import torch.nn as nn
import torch.optim as optim
from .networks import Discriminator, Discriminator2, InpaintGenerator_5
from .loss import AdversarialLoss, PerceptualLoss, StyleLoss


class BaseModel(nn.Module):
    def __init__(self, name, config):
        super(BaseModel, self).__init__()

        self.name = name
        self.config = config
        self.iteration = 0

        self.gen_weights_path = os.path.join(config.PATH, name + '_gen.pth')
        self.dis_weights_path = os.path.join(config.PATH, name + '_dis.pth')

    def load(self):
        if os.path.exists(self.gen_weights_path):
            print('Loading %s generator...' % self.name)

            if torch.cuda.is_available():
                data = torch.load(self.gen_weights_path)
            else:
                data = torch.load(self.gen_weights_path, map_location=lambda storage, loc: storage)

            #self.generator.load_state_dict(data['generator'])
            self.generator.load_state_dict(data)
            self.iteration = 0 # data['iteration']

        # load discriminator only when training
        if (self.config.MODE == 1 or self.config.score) and os.path.exists(self.dis_weights_path):
            print('Loading %s discriminator...' % self.name)

            if torch.cuda.is_available():
                data = torch.load(self.dis_weights_path)
            else:
                data = torch.load(self.dis_weights_path, map_location=lambda storage, loc: storage)

            self.discriminator.load_state_dict(data['discriminator'])

    def save(self):
        print('\nsaving %s...\n' % self.name)
        torch.save({
            'iteration': self.iteration,
            'generator': self.generator.state_dict()
        }, self.gen_weights_path)

        torch.save({
            'discriminator': self.discriminator.state_dict()
        }, self.dis_weights_path)


class InpaintingModel(BaseModel):
    def __init__(self, config):
        super(InpaintingModel, self).__init__('InpaintingModel', config)

        # generator input: [rgb(3) + edge(1)]
        # discriminator input: [rgb(3)]
        
        if config.Generator==4:
            print("*******remove IN*******")
            generator = InpaintGenerator_5()

        if config.Discriminator==0:
            discriminator = Discriminator(in_channels=3, use_sigmoid=config.GAN_LOSS != 'hinge')
        else:
            discriminator = Discriminator2(in_channels=3, use_sigmoid=config.GAN_LOSS != 'hinge')
        """
        if len(config.GPU) > 1:
            generator = nn.DataParallel(generator, config.GPU)
            discriminator = nn.DataParallel(discriminator , config.GPU)
        """
        l1_loss = nn.L1Loss()
        perceptual_loss = PerceptualLoss()
        style_loss = StyleLoss()
        adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)

        self.add_module('generator', generator)
        self.add_module('discriminator', discriminator)

        self.add_module('l1_loss', l1_loss)
        self.add_module('perceptual_loss', perceptual_loss)
        self.add_module('style_loss', style_loss)
        self.add_module('adversarial_loss', adversarial_loss)

        self.gen_optimizer = optim.Adam(
            params=generator.parameters(),
            lr=float(config.LR),
            betas=(config.BETA1, config.BETA2)
        )

        self.dis_optimizer = optim.Adam(
            params=discriminator.parameters(),
            lr=float(config.LR) * float(config.D2G_LR),
            betas=(config.BETA1, config.BETA2)
        )

    def process(self, images, edges, masks):
        self.iteration += 1

        # zero optimizers
        self.gen_optimizer.zero_grad()
        self.dis_optimizer.zero_grad()


        # process outputs
        outputs = self(images, edges, masks)
        gen_loss = 0
        dis_loss = 0


        # discriminator loss
        dis_input_real = images
        dis_input_fake = outputs.detach()
        dis_real, _ = self.discriminator(dis_input_real)                    # in: [rgb(3)]
        dis_fake, _ = self.discriminator(dis_input_fake)                    # in: [rgb(3)]
        dis_real_loss = self.adversarial_loss(dis_real, True, True)
        dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
        dis_loss += (dis_real_loss + dis_fake_loss) / 2

        dis_loss.backward()
        self.dis_optimizer.step()


        # generator adversarial loss
        gen_input_fake = outputs
        gen_fake, _ = self.discriminator(gen_input_fake)                    # in: [rgb(3)]
        gen_gan_loss = self.adversarial_loss(gen_fake, True, False) * self.config.INPAINT_ADV_LOSS_WEIGHT
        gen_loss += gen_gan_loss


        # generator l1 loss
        gen_l1_loss = self.l1_loss(outputs, images) * self.config.L1_LOSS_WEIGHT / torch.mean(masks)
        gen_loss += gen_l1_loss


        # generator perceptual loss
        gen_content_loss = self.perceptual_loss(outputs, images)
        gen_content_loss = gen_content_loss * self.config.CONTENT_LOSS_WEIGHT
        gen_loss += gen_content_loss


        # generator style loss
        gen_style_loss = self.style_loss(outputs * masks, images * masks)
        gen_style_loss = gen_style_loss * self.config.STYLE_LOSS_WEIGHT
        gen_loss += gen_style_loss


        gen_loss.backward()
        self.gen_optimizer.step()

        # create logs
        logs = [
            ("l_d2", dis_loss.item()),
            ("l_g2", gen_gan_loss.item()),
            ("l_l1", gen_l1_loss.item()),
            ("l_per", gen_content_loss.item()),
            ("l_sty", gen_style_loss.item()),
        ]

        return outputs, gen_loss, dis_loss, logs

    def forward(self, images, edges, masks):
        images_masked = (images * (1 - masks).float()) + masks
        inputs = torch.cat((images_masked, edges), dim=1)



        if self.config.Generator==0 or self.config.Generator==2 or self.config.Generator==4:
            outputs = self.generator(inputs)
        else:
            outputs = self.generator(inputs, masks)



        if self.config.score:
            gen_fake, _ =self.discriminator(outputs) 
            gen_fake=gen_fake.view(8,-1)
            print(torch.mean(gen_fake,dim=1))

        return outputs

In [None]:
#@title fixing imagenet state_dict
# rfr paris
%cd /content/

#https://discuss.pytorch.org/t/dataparallel-changes-parameter-names-issue-with-load-state-dict/60211
import torch
from collections import OrderedDict
state_dict = torch.load("/content/ICT/ckpts_ICT/Upsample/ImageNet/InpaintingModel_gen.pth", map_location='cpu')
new_state_dict = OrderedDict()

for k, v in state_dict['generator'].items():
  name = k.replace("module.", "")

  new_state_dict[name] = v

torch.save(new_state_dict, '/content/ICT/ckpts_ICT/Upsample/ImageNet/InpaintingModel_gen.pth')

In [None]:
#@title fixing places state_dict
# rfr paris
%cd /content/

#https://discuss.pytorch.org/t/dataparallel-changes-parameter-names-issue-with-load-state-dict/60211
import torch
from collections import OrderedDict
state_dict = torch.load("/content/ICT/ckpts_ICT/Upsample/Places2_Nature/InpaintingModel_gen.pth", map_location='cpu')
new_state_dict = OrderedDict()

for k, v in state_dict['generator'].items():
  name = k.replace("module.", "")

  new_state_dict[name] = v

torch.save(new_state_dict, '/content/ICT/ckpts_ICT/Upsample/Places2_Nature/InpaintingModel_gen.pth')

Paths
```
/content/input # user data
/content/masks # user masks
/content/output # results
```

In [None]:
# example data
!wget "https://i.guim.co.uk/img/media/6088d89032f8673c3473567a91157080840a7bb8/413_955_2808_1685/master/2808.jpg?width=1200&height=1200&quality=85&auto=format&fit=crop&s=412cc526a799b2d3fff991129cb8f030" -O /content/input/0.png
!wget "https://i.stack.imgur.com/PIfn1.png" -O /content/masks/0.png

In [None]:
# resize, input must be 256px
import cv2
image = cv2.imread("/content/input/0.png")
image = cv2.resize(image, (256,256), interpolation=cv2.INTER_NEAREST)
cv2.imwrite("/content/input/0.png", image)

image = cv2.imread("/content/masks/0.png")
image = cv2.resize(image, (256,256), interpolation=cv2.INTER_NEAREST)
cv2.imwrite("/content/masks/0.png", image)

In [None]:
# running with places
%cd /content/ICT/Transformer
!python inference.py --ckpt_path /content/ICT/ckpts_ICT/Transformer/Places2_Nature.pth \
                                --BERT --image_url "/content/input" \
                                --mask_url "/content/masks" \
                                --n_layer 35 --n_embd 512 --n_head 8 --top_k 40 --GELU_2 \
                                --save_url "/content/prior" \
                                --image_size 32 --n_samples 1

%cd /content/ICT/Guided_Upsample
!python test.py --input "/content/input/" \
                                        --mask "/content/masks" \
                                        --prior "/content/prior" \
                                        --output "/content/output" \
                                        --checkpoints /content/ICT/ckpts_ICT/Upsample/Places2_Nature \
                                        --test_batch_size 1 --model 2 --Generator 4 --condition_num 1

In [None]:
# delete output files if needed, folder must be empty for a run
%cd /content
!sudo rm -rf /content/output
!mkdir /content/output
!sudo rm -rf /content/prior
!mkdir /content/prior