In [4]:
import torch.nn as nn
from torch.autograd import Variable
import torch
import torch.nn.functional as F
import torch.optim as optim

In [5]:
class Generator(nn.Module):

    def __init__(self, kernel_size=3):
        super(Generator, self).__init__()
        #input size (6,256,256)
        self.conv1 = nn.Sequential(
            nn.Conv2d(6,32,kernel_size,stride = 1, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(32,32,kernel_size,stride = 1, padding=1),
            nn.LeakyReLU(),
            nn.AvgPool2d(3,stride=2,padding=1))
        #(3,128,128)
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(32,64,kernel_size,stride = 1, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64,64,kernel_size,stride = 1, padding=1),
            nn.LeakyReLU(),
            nn.AvgPool2d(3,stride=2,padding=1))
        #(64,64,64)
        self.conv3 = nn.Sequential(
            nn.Conv2d(64,128,kernel_size,stride = 1, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(128,128,kernel_size,stride = 1, padding=1),
            nn.LeakyReLU(),
            nn.AvgPool2d(3,stride=2,padding=1))
        #(128,32,32)
        self.conv4 = nn.Sequential(
            nn.Conv2d(128,256,kernel_size,stride = 1, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(256,256,kernel_size,stride = 1, padding=1),
            nn.LeakyReLU(),
            nn.AvgPool2d(3,stride=2,padding=1))
        #(256,16,16)
        self.deconv4 = nn.Sequential(
            nn.ConvTranspose2d(256,128,kernel_size,stride = 1, padding=2),
            nn.LeakyReLU(),
            nn.Conv2d(128,128,kernel_size,stride = 1, padding=2),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear'))
        #(128,32,32)
        self.deconv3 = nn.Sequential(
            nn.ConvTranspose2d(256,64,kernel_size,stride = 1, padding=2),
            nn.LeakyReLU(),
            nn.Conv2d(64,64,kernel_size,stride = 1, padding=2),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear'))
        #(64,64,64)
        self.deconv2 = nn.Sequential(
            nn.ConvTranspose2d(128,32,kernel_size,stride = 1, padding=2),
            nn.LeakyReLU(),
            nn.Conv2d(32,32,kernel_size,stride = 1, padding=2),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear'))
        #(3,256,256)
        self.deconv1 = nn.Sequential(
            nn.ConvTranspose2d(64,3,kernel_size,stride = 1, padding=2),
            nn.LeakyReLU(),
            nn.Conv2d(3,3,kernel_size,stride = 1, padding=2),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear'))
        
    def forward(self,inputs):
        e1 = self.conv1(inputs)
        e2 = self.conv2(e1)
        e3 = self.conv3(e2)
        e4 = self.conv4(e3)
        d4 = self.deconv4(e4)
        d3_in = torch.cat((e3,d4),1)
        d3 = self.deconv3(d3_in)
        d2_in = torch.cat((e2,d3),1)
        d2 = self.deconv2(d2_in)
        d1_in = torch.cat((e1,d2),1)
        d1 = self.deconv1(d1_in)
        return d1

In [6]:
#test Generator
inputs = torch.rand((10,6,240,320))
model = Generator()
outputs = model.forward(inputs)
print(outputs.size())

torch.Size([10, 3, 240, 320])


  "See the documentation of nn.Upsample for details.".format(mode))


In [19]:
class Discriminator(nn.Module):

    def __init__(self, kernel_size=5, dim=64):
        super(Discriminator, self).__init__()
        self.kernel_size = kernel_size
        self.dim = dim
        self.conv1 = nn.Conv2d(3, self.dim, self.kernel_size, stride=2, padding=2)
        #(64,120,160)
        #LeakyRelu
        self.conv2 = nn.Conv2d(self.dim, 2*self.dim,kernel_size,stride=2,padding=2)
        self.bn1 = nn.BatchNorm2d(2*self.dim)
        #LeakyRelu
        #(128,60,80)
        self.conv3 = nn.Conv2d(2*self.dim,4*self.dim,kernel_size,stride=2,padding=2)
        self.bn2 = nn.BatchNorm2d(4*self.dim)
        #LeakyRelu
        #(256,30,40)
        self.conv4 = nn.Conv2d(4*self.dim,8*self.dim,kernel_size,stride=2,padding=2)
        self.bn3 = nn.BatchNorm2d(8*self.dim)
        #LeakyRelu
        #(512,15,20)
        self.conv5 = nn.Conv2d(8*self.dim,8*self.dim,kernel_size,stride=2,padding=2)
        self.bn4 = nn.BatchNorm2d(8*self.dim)
        #LeakyRelu
        #(512,8,10)
        self.fc = nn.Linear(8*10*8*self.dim,1)
        
    def forward(self, x):
        batch_size, channel, height, width = x.shape
        output = self.conv1(x)
        output = F.leaky_relu(output)
        output = self.conv2(output)
        output = self.bn1(output)
        output = F.leaky_relu(output)
        output = self.conv3(output)
        output = self.bn2(output)
        output = F.leaky_relu(output)
        output = self.conv4(output)
        output = self.bn3(output)
        output = F.leaky_relu(output)
        output = self.conv5(output)
        output = self.bn4(output)
        output = F.leaky_relu(output)
        
        output = output.view(-1, 8*10*8*self.dim)
        output = self.fc(output)
        
        return output
      

In [18]:
#test discriminator
x = torch.rand((1,3,240,320))
model = Discriminator()
model.forward(x)

tensor([[-0.2042]], grad_fn=<AddmmBackward>)

In [18]:
L1_criterion = nn.L1Loss()
BCE_criterion = nn.BCELoss()
#generator_one = GeneratorCNN_Pose_UAEAfterResidual_256(21, z_num, repeat_num)
generator_two = Generator()
discriminator = Discriminator()

gen_train_op2 = optim.Adam(generator_two.parameters(), lr=2e-5, betas=(0.5, 0.999))
dis_train_op1 = optim.Adam(discriminator.parameters(), lr=2e-5, betas=(0.5, 0.999))

In [21]:
def train():
    for epoch in range(10):
        for step, example in enumerate(pose_loader):
            [x, x_target, mask_target] = example
            x = Variable(x.cuda())
            x_target = Variable(x_target.cuda())
            mask_target = Variable(mask_target.cuda())
            
            
            
            DiffMap = generator_two(torch.cat([G1, x], dim=1))
            G2 = G1 + DiffMap
            triplet = torch.cat([x_target, G2, x], dim=0)
            D_z = Discriminator(triplet)
            D_z_pos_x_target, D_z_neg_g2, D_z_neg_x = torch.split(D_z, 3)
            D_z_pos = D_z_pos_x_target
            D_z_neg = torch.cat([D_z_neg_g2, D_z_neg_x], 0)
            
            g_loss_2 = BCE_criterion(D_z_neg, torch.ones((2)).cuda())
            PoseMaskLoss2 = L1_criterion(G2 * mask_target, x_target * mask_target)
            L1Loss2 = L1_criterion(G2, x_target) + PoseMaskLoss2
            g_loss_2 += 50*L1Loss2

            gen_train_op2.zero_grad()
            g_loss_2.backward()
            gen_train_op2.step(retain_graph=True)

            d_loss = BCE_criterion(D_z_pos, torch.ones((1)).cuda())
            d_loss += BCE_criterion(D_z_neg, torch.zeros((2)).cuda())
            d_loss /= 2
            
            dis_train_op1.zero_grad()
            d_loss.backward()
            dis_train_op1.step()
