In [None]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
import torchvision.transforms as transforms
import numpy as np
import os
from PIL import Image
from IPython import display
import matplotlib.pyplot as plt
import glob
%matplotlib inline
device = "cuda"

In [None]:
preprocess = transforms.Compose([
    transforms.Resize((256,256*2)),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
class CycleGANDataset(torch.utils.data.Dataset):
    def __init__(self, folder):
        self.files_a = glob.glob(os.path.join(folder,'A/*.jpg'))
        self.files_b = glob.glob(os.path.join(folder,'B/*.jpg'))
        self.images = []
        for fn_a, fn_b in zip(self.files_a, self.files_b):
            if len(self.images) % 100 == 0:
                print(len(self.images))
            try:
                image_a = preprocess(Image.open(fn_a))
                image_b = preprocess(Image.open(fn_b))
            except IOError:
                continue
            self.images.append((image_a, image_b))
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        out_data = self.images[idx]
        return out_data
dataset = CycleGANDataset("/home/yanai-lab/terauchi-k/export/jupyter/notebook/pytorch-CycleGAN-and-pix2pix/datasets/datasets/apple2orange/train")


In [None]:
class DownSamp(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DownSamp, self).__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=2,padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1,padding=1)
        self.conv3 = nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1,padding=1)
    def forward(self, x):
        x = self.conv1(x)
        x = F.leaky_relu(x)
        x = self.conv2(x)
        x = F.leaky_relu(x)
        x = self.conv3(x)
        x = F.leaky_relu(x)
        return x
class UpSamp(nn.Module):
    def __init__(self, in_ch, cat_ch, out_ch):
        super(UpSamp, self).__init__()
        self.deconv = nn.ConvTranspose2d(in_ch,in_ch//2,kernel_size=4, stride=2,padding=1)
        #forwardでcatする catするサイズはアップサンプル後
        self.conv1 = nn.Conv2d(in_ch//2+cat_ch, out_ch, kernel_size=3, stride=1,padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1,padding=1)
    def forward(self,x, cat):
        x = self.deconv(x)
        x = F.leaky_relu(x)
        x = torch.cat((x, cat),axis=1)
        x = self.conv1(x)
        x = F.leaky_relu(x)
        x = self.conv2(x)
        x = F.leaky_relu(x)
        return x
d = DownSamp(32,64)
u = UpSamp(64,32,32)
d(torch.zeros((1,32,64,64)))
u(torch.zeros((1,64,32,32)),torch.zeros((1,32,64,64))).shape
class UNet(nn.Module):
    def __init__(self):
        super(UNet,self).__init__()
        self.in1 = nn.Conv2d(3, 32, kernel_size=7,stride=1,padding=3)
        self.in2 = nn.Conv2d(32, 32, kernel_size=3,stride=1,padding=1)
        self.down1 = DownSamp(32,64)
        self.down2 = DownSamp(64,128)
        self.down3 = DownSamp(128,256)
        self.down4 = DownSamp(256,512)
        self.up4 = UpSamp(512,256,256)
        self.up3 = UpSamp(256,128,128)
        self.up2 = UpSamp(128,64,64)
        self.up1 = UpSamp(64,32,32)
        self.out = nn.Conv2d(32,3, kernel_size=7, stride=1, padding=3)
    def forward(self,x):
        x = self.in1(x)
        x = F.leaky_relu(x)
        x1 = self.in2(x)
        x1 = F.leaky_relu(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up4(x5,x4)
        x = self.up3(x, x3)
        x = self.up2(x, x2)
        x = self.up1(x, x1)
        x = self.out(x)
        x = torch.sigmoid(x)
        return x


In [None]:
from torch.nn.utils.spectral_norm import spectral_norm

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()        
        self.main = nn.Sequential(
            nn.Conv2d(6, 32, stride=2, kernel_size=3, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(32, 64, stride=2, kernel_size=3, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(64, 128, stride=2, kernel_size=3, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(128, 256, stride=2, kernel_size=3, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256,256,stride=2,kernel_size=3,padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256,512,stride=2,kernel_size=3,padding=1),
            nn.LeakyReLU(0.1),
            #nn.Conv2d(1024,1024,stride=1,kernel_size=1,padding=0),
            nn.Conv2d(512,1,stride=1,kernel_size=1,padding=0),
            nn.Flatten(),
        )
    def forward(self, x):
        return (self.main(x).squeeze_(1))
d = Discriminator()
d(torch.zeros((1,6,256,256))).shape

In [None]:
batch_size=1
epoch = 3
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,shuffle=True)
g_ab = UNet().to(device)
g_ba = UNet().to(device)
d_a = Discriminator().to(device)
d_b = Discriminator().to(device)
#zero_ = torch.zeros((batch_size,64)).to(device)
#one_ = torch.ones((batch_size,64)).to(device)
lossfunc = nn.MSELoss()
g_opt = opt.SGD(params=g.parameters(),lr=1e-3, momentum=0.8)
d_opt = opt.SGD(params=d.parameters(),lr=1e-2, momentum=0.8)
#g_opt = opt.Adam(params=g.parameters())
#d_opt = opt.Adam(params=d.parameters())
itr = 0
for ep in range(epoch):
    for (img_a, img_b) in dataloader:
        itr+=1
        img_a = img_a.to(device)
        img_b = img_b.to(device)
        fake_b = g_ab(img_a)
        fake_a = g_ba(img_b)
        rec_a = g_ba(fake_b)
        rec_b = g_ab(fake_a)
        
        l_g_aba_rec = torch.mean(torch.abs(img_a-rec_a))
        l_g_bab_rec = torch.mean(torch.abs(img_b-rec_b))
        l_g_ab_adv = 
        
        g_loss = 1e-2*torch.mean(torch.abs(fake-real))
        fake = torch.cat((img_a,fake),axis=1)
        real = torch.cat((img_a,real),axis=1)
        zero_ = torch.zeros((img_a.shape[0],16)).to(device)
        one_ = torch.ones((img_a.shape[0],16)).to(device)
        
        g_loss += lossfunc(d(fake),one_)
        #g_loss += -torch.mean(d(fake))
        g.zero_grad()
        g_loss.backward()
        g_opt.step()
                


        d_loss = lossfunc(d(real),one_) + lossfunc(d(fake.detach()),zero_)
        #d_loss = -(torch.mean(torch.min(zero_,-one_+d(real)))+torch.mean(torch.min(zero_,-one_-d(fake.detach()))))
        d.zero_grad()
        d_loss.backward()
        d_opt.step()
        if itr % 1000 == 1:
            print(d_loss,g_loss)
            print("a")
            plt.imshow(img_a.detach()[0].cpu().numpy().transpose(1,2,0))
            plt.show()
            print("b")
            plt.imshow(img_b.detach()[0].cpu().numpy().transpose(1,2,0))
            plt.show()
            print("a->b")
            plt.imshow(fake.detach()[0].squeeze()[3:].cpu().numpy().transpose(1,2,0))
            plt.show()