In [1]:
%env CUDA_VISIBLE_DEVICES=1
import torch
import torch.nn as nn
class shave_block(nn.Module):
    def __init__(self, s):
        super(shave_block, self).__init__()
        self.s=s
    def forward(self,x):
        return x[:,:,self.s:-self.s,self.s:-self.s]

env: CUDA_VISIBLE_DEVICES=1


# Define G

In [2]:
from functools import reduce
from torch.autograd import Variable

class LambdaBase(nn.Sequential):
    def __init__(self, fn, *args):
        super(LambdaBase, self).__init__(*args)
        self.lambda_func = fn

    def forward_prepare(self, input):
        output = []
        for module in self._modules.values():
            output.append(module(input))
        return output if output else input

class Lambda(LambdaBase):
    def forward(self, input):
        return self.lambda_func(self.forward_prepare(input))

class LambdaMap(LambdaBase):
    def forward(self, input):
        return list(map(self.lambda_func,self.forward_prepare(input)))

class LambdaReduce(LambdaBase):
    def forward(self, input):
        return reduce(self.lambda_func,self.forward_prepare(input))


G = nn.Sequential( # Sequential,
    nn.ReflectionPad2d((40, 40, 40, 40)),
    nn.Conv2d(1,32,(9, 9),(1, 1),(4, 4)),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.Conv2d(32,64,(3, 3),(2, 2),(1, 1)),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.Conv2d(64,128,(3, 3),(2, 2),(1, 1)),
    nn.BatchNorm2d(128),
    nn.ReLU(),
    nn.Sequential( # Sequential,
        LambdaMap(lambda x: x, # ConcatTable,
            nn.Sequential( # Sequential,
                nn.Conv2d(128,128,(3, 3)),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.Conv2d(128,128,(3, 3)),
                nn.BatchNorm2d(128),
            ),
            shave_block(2),
        ),
        LambdaReduce(lambda x,y: x+y), # CAddTable,
    ),
    nn.Sequential( # Sequential,
        LambdaMap(lambda x: x, # ConcatTable,
            nn.Sequential( # Sequential,
                nn.Conv2d(128,128,(3, 3)),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.Conv2d(128,128,(3, 3)),
                nn.BatchNorm2d(128),
            ),
            shave_block(2),
        ),
        LambdaReduce(lambda x,y: x+y), # CAddTable,
    ),
    nn.Sequential( # Sequential,
        LambdaMap(lambda x: x, # ConcatTable,
            nn.Sequential( # Sequential,
                nn.Conv2d(128,128,(3, 3)),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.Conv2d(128,128,(3, 3)),
                nn.BatchNorm2d(128),
            ),
            shave_block(2),
        ),
        LambdaReduce(lambda x,y: x+y), # CAddTable,
    ),
    nn.Sequential( # Sequential,
        LambdaMap(lambda x: x, # ConcatTable,
            nn.Sequential( # Sequential,
                nn.Conv2d(128,128,(3, 3)),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.Conv2d(128,128,(3, 3)),
                nn.BatchNorm2d(128),
            ),
            shave_block(2),
        ),
        LambdaReduce(lambda x,y: x+y), # CAddTable,
    ),
    nn.Sequential( # Sequential,
        LambdaMap(lambda x: x, # ConcatTable,
            nn.Sequential( # Sequential,
                nn.Conv2d(128,128,(3, 3)),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.Conv2d(128,128,(3, 3)),
                nn.BatchNorm2d(128),
            ),
            shave_block(2),
        ),
        LambdaReduce(lambda x,y: x+y), # CAddTable,
    ),
    nn.ConvTranspose2d(128,64,(3, 3),(2, 2),(1, 1),(1, 1)),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.ConvTranspose2d(64,32,(3, 3),(2, 2),(1, 1),(1, 1)),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.Conv2d(32,2,(9, 9),(1, 1),(4, 4)),
    nn.Tanh(),
)

G=G.cuda()

# Define D

In [3]:
import torchvision.models as models
D = models.resnet18(pretrained=False,num_classes=2)
D.fc = nn.Sequential(nn.Linear(2048, 1), nn.Sigmoid())
D = D.cuda()

# Define data generator

In [5]:
import torch
import os
from torch.utils import data
import numpy as np
from PIL import Image
from skimage.color import rgb2yuv,yuv2rgb

class img_data(data.Dataset):
    def __init__(self, path):
        files = os.listdir(path)
        self.files = [os.path.join(path,x) for x in files]
    def __len__(self):
        'Denotes the total number of samples'
        return len(self.files)

    def __getitem__(self, index):
        'Generates one sample of data'
        img = Image.open(self.files[index])
        yuv = rgb2yuv(img)
        y = yuv[...,0]-0.5
        u_t = yuv[...,1] / 0.43601035
        v_t = yuv[...,2] / 0.61497538


        return torch.Tensor(np.expand_dims(y,axis=0)),torch.Tensor(np.stack([u_t,v_t],axis=0))
trainset = img_data('train')
valset = img_data('/data/data/coco/val')


In [6]:
from torch.utils import data
params = {'batch_size': 20,
          'shuffle': True,
          'num_workers': 6}
training_generator = data.DataLoader(trainset, **params)
validation_generator = data.DataLoader(valset, **params)

In [16]:
import cv2
p=Image.open('5.jpg').convert('RGB').resize((256,256))
img_yuv = rgb2yuv(p)
infimg= img_yuv[...,0].reshape(1,1,256,256)
img_variable = Variable(torch.Tensor(infimg-0.5)).cuda()

# Train D

In [8]:
'''
adversarial_loss = torch.nn.BCELoss()
optimizer_G = torch.optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.999))
i=0
for epoch in range(1):
    for y, uv in training_generator:
        # Adversarial ground truths
        valid = Variable(torch.Tensor(y.size(0), 1).fill_(1.0), requires_grad=False).cuda()
        fake = Variable(torch.Tensor(y.size(0), 1).fill_(0.0), requires_grad=False).cuda()

        # Configure input
        yvar = Variable(y).cuda()
        real_imgs = torch.cat([yvar,Variable(uv).cuda()],dim=1)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        gen_imgs = torch.cat([yvar.detach(),G(yvar)],dim=1)
        optimizer_D.zero_grad()
        
        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(D(real_imgs), valid)
        fake_loss = adversarial_loss(D(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()
        i+=1
        if i%100==0:
            print ("[D loss: %f]" % (d_loss.item()))
        if i%2000==0:
            torch.save(D.state_dict(), 'Dinit.pth')
            break
'''
True

True

In [40]:
import json
with open('weight.json','w') as k:
    json.dump(j,k)

In [39]:
j=torch.load('weights/G3.pth')
for i in j.keys():
    j[i]=j[i].cpu().numpy().tolist()

In [17]:
G.load_state_dict(torch.load('weights/G3.pth'))
D.load_state_dict(torch.load('weights/D3.pth'))
res = G(img_variable)
uv=res.cpu().detach().numpy()
uv[:,0,:,:] *= 0.436
uv[:,1,:,:] *= 0.615
fr = np.concatenate([infimg,uv],axis=1).reshape(3,256,256)
rgb=yuv2rgb(fr.transpose(1,2,0))
cv2.imwrite('current.jpg',(rgb.clip(min=0,max=1)*256)[:,:,[2,1,0]])

True

In [None]:
i=0
flag = 0
adversarial_loss = torch.nn.BCELoss()
optimizer_G = torch.optim.Adam(G.parameters(), lr=5e-8, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=1e-6, betas=(0.5, 0.999))
#G.load_state_dict(torch.load('model.pth'))
#D.load_state_dict(torch.load('Dinit.pth'))
for epoch in range(1000):
    for y, uv in training_generator:
        # Adversarial ground truths
        valid = Variable(torch.Tensor(y.size(0), 1).fill_(1.0), requires_grad=False).cuda()
        fake = Variable(torch.Tensor(y.size(0), 1).fill_(0.0), requires_grad=False).cuda()

        # Configure input
        yvar = Variable(y).cuda()
        real_imgs = torch.cat([yvar,Variable(uv).cuda()],dim=1)

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Generate a batch of images
        gen_imgs = torch.cat([yvar.detach(),G(yvar)],dim=1)
        
        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(D(gen_imgs), valid)
        if flag==0:
            g_loss.backward()
            optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------
        
        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(D(real_imgs), valid)
        fake_loss = adversarial_loss(D(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        if flag>0:
            d_loss.backward()
            optimizer_D.step()
        i+=1
        if i%100==0:
            print ("Epoch: %d: [D loss: %f] [G loss: %f]" % (epoch, d_loss.item(), g_loss.item()))
        
            torch.save(D.state_dict(), 'weights/D'+str(epoch)+'.pth')
            torch.save(G.state_dict(), 'weights/G'+str(epoch)+'.pth')
            res = G(img_variable)
            uv=res.cpu().detach().numpy()
            uv[:,0,:,:] *= 0.436
            uv[:,1,:,:] *= 0.615
            fr = np.concatenate([infimg,uv],axis=1).reshape(3,256,256)
            rgb=yuv2rgb(fr.transpose(1,2,0))
            cv2.imwrite('current.jpg',(rgb.clip(min=0,max=1)*256)[:,:,[2,1,0]])
            flag = (flag+1)%6

Epoch: 0: [D loss: 0.000000] [G loss: 24.361118]
Epoch: 0: [D loss: 0.000000] [G loss: 21.169556]
Epoch: 0: [D loss: 0.000000] [G loss: 21.799685]
Epoch: 0: [D loss: 0.000000] [G loss: 24.758905]
Epoch: 0: [D loss: 0.000000] [G loss: 22.759775]
Epoch: 0: [D loss: 0.000000] [G loss: 22.161333]
Epoch: 0: [D loss: 0.000000] [G loss: 23.299866]
Epoch: 0: [D loss: 0.000000] [G loss: 20.657955]
Epoch: 0: [D loss: 0.000000] [G loss: 21.871845]
Epoch: 0: [D loss: 0.000000] [G loss: 23.164768]
Epoch: 0: [D loss: 0.000000] [G loss: 22.579809]
Epoch: 0: [D loss: 0.000000] [G loss: 19.386799]
Epoch: 0: [D loss: 0.000000] [G loss: 23.742941]
Epoch: 0: [D loss: 0.000000] [G loss: 22.875778]
Epoch: 0: [D loss: 0.000000] [G loss: 22.054083]
Epoch: 0: [D loss: 0.000000] [G loss: 20.886280]
Epoch: 0: [D loss: 0.000000] [G loss: 23.602940]
Epoch: 0: [D loss: 0.000000] [G loss: 22.440645]
Epoch: 0: [D loss: 0.000000] [G loss: 22.391920]
Epoch: 0: [D loss: 0.000000] [G loss: 21.965626]
Epoch: 0: [D loss: 0