In [1]:
import os
import sys

import argparse
import datetime
import pickle
import time

import cv2
import numpy as np
import matplotlib.pyplot as plt
import tifffile
import torch
import torch.autograd
import torch.nn as nn
import torch.nn.functional as F

In [2]:
Project_name = 'DP1_0106'
Project_dir = 'Trained_Generators/DP1'
Training = True

In [3]:
def mkdr(proj,proj_dir,Training):
    """
    When training, creates a new project directory or overwrites an existing directory according to user input. When testing, returns the full project path
    :param proj: project name
    :param proj_dir: project directory
    :param Training: whether new training run or testing image
    :return: full project path
    """
    pth = proj_dir + '/' + proj
    if Training:
        try:
            os.mkdir(pth)
            return pth + '/' + proj
        except FileExistsError:
            print('Directory', pth, 'already exists. Enter new project name or hit enter to overwrite')
            new = input()
            if new == '':
                return pth + '/' + proj
            else:
                pth = mkdr(new, proj_dir, Training)
                return pth
        except FileNotFoundError:
            print('The specifified project directory ' + proj_dir + ' does not exist. Please change to a directory that does exist and again')
            sys.exit()
    else:
        return pth + '/' + proj

In [4]:
xy = "inputs/DP1_xy.png"
zx = "inputs/DP1_xz.png"
yz_90_degree_right_rotate = "inputs/DP1_yz.png"

In [5]:
input_image_paths = [xy, zx, yz_90_degree_right_rotate]

In [6]:
# Root directory for dataset
dataroot = "data/celeba"

# Number of workers for dataloader
workers = 0

# Batch size during training
g_batch_size, d_batch_size = 8, 8

# Spatial size of training images. All images will be resized to this size using a transformer.
l = 64

# Number of channels in the training images.
nc = 2

# Number of channels in z latent vector
nz = 64

# Size of z latent vector (i.e. size of generator input)
lz = 4

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 1

# Learning rate for optimizers
g_lr, d_lr = 0.0001, 0.0001

# Beta1 and Beta2 hyperparam for Adam optimizers
g_betas, d_betas = (.9, .99), (.9, .99)

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

In [7]:
pth = mkdr(Project_name, Project_dir, Training)

In [8]:
train_images_set_xyz = []
for input_image_path in input_image_paths:
    input_image = cv2.imread(input_image_path)
    if len(input_image.shape) > 2:
        input_image = input_image[:, :, 0]
    h, w = input_image.shape[0], input_image.shape[1]
    
    phases = np.unique(input_image)
    train_images_set = np.empty([32 * 900, len(phases), l, l]) ###
    for i in range(32 * 900):
        x = np.random.randint(1, h-l-1)
        y = np.random.randint(1, w-l-1)
        for count, phase in enumerate(phases):
            image_i = np.zeros([l, l])
            image_i[input_image[x:x+l, y:y+l] == phase] = 1
            train_images_set[i, count, :, :] = image_i
        
    train_images_set_tensor = torch.utils.data.TensorDataset(torch.FloatTensor(train_images_set))
    train_images_set_xyz.append(train_images_set_tensor)

In [9]:
x_trainloader = torch.utils.data.DataLoader(train_images_set_xyz[0], batch_size=g_batch_size, shuffle=True, num_workers=workers)
y_trainloader = torch.utils.data.DataLoader(train_images_set_xyz[1], batch_size=g_batch_size, shuffle=True, num_workers=workers)
z_trainloader = torch.utils.data.DataLoader(train_images_set_xyz[2], batch_size=g_batch_size, shuffle=True, num_workers=workers)

In [10]:
lays = 5
df, gf = [nc, ndf, ndf*2, ndf*4, ndf*8, 1], [nz, ngf*16, ngf*8, ngf*4, ngf*2, nc]
dk, gk = [4]*lays, [4]*lays
ds, gs = [2]*lays, [2]*lays
dp, gp = [1, 1, 1, 1, 0], [2, 2, 2, 2, 3]

In [11]:
class Generator(nn.Module):
    
    def __init__(self, gf, gk, gs, gp):
        super(Generator, self).__init__()
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        for lay, (k,s,p) in enumerate(zip(gk, gs, gp)):
            self.convs.append(nn.ConvTranspose3d(gf[lay], gf[lay+1], k, s, p, bias=False))
            self.bns.append(nn.BatchNorm3d(gf[lay+1]))
            
    def forward(self, x):
        for conv, bn in zip(self.convs[:-1], self.bns[:-1]):
            x = F.relu_(bn(conv(x)))
        x = torch.softmax(self.convs[-1](x), 1)
        return x

In [12]:
class Discriminator(nn.Module):
    
    def __init__(self, df, dk, ds, dp):
        super(Discriminator, self).__init__()
        self.convs = nn.ModuleList()
        for lay, (k, s, p) in enumerate(zip(dk, ds, dp)):
            self.convs.append(nn.Conv2d(df[lay], df[lay + 1], k, s, p, bias=False))

    def forward(self, x):
        for conv in self.convs[:-1]:
            x = F.relu_(conv(x))
        x = self.convs[-1](x)
        return x

In [13]:
def calc_gradient_penalty(netD, real_data, fake_data, batch_size, l, device, gp_lambda,nc):
    """
    calculate gradient penalty for a batch of real and fake data
    :param netD: Discriminator network
    :param real_data:
    :param fake_data:
    :param batch_size:
    :param l: image size
    :param device:
    :param gp_lambda: learning parameter for GP
    :param nc: channels
    :return: gradient penalty
    """
    #sample and reshape random numbers
    alpha = torch.rand(batch_size, 1, device = device)
    alpha = alpha.expand(batch_size, int(real_data.nelement() / batch_size)).contiguous()
    alpha = alpha.view(batch_size, nc, l, l)

    # create interpolate dataset
    interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach())
    interpolates.requires_grad_(True)

    #pass interpolates through netD
    disc_interpolates = netD(interpolates)
    gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size(), device = device),
                              create_graph=True, only_inputs=True)[0]
    # extract the grads and calculate gp
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * gp_lambda
    return gradient_penalty

In [14]:
class GradientPenalty(nn.Module):
    
    def forward(self, x: torch.Tensor, f: torch.Tensor):
        batch_size = x.shape[0]
        gradients, *_ = torch.autograd.grad(outputs=f,
                                            inputs=x,
                                            grad_outputs=f.new_ones(f.shape),
                                            create_graph=True)
        gradients = gradients.reshape(batch_size, -1)
        norm = gradients.norm(2, dim=-1)
        return 10 * torch.mean((norm - 1) ** 2)

In [15]:
device = torch.device("cuda:0" if(torch.cuda.is_available() and ngpu > 0) else "cpu")

In [16]:
netG = Generator(gf, gk, gs, gp).to(device)
if (device.type == "cuda") and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))
optG = torch.optim.Adam(netG.parameters(), lr=g_lr, betas=g_betas)

In [17]:
netDs = []
optDs = []
for i in range(3):
    netD = Discriminator(df, dk, ds, dp)
    netD = nn.DataParallel(netD, list(range(ngpu))).to(device)
    netDs.append(netD)
    optDs.append(torch.optim.Adam(netDs[i].parameters(), lr=d_lr, betas=d_betas))

In [20]:
def plot_generator_loss(data, labels, pth, name):
    for data, label in zip(data, labels):
        plt.plot(data, label=label)
    plt.title("Generator Loss")
    plt.xlabel("iterations")
    plt.ylabel("loss")
    plt.legend()
    plt.savefig(pth + "_" + name)
    plt.close()

def plot_discriminator_loss(data, labels, pth, name):
    for data, label in zip(data, labels):
        plt.plot(data, label=label)
    plt.title("Discriminator Loss")
    plt.xlabel("iterations")
    plt.ylabel("loss")
    plt.legend()
    plt.savefig(pth + "_" + name)
    plt.close()

def plot_wasserstein_distance(data, labels, pth, name):
    for data, label in zip(data, labels):
        plt.plot(data, label=label)
    plt.title("Wasserstein Distance")
    plt.xlabel("iterations")
    plt.ylabel("wasserstein distance")
    plt.legend()
    plt.savefig(pth + "_" + name)
    plt.close()

def plot_gradient_penalty(data, labels, pth, name):
    for data, label in zip(data, labels):
        plt.plot(data, label=label)
    plt.title("Gradient Penalty")
    plt.xlabel("iterations")
    plt.ylabel("gradient penalty")
    plt.legend()
    plt.savefig(pth + "_" + name)
    plt.close()
    
def plot_real_fake_loss(data, labels, pth, name):
    for data, label in zip(data, labels):
        plt.plot(data, label=label)
    plt.title("Discriminator Outputs")
    plt.xlabel("iterations")
    plt.ylabel("loss")
    plt.legend()
    plt.savefig(pth + "_" + name)
    plt.close()

In [23]:
start_time = time.time()
d_real_output_list = []
d_fake_output_list = []
gradient_penality_list = []
wasserstein_distance_list = []
g_loss_list = []
d_loss_list = []
for epoch in range(1, num_epochs+1):
    
    start_epoch_time = time.time()
    for iteration, (x_train_images, y_train_images, z_train_images) in enumerate(zip(x_trainloader, y_trainloader, z_trainloader), 1):
        train_2d_images_set = [x_train_images, y_train_images, z_train_images] ###
        
        noise = torch.randn(d_batch_size, nz, lz, lz, lz, device=device)
        fake_3d_image = netG(noise).detach()
        
        # Discriminator
        for train_images_set, netD, optD, d1, d2, d3 in zip(train_2d_images_set, netDs, optDs, [2,3,4], [3,2,2], [4,4,3]):
            netD.zero_grad()
            real_2d_images = train_images_set[0].to(device)
            errD_real = netD(real_2d_images).view(-1).mean() ###
            fake_2d_images = fake_3d_image.permute(0, d1, 1, d2, d3).reshape(l * d_batch_size, nc, l, l)
            errD_fake = netD(fake_2d_images).mean() ###
            gradient_penalty = calc_gradient_penalty(netD, real_2d_images, fake_2d_images[:g_batch_size], g_batch_size, l, device, 10, nc)
            errD = errD_fake - errD_real + gradient_penalty
            errD.backward()
            optD.step()
        d_real_output_list.append(errD_real.item())
        d_fake_output_list.append(errD_fake.item())
        wasserstein_distance_list.append(errD_real.item() - errD_fake.item())
        gradient_penality_list.append(gradient_penalty.item())
        d_loss_list.append(errD.item())
         
        # Generator ###
        if iteration % 5 == 0:
            netG.zero_grad()
            errG = 0
            
            noise = torch.randn(g_batch_size, nz, lz, lz, lz, device=device)
            fake_3d_images = netG(noise) ###

            for netD, d1, d2, d3 in zip(netDs, [2, 3, 4], [3, 2, 2], [4, 4, 3]):
                fake_2d_images = fake_3d_image.permute(0, d1, 1, d2, d3).reshape(l * g_batch_size, nc, l, l)
                errG_fake = netD(fake_2d_images)
                errG -= errG_fake.mean()
            errG.backward()
            optG.step()
            g_loss_list.append(errG.item())

        if iteration % 25 == 0:
            netG.eval()
            with torch.no_grad():
                torch.save(netG.state_dict(), pth + "_generator.pth")
                torch.save(netD.state_dict(), pth + "_discriminator.pth")
                plot_real_fake_loss([d_real_output_list, d_fake_output_list], ["Real", "Fake"], pth, "d_output_graoh")
                plot_wasserstein_distance([wasserstein_distance_list], ["Wasserstein Distance"], pth, "wd_graph")
                plot_gradient_penalty([gradient_penality_list], ["Gradient Penalty"], pth, "gp_graph")
                plot_discriminator_loss([d_loss_list], ["Discriminator"], pth, "d_loss_graph")
                plot_generator_loss([g_loss_list], ["Generator"], pth, "g_loss_graph")
            netG.train()
            
    print(
    "進捗率: {0} % \n"
    "現在のエポック: {1}/{2} \n"
    "1エポックに要した時間: {3} s/1epoch \n"
    "これまでに要した時間: {4} \n"
    .format(
        round((epoch) * 100 / num_epochs , 1),
        epoch,
        num_epochs,
        datetime.timedelta(seconds=time.time() - start_epoch_time),
        datetime.timedelta(seconds=(time.time() - start_time))))
    print("-"*40)

finish
finish
finish


KeyboardInterrupt: 