In [1]:
import torch
import torchvision
import torch.nn.functional as F
from torch import nn, optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np
from random import randrange
# from torchsummary import summary
import os
import datetime

import sys
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

from CoordConvModule import *

In [2]:
epochs = 1000
batch_size = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
trainset = torch.load('../data/dataset_10steps.pt')

In [4]:
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

In [5]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # Input size: [batch, 3, 100, 100]
        # Output size: [batch, 3, 100, 100]
        

    def forward(self, x):
        # print(x.size())
        # wall = x[:, -1]
        # print(wall.size())
        x = self.encoder(x)
        x = self.decoder(x)
        # print(x.size())
        # x = torch.cat((x, wall.reshape(len(x), 1, width_fig, width_fig)), 1)
        # print(x.size())
        return x

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 9, stride=2)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 8, 7, stride=2)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout()
        self.fc1 = nn.Linear(128, 32)
        self.fc2 = nn.Linear(32, 1)

    def forward(self, x):
        # print(x.size())
        x = self.pool(F.relu(self.conv1(x)))
        # print(x.size())
        x = self.pool(F.relu(self.conv2(x)))
        # print(x.size())
        x = torch.flatten(x, start_dim=1) # flatten all dimensions except batch
        # print(x.size())
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        # print(x.size())
        x = F.sigmoid(self.fc2(x))
        # print(x.size())
        return x

In [7]:
class StepAE(nn.Module):
    def __init__(self):
        """Variational Auto-Encoder Class"""
        super(StepAE, self).__init__()
        
        self.generator = nn.Sequential(
            CoordConv(3, 32, kernel_size = 5, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            nn.Conv2d(32, 32, 5, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(32, 32, 5, stride=1, padding=1),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(32, 3, 5, stride=1, padding=1),
        )

    def step(self, x): # For actual run after training
        x = self.generator(x)
        return x

    def forward(self, x): # For training
        # Encode x to z
        recon1 = self.generator(x)
        recon2 = self.generator(recon1)
        recon3 = self.generator(recon2)
        recon4 = self.generator(recon3)
        # recon5 = self.enc(recon4)
        
        return recon1, recon2, recon3, recon4

# Training

In [8]:
disc_history_real = []
disc_history_fake = []
rme_history = []

In [9]:
# https://uos-deep-learning.tistory.com/16
def calc_gradient_penalty(netD, real_data, generated_data):
    # GP strength
    LAMBDA = 3

    b_size = real_data.size()[0]

    # Calculate interpolation
    alpha = torch.rand(b_size, 1, 1, 1)
    alpha = alpha.expand_as(real_data)
    alpha = alpha.cuda()

    interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
    interpolated = torch.autograd.Variable(interpolated, requires_grad=True)
    interpolated = interpolated.cuda()

    # Calculate probability of interpolated examples
    prob_interpolated = netD(interpolated)

    # Calculate gradients of probabilities with respect to examples
    gradients = torch.autograd.grad(outputs=prob_interpolated, inputs=interpolated,
                           grad_outputs=torch.ones(prob_interpolated.size()).cuda(),
                           create_graph=True, retain_graph=True)[0]

    # Gradients have shape (batch_size, num_channels, img_width, img_height),
    # so flatten to easily take norm per example in batch
    gradients = gradients.view(b_size, -1)

    # Derivatives of the gradient close to 0 can cause problems because of
    # the square root, so manually calculate norm and add epsilon
    gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

    # Return gradient penalty
    return LAMBDA * ((gradients_norm - 1) ** 2).mean()

In [10]:
def train(model, discnet, train_loader, epochnow):
    model.train()
    discnet.train()
    avg_loss = 0
    for step, (x0, _, x1, _, x2, _, x3, _, x4, _) in enumerate(train_loader):
        # noisy_x = x + train_noise_level * torch.randn(*x.shape)
        # noisy_x = np.clip(noisy_x, 0., 1.)
        del _

        x0 = x0.permute(0, 3, 1, 2)
        x1 = x1.permute(0, 3, 1, 2)
        x2 = x2.permute(0, 3, 1, 2)
        x3 = x3.permute(0, 3, 1, 2)
        x4 = x4.permute(0, 3, 1, 2)

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))
        # model.eval()
        discnet.train()
        discnet.zero_grad()
        optimizerD.zero_grad()
        r1, r2, r3, r4 = model(x0.cuda())
        
        # All real
        label_size = x1.size(0)
        real_label = 1
        label = torch.full((label_size,), real_label, dtype=torch.float, device=device)       
        for idx, x_ in enumerate([x1, x2, x3, x4]):
            output = discnet(x_.cuda()).view(-1)
            if idx == 0:
                errD_real = criterionD(output, label)
            else:
                errD_real += criterionD(output, label)
        errD_real.backward()
        out_real = output.detach().cpu().numpy()

        # All fake
        fake_label = 0
        label.fill_(fake_label)
        for idx, r_ in enumerate([r1, r2, r3, r4]):
            output = discnet(r_.detach().contiguous()).view(-1)
            if idx == 0:
                errD_fake = criterionD(output, label)
            else:
                errD_fake += criterionD(output, label)
        errD_fake.backward()

        grad_penalty = calc_gradient_penalty(netD, x1.cuda(), r1)
        grad_penalty += calc_gradient_penalty(netD, x2.cuda(), r2)
        grad_penalty += calc_gradient_penalty(netD, x3.cuda(), r3)
        grad_penalty += calc_gradient_penalty(netD, x4.cuda(), r4)
        grad_penalty = 1*grad_penalty
        grad_penalty.backward()

        optimizerD.step()
        del r1, r2, r3, r4

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        # model.train()
        discnet.eval()
        optimizer.zero_grad()
        model.zero_grad()

        r1, r2, r3, r4 = model(x0.cuda())
        label.fill_(real_label)
        for idx, r_ in enumerate([r1, r2, r3, r4]):
            output = discnet(r_.detach().contiguous()).view(-1)
            if idx == 0:
                errG = criterionD(output, label)
            else:
                errG += criterionD(output, label)

        loss1 = criterion(r1, x1.cuda())
        loss2 = criterion(r2, x2.cuda())
        loss3 = criterion(r3, x3.cuda())
        loss4 = criterion(r4, x4.cuda())
        # loss5 = criterion(r5, x5.cuda())
        loss = 10*(loss1 + loss2 + loss3 + loss4) + errG
        
        del x0, x1, x2, x3, x4, r1, r2, r3, r4

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        avg_loss += loss.item()
        loss_history.append(loss.item())

        if step%10 == 0:
            print("step{}, loss = {}, real = {}, fake = {}".format(step, loss, np.mean(out_real), (output.mean().item())), errG.mean().item())
            disc_history_real.append([out_real])
            disc_history_fake.append([output.detach().cpu().numpy()])

        if step%100 == 0:
            # sys.stdout = open(os.devnull, 'w')
            figplot(epochnow, step)
            torch.save(model_origin.state_dict(), "../data/model/" + timestampStr + "/CConv_stepwise_ep{}_{}.pt".format(epochnow, step))
            torch.save(netD.state_dict(), "../data/model/" + timestampStr + "/CConv_discnet_ep{}_{}.pt".format(epochnow, step))
            torch.save(model_origin.state_dict(), "latest_net_wdisc_CConv.pt".format(epochnow, step))
            torch.save(netD.state_dict(), "latest_disc_wdisc_4ch_CConv.pt".format(epochnow, step))
            np.save("../data/model/" + timestampStr +"/disc_history_fake", disc_history_fake)
            np.save("../data/model/" + timestampStr +"/disc_history_real", disc_history_real)
            np.save("../data/model/" + timestampStr +"/loss_history", loss_history)
            # sys.stdout = sys.__stdout__

    return avg_loss / len(train_loader)

In [11]:
netD = Discriminator().to(device)
optimizerD = optim.Adam(netD.parameters(), lr=0.0001)
criterionD = nn.BCEWithLogitsLoss()

In [12]:
model_origin = StepAE().to(device)
model_origin.load_state_dict(torch.load("latest_net_wdisc_CConv.pt"))
criterion = nn.MSELoss()
loss_history = []

In [13]:
oneloader = DataLoader(trainset, batch_size=1, shuffle=True)
torch.cuda.empty_cache()

def figplot(epochnow, step):
    model_origin.eval()
    dataiter = iter(oneloader)
    images, *_ = dataiter.next()
    del _

    recon, *_ = model_origin(images.permute(0, 3, 1, 2).cuda())

    del _
    torch.cuda.empty_cache()
    # get sample outputs

    steps = 100
    skip = int(steps/10)
    plt.figure(figsize=(6, 20))
    ii = 0
    for i in range(steps):
        recon, *_ = model_origin(recon)
        del _
        
        if i%skip == 0:
            ii+=1
            ax = plt.subplot(10, 1, ii + 0)
            # print(i)
            # output is resized into a batch of iages
            # use detach when it's an output that requires_grad
            output = recon.detach().view(1, 3, 100, 100).cpu()[0].permute(1, 2, 0)
            # print(output.min(), output.max())
            plt.imshow(output)
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
            del output
            torch.cuda.empty_cache()
    plt.savefig("../data/model/"+timestampStr+"/fig_{}_{}.png".format(epochnow, step))
    torch.save(images, "../data/model/"+timestampStr+"/input_{}_{}.pt".format(epochnow,step))
    del images
    del recon
    del dataiter

In [14]:
optimizer = torch.optim.Adam(model_origin.parameters(), lr=0.0002)

import datetime
dateTimeObj = datetime.datetime.now()
timestampStr = "CoordConv_" + dateTimeObj.strftime("%d_%b_%Y_%H_%M")
if timestampStr not in os.listdir("../data/model/"):
    os.mkdir("../data/model/"+timestampStr)

for epoch in range(1, epochs+1):
    epoch_loss = train(model=model_origin, discnet=netD, train_loader=trainloader, epochnow=epoch)
    print("[Epoch {}] loss:{}".format(epoch, epoch_loss))
    torch.save(model_origin.state_dict(), "../data/model/" + timestampStr + "/new_stepwise_ep{}.pt".format(epoch))
    torch.save(netD.state_dict(), "../data/model/" + timestampStr + "/new_discnet_ep{}.pt".format(epoch))

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


step0, loss = 2.734926223754883, real = 0.5467444062232971, fake = 0.5463590025901794 1.8273062705993652


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


step10, loss = 3.0699539184570312, real = 0.5439685583114624, fake = 0.5442588925361633 1.8301918506622314
step20, loss = 2.4335453510284424, real = 0.5429933071136475, fake = 0.5430705547332764 1.8319628238677979
step30, loss = 2.305051803588867, real = 0.541258692741394, fake = 0.5415282249450684 1.8344321250915527
step40, loss = 2.6950151920318604, real = 0.5386121273040771, fake = 0.53886878490448 1.838399887084961
step50, loss = 2.2160518169403076, real = 0.5366238355636597, fake = 0.5357046127319336 1.8426668643951416
step60, loss = 2.5116753578186035, real = 0.5312451124191284, fake = 0.5314872860908508 1.8494038581848145
step70, loss = 2.8595895767211914, real = 0.5274835824966431, fake = 0.5271724462509155 1.8557509183883667
step80, loss = 2.4085536003112793, real = 0.5214473009109497, fake = 0.5191435217857361 1.8673744201660156
step90, loss = 2.7087936401367188, real = 0.5103798508644104, fake = 0.5107994079589844 1.8799347877502441


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


step100, loss = 2.4397919178009033, real = 0.4974340498447418, fake = 0.49732428789138794 1.901609182357788


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


step110, loss = 2.3171322345733643, real = 0.47851160168647766, fake = 0.4766179919242859 1.9304732084274292
step120, loss = 2.342245101928711, real = 0.4516436457633972, fake = 0.4538847506046295 1.966871738433838
step130, loss = 2.41424822807312, real = 0.41976746916770935, fake = 0.4255780577659607 2.01167631149292
step140, loss = 3.0267441272735596, real = 0.40946388244628906, fake = 0.40498289465904236 2.0470528602600098
step150, loss = 3.103707790374756, real = 0.3883110582828522, fake = 0.38673701882362366 2.076249837875366
step160, loss = 2.590350389480591, real = 0.3884177803993225, fake = 0.39620840549468994 2.060598134994507
step170, loss = 2.500044107437134, real = 0.4031396508216858, fake = 0.3994787335395813 2.0555899143218994
step180, loss = 2.4495463371276855, real = 0.4098573923110962, fake = 0.4118758738040924 2.0377187728881836
step190, loss = 2.9849939346313477, real = 0.4091266095638275, fake = 0.42028170824050903 2.0205678939819336


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


step200, loss = 2.457486629486084, real = 0.4226399064064026, fake = 0.4185061454772949 2.020998477935791


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


step210, loss = 2.4739959239959717, real = 0.4214203357696533, fake = 0.4293369650840759 2.0053277015686035
step220, loss = 2.413235902786255, real = 0.4481106996536255, fake = 0.4503028094768524 1.9734009504318237
step230, loss = 2.3775200843811035, real = 0.4415099322795868, fake = 0.44849056005477905 1.974179744720459
step240, loss = 2.329042673110962, real = 0.45137569308280945, fake = 0.4584653377532959 1.9611557722091675
step250, loss = 2.4522252082824707, real = 0.45575886964797974, fake = 0.44931942224502563 1.9727340936660767
step260, loss = 2.4274816513061523, real = 0.46875739097595215, fake = 0.4694766104221344 1.942176342010498
step270, loss = 2.339890480041504, real = 0.46826404333114624, fake = 0.46816375851631165 1.9447444677352905
step280, loss = 2.410634994506836, real = 0.4711078405380249, fake = 0.46434250473976135 1.9505523443222046
step290, loss = 2.9942569732666016, real = 0.46266210079193115, fake = 0.45903700590133667 1.9577598571777344


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


step300, loss = 2.9355382919311523, real = 0.4691929817199707, fake = 0.4680069088935852 1.944007396697998


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


step310, loss = 2.4298079013824463, real = 0.4684213697910309, fake = 0.46442124247550964 1.9505743980407715
step320, loss = 2.3967607021331787, real = 0.47937941551208496, fake = 0.46763256192207336 1.942293405532837
step330, loss = 2.3647918701171875, real = 0.4719278812408447, fake = 0.47415563464164734 1.9365752935409546
step340, loss = 2.4009196758270264, real = 0.45961394906044006, fake = 0.4650000035762787 1.9507126808166504
step350, loss = 2.4275760650634766, real = 0.48427948355674744, fake = 0.48001736402511597 1.9299753904342651
step360, loss = 2.3414947986602783, real = 0.48981809616088867, fake = 0.4832304120063782 1.9202570915222168


In [None]:
plt.plot(loss_history)
plt.show()
torch.cuda.empty_cache()

In [None]:
# torch.cuda.empty_cache()
# model_origin = StepAE().to(device)
# model_origin.load_state_dict(torch.load("cifar10.pt"))
model_origin.eval()