In [1]:
import os
import random
import numpy as np
from tqdm import tqdm
from IPython.display import HTML

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
matplotlib.rcParams['figure.figsize'] = [5, 5]
matplotlib.rcParams['figure.dpi'] = 200

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils

from data_helper import UnlabeledDataset, LabeledDataset
from helper import collate_fn, draw_box

In [2]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0);

In [3]:
device = ("cuda:0" if torch.cuda.is_available() else "cpu")
device

'cuda:0'

### Set Initial Parameters

In [4]:
batch_size = 8

### Create Dataset and Dataloader

In [5]:
data_dir = '/home/alexander/data'
annotation_csv = '/home/alexander/data/annotation.csv'

# You shouldn't change the unlabeled_scene_index
# The first 106 scenes are unlabeled
unlabeled_scene_index = np.arange(106)
# The scenes from 106 - 133 are labeled
# You should devide the labeled_scene_index into two subsets (training and validation)
labeled_scene_index = np.arange(106, 130)

In [6]:
transform = torchvision.transforms.ToTensor()

# unlabeled_trainset = UnlabeledDataset(image_folder=image_folder, scene_index=labeled_scene_index, first_dim='sample', transform=transform)
# trainloader = torch.utils.data.DataLoader(unlabeled_trainset, batch_size=3, shuffle=True, num_workers=2)

In [7]:
labeled_trainset = LabeledDataset(image_folder=data_dir,
                                  annotation_file=annotation_csv,
                                  scene_index=labeled_scene_index,
#                                   img_transform = transforms.ToTensor(),
#                                   map_transform = None,
                                  img_transform=transforms.Compose([transforms.CenterCrop(256),
                                                                    transforms.ToTensor(),
                                                                    transforms.Normalize(mean=(0.5,), std=(0.5,))
                                                                   ]),
                                  map_transform=transforms.Compose([transforms.ToPILImage(),
                                                                    transforms.Resize(256),
                                                                    transforms.ToTensor()
                                                                   ]),
                                  extra_info=True
                                 )

trainloader = torch.utils.data.DataLoader(labeled_trainset,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=2,
                                          collate_fn=collate_fn)

In [8]:
labeled_trainset[1][0].size()

torch.Size([18, 256, 256])

In [9]:
for images, road_maps in trainloader:
    print(images.size())
    break

# # sample, target, road_image, extra = iter(trainloader).next()
# # print(len(sample[0]))
# # # print(torch.stack(sample).shape)
# # print(road_image[0].size())

torch.Size([8, 18, 256, 256])


In [10]:
# print(road_image[0])

In [11]:
# # The 6 images orgenized in the following order:
# # CAM_FRONT_LEFT, CAM_FRONT, CAM_FRONT_RIGHT, CAM_BACK_LEFT, CAM_BACK, CAM_BACK_RIGHT
# plt.imshow(torchvision.utils.make_grid(sample[4], nrow=3).numpy().transpose(1, 2, 0))
# plt.axis('off');

In [12]:
# # The road map layout is encoded into a binary array of size [800, 800] per sample 
# # Each pixel is 0.1 meter in physiscal space, so 800 * 800 is 80m * 80m centered at the ego car
# # The ego car is located in the center of the map (400, 400) and it is always facing the left

# fig, ax = plt.subplots()
# ax.imshow(road_image[4][0], cmap='binary');

### Define models

In [13]:
class ResnetBlock(nn.Module):
    """
    Define a Resnet block
    A resnet block is a conv block with skip connections
    Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
    """
    def __init__(self, dim):
        """
        Parameters:
            dim (int)           -- the number of channels in the conv layer.
            padding_type (str)  -- the name of padding layer: reflect | replicate | zero
            use_dropout (bool)  -- if use dropout layers.
            use_bias (bool)     -- if the conv layer uses bias or not
        """
        super(ResnetBlock, self).__init__()
        self.resnet_block = nn.Sequential(nn.Conv2d(dim, dim, kernel_size=3, padding=1),
                                          nn.BatchNorm2d(dim),
                                          nn.ReLU(True),
                                          nn.Dropout(0.5),
                                          nn.Conv2d(dim, dim, kernel_size=3, padding=1),
                                          nn.BatchNorm2d(dim)
                                         )
        
    def forward(self, x):
        return x + self.resnet_block(x)  # add skip connections

In [14]:
class Generator(nn.Module):
    """
    RESNET-based generator that consists of Resnet blocks + downsampling/upsampling operations.
    """
    def __init__(self, in_ch, out_ch, ngf, n_blocks=6, init_gain=0.02):
        """
        Parameters:
            in_ch (int)         -- the number of channels in input images
            out_ch (int)        -- the number of channels in output images
            ngf (int)           -- the number of filters in the last conv layer
            n_blocks (int)      -- the number of ResNet blocks
            padding_type (str)  -- the name of padding layer in conv layers: reflect | replicate | zero
        """        
        assert(n_blocks >= 0)
        super(Generator, self).__init__()
        
        self.model = nn.Sequential(
            nn.ReflectionPad2d(3),
            
            nn.Conv2d(in_ch, ngf, kernel_size=7, padding=0),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            
            nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            
            nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            
            ResnetBlock(ngf * 4),  #1
            ResnetBlock(ngf * 4),  #2
            ResnetBlock(ngf * 4),  #3
            ResnetBlock(ngf * 4),  #4
            ResnetBlock(ngf * 4),  #5
            ResnetBlock(ngf * 4),  #6
            ResnetBlock(ngf * 4),  #7
            ResnetBlock(ngf * 4),  #8
            ResnetBlock(ngf * 4),  #9

            nn.ConvTranspose2d(ngf * 4, int(ngf * 2), kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(int(ngf * 2)),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf * 2, int(ngf), kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(int(ngf)),
            nn.ReLU(True),

            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, out_ch, kernel_size=7, padding=0),
            nn.Tanh()
        )
        
    def forward(self, x):
        return self.model(x)

In [15]:
class Discriminator(nn.Module):
    def __init__(self, in_ch, ndf=64, n_layers=3):  
        super(Discriminator, self).__init__()
        
        self.block1 = nn.Sequential(nn.Conv2d(in_ch, ndf, kernel_size=4, stride=2, padding=1),
                                    nn.LeakyReLU(0.2, True)
                                   )
        
        nf_mult = 1
        nf_mult_prev = 1
        sequence2 = []        
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence2 += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=2, padding=1),
                         nn.BatchNorm2d(ndf * nf_mult),
                         nn.LeakyReLU(0.2, True)
                         ]
        self.block2 = nn.Sequential(*sequence2)
        
        sequence3 = []
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence3 += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=1, padding=1),
                     nn.BatchNorm2d(ndf * nf_mult),
                     nn.LeakyReLU(0.2, True)
                     ]
        self.block3 = nn.Sequential(*sequence3)
        
        self.model = nn.Sequential(self.block1,
                                   self.block2,
                                   self.block3,
                                   nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)
                                  )
        
    def forward(self, x):
        return self.model(x)

In [16]:
in_ch = 18
out_ch = 1
ngf = 64
ndf = 64
n_blocks_g = 9
n_layers_d = 3

n_epochs = 100
lr_g = 0.0001
lr_d = 0.0004
beta1 = 0.5
beta2 = 0.999
L1_lambda = 100

real_label = 0.9
gen_label = 0.0

In [17]:
# Sanity check - dimensions of inputs and outputs to both networks.

test = torch.zeros([1, 18, 256, 256])
print("Input to generator:\t", test.size())

test_generator = Generator(in_ch, out_ch, ngf, n_blocks=n_blocks_g, init_gain=0.02)
test_result_g = test_generator(test)
print("\nOutput of generator:\t", test_result_g.size())

test_d_input = torch.cat((test_result_g, test), 1)
print("\nInput to discriminator:\t", test_d_input.size())

test_discriminator = Discriminator(in_ch + out_ch, ndf=ndf, n_layers=n_layers_d)
test_result_d = test_discriminator(test_d_input)
print("\nOutput of discriminator:", test_result_d.size())

Input to generator:	 torch.Size([1, 18, 256, 256])

Output of generator:	 torch.Size([1, 1, 256, 256])

Input to discriminator:	 torch.Size([1, 19, 256, 256])

Output of discriminator: torch.Size([1, 1, 30, 30])


In [18]:
class GANLoss(nn.Module):
    """
    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """
    def __init__(self, gan_mode, real_label=1.0, gen_label=0.0):
        """
        Parameters:
            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_gen_label (bool) - - label of a generated image
        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
        
        self.gan_mode = gan_mode
        self.real_label = real_label
        self.gen_label = gen_label
        
        if gan_mode == 'LS':
            self.loss = nn.MSELoss().to(device)
        elif gan_mode == 'BCE':
            self.loss = nn.BCEWithLogitsLoss().to(device)

    def get_target_tensor(self, output, target_is_real):
        # Create label tensors with same size as the discriminator output.
        if target_is_real:
            if "cuda" in device:
                target_tensor = torch.cuda.FloatTensor([self.real_label])
            else:
                target_tensor = torch.Tensor([self.real_label])
        else:
            if "cuda" in device:
                target_tensor = torch.cuda.FloatTensor([self.gen_label])
            else:
                target_tensor = torch.Tensor([self.gen_label])
                
        return target_tensor.expand_as(output)

    def __call__(self, output, target_is_real):
        target_tensor = self.get_target_tensor(output, target_is_real)
        loss = self.loss(output, target_tensor)
        return loss

In [19]:
def init_weights(m):  # define the initialization function
    classname = m.__class__.__name__
    if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
        init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)

In [20]:
generator = Generator(in_ch, out_ch, ngf, n_blocks=n_blocks_g, init_gain=0.02).to(device)
discriminator = Discriminator(in_ch + out_ch, ndf=ndf, n_layers=n_layers_d).to(device)

generator.apply(init_weights)
discriminator.apply(init_weights)

generator.train()
discriminator.train()

Discriminator(
  (block1): Sequential(
    (0): Conv2d(19, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (block2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (block3): Sequential(
    (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (model): Sequential(
    (0): Sequential(
      (0): Conv2d(19, 64, kernel_size=(4, 4), stride=(2, 2), padding=(

In [21]:
# We combine GAN Loss and L1 Loss to attain Total Loss
criterion_gan = GANLoss('BCE', real_label=real_label, gen_label=gen_label).to(device)
criterion_L1 = nn.L1Loss().to(device)

optim_G = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
optim_D = torch.optim.Adam(discriminator.parameters(), lr=lr_d, betas=(beta1, beta2))

### Train

In [22]:
iters = 0
update_stats_rate = 50
losses_g = []
losses_d = []
intermediate_images = []

In [23]:
####################################
########## TRAINING LOOP ###########
####################################
for epoch in range(n_epochs):
    i = 1
    for real_18ch, real_map in tqdm(trainloader):
        """
        TRAIN DISCRIMINATOR ON REAL AND GENERATED SAMPLES
        """
        real_18ch = real_18ch.to(device)
        real_map = real_map.to(device)
        
        gen_map = generator(real_18ch)
        
        real_19ch = torch.cat((real_18ch, real_map), 1).to(device)
        gen_19ch = torch.cat((real_18ch, gen_map), 1).to(device)

        output_real = discriminator(real_19ch)
        output_gen = discriminator(gen_19ch.detach())
        
        d_real_mean = output_real.mean().item()
        d_gen_mean_pre = output_gen.mean().item()

        # Make label tensors.
        target_real = torch.Tensor([real_label]).expand_as(output_real).to(device)
        target_gen = torch.Tensor([gen_label]).expand_as(output_gen).to(device)

        # Compute loss.
        loss_d_real = criterion_gan(output_real, target_is_real=True)
        loss_d_gen = criterion_gan(output_gen, target_is_real=False)
        loss_d_total = (loss_d_real + loss_d_gen)

        discriminator.zero_grad()
        loss_d_total.backward()
        optim_D.step()
        
        """
        TRAIN GENERATOR
        """
        output_gen = discriminator(gen_19ch)
        target_real = torch.Tensor([real_label]).expand_as(output_real).to(device)
        loss_g = criterion_gan(output_gen, target_is_real=True) + L1_lambda * criterion_L1(gen_map, real_map)
#         loss_g = criterion_gan(output_gen, target_is_real=True)
        d_gen_mean_post = output_gen.mean().item()
        
        generator.zero_grad()
        loss_g.backward()
        optim_G.step()
        
        """
        COMPILE TRAINING STATISTICS
        """
        losses_d.append(loss_d_total)
        losses_g.append(loss_g)
        
        # Print out training stats
        if i % update_stats_rate == 1:
            print('Epoch: [%d/%d]\tBatch: [%d/%d]\nLoss_D: %.4f\tLoss_G: %.4f\nD(real): %.4f  ||  D(gen) Pre-G-step: %.4f  ||  D(gen) Post-G-step: %.4f'
                  % (epoch + 1, n_epochs,
                     i, len(trainloader),
                     loss_d_total.item(), loss_g.item(),
                     d_real_mean, d_gen_mean_pre, d_gen_mean_post))
        
        # Append current generated image to list of intermediate images to view post training
        if (iters % update_stats_rate == 0) or epoch == n_epochs-1:
            with torch.no_grad():
                output_gen = generator(real_18ch).detach().cpu()
            intermediate_images.append(vutils.make_grid(output_gen, padding=2, normalize=True))
        
        i += 1
        iters += 1
        

  0%|          | 1/378 [00:01<11:42,  1.86s/it]

Epoch: [1/100]	Batch: [1/378]
Loss_D: 1.6159	Loss_G: 56.0310
D(real): -0.2356  ||  D(gen) Pre-G-step: -0.2298  ||  D(gen) Post-G-step: -0.3413


  6%|▌         | 21/378 [00:11<03:40,  1.62it/s]

Epoch: [1/100]	Batch: [21/378]
Loss_D: 1.4294	Loss_G: 31.0258
D(real): -0.1999  ||  D(gen) Pre-G-step: -0.2046  ||  D(gen) Post-G-step: -0.2130


 11%|█         | 41/378 [00:21<03:02,  1.85it/s]

Epoch: [1/100]	Batch: [41/378]
Loss_D: 1.3902	Loss_G: 25.2355
D(real): -0.1765  ||  D(gen) Pre-G-step: -0.1884  ||  D(gen) Post-G-step: -0.2243


 16%|█▌        | 61/378 [00:31<02:54,  1.82it/s]

Epoch: [1/100]	Batch: [61/378]
Loss_D: 1.4016	Loss_G: 24.3958
D(real): -0.2108  ||  D(gen) Pre-G-step: -0.2130  ||  D(gen) Post-G-step: -0.1722


 21%|██▏       | 81/378 [00:40<02:49,  1.75it/s]

Epoch: [1/100]	Batch: [81/378]
Loss_D: 1.3719	Loss_G: 26.6460
D(real): -0.1837  ||  D(gen) Pre-G-step: -0.2286  ||  D(gen) Post-G-step: -0.2293


 27%|██▋       | 101/378 [00:49<02:25,  1.91it/s]

Epoch: [1/100]	Batch: [101/378]
Loss_D: 1.4250	Loss_G: 26.8321
D(real): -0.0977  ||  D(gen) Pre-G-step: -0.2280  ||  D(gen) Post-G-step: -0.3516


 32%|███▏      | 121/378 [01:00<02:33,  1.68it/s]

Epoch: [1/100]	Batch: [121/378]
Loss_D: 1.3042	Loss_G: 24.1908
D(real): -0.1312  ||  D(gen) Pre-G-step: -0.3657  ||  D(gen) Post-G-step: -0.3343


 37%|███▋      | 141/378 [01:09<02:15,  1.75it/s]

Epoch: [1/100]	Batch: [141/378]
Loss_D: 1.2783	Loss_G: 19.2819
D(real): 0.0464  ||  D(gen) Pre-G-step: -0.2639  ||  D(gen) Post-G-step: -0.3404


 43%|████▎     | 161/378 [01:19<02:10,  1.66it/s]

Epoch: [1/100]	Batch: [161/378]
Loss_D: 1.3104	Loss_G: 20.0129
D(real): -0.0102  ||  D(gen) Pre-G-step: -0.3387  ||  D(gen) Post-G-step: -0.3355


 48%|████▊     | 181/378 [01:28<01:46,  1.85it/s]

Epoch: [1/100]	Batch: [181/378]
Loss_D: 1.2004	Loss_G: 24.1573
D(real): 0.4693  ||  D(gen) Pre-G-step: -0.1920  ||  D(gen) Post-G-step: -0.7360


 53%|█████▎    | 201/378 [01:37<01:33,  1.90it/s]

Epoch: [1/100]	Batch: [201/378]
Loss_D: 1.1348	Loss_G: 16.9678
D(real): 0.2451  ||  D(gen) Pre-G-step: -0.5660  ||  D(gen) Post-G-step: -0.6607


 58%|█████▊    | 221/378 [01:46<01:19,  1.96it/s]

Epoch: [1/100]	Batch: [221/378]
Loss_D: 1.5701	Loss_G: 16.0813
D(real): 0.1063  ||  D(gen) Pre-G-step: -0.4139  ||  D(gen) Post-G-step: -0.4692


 64%|██████▍   | 241/378 [01:56<01:20,  1.70it/s]

Epoch: [1/100]	Batch: [241/378]
Loss_D: 0.8738	Loss_G: 17.9817
D(real): 0.4183  ||  D(gen) Pre-G-step: -1.2808  ||  D(gen) Post-G-step: -0.9505


 69%|██████▉   | 261/378 [02:06<01:05,  1.77it/s]

Epoch: [1/100]	Batch: [261/378]
Loss_D: 0.9968	Loss_G: 24.4938
D(real): 0.1603  ||  D(gen) Pre-G-step: -1.3954  ||  D(gen) Post-G-step: -1.1841


 74%|███████▍  | 281/378 [02:15<00:55,  1.75it/s]

Epoch: [1/100]	Batch: [281/378]
Loss_D: 2.1405	Loss_G: 24.2041
D(real): -0.9222  ||  D(gen) Pre-G-step: -1.7219  ||  D(gen) Post-G-step: -0.6909


 80%|███████▉  | 301/378 [02:25<00:42,  1.83it/s]

Epoch: [1/100]	Batch: [301/378]
Loss_D: 0.9274	Loss_G: 20.3977
D(real): 0.1584  ||  D(gen) Pre-G-step: -1.2838  ||  D(gen) Post-G-step: -1.1451


 85%|████████▍ | 321/378 [02:34<00:31,  1.80it/s]

Epoch: [1/100]	Batch: [321/378]
Loss_D: 0.8318	Loss_G: 21.9775
D(real): 0.5413  ||  D(gen) Pre-G-step: -1.5818  ||  D(gen) Post-G-step: -1.5920


 90%|█████████ | 341/378 [02:44<00:20,  1.79it/s]

Epoch: [1/100]	Batch: [341/378]
Loss_D: 0.5629	Loss_G: 19.9760
D(real): 1.1404  ||  D(gen) Pre-G-step: -1.9773  ||  D(gen) Post-G-step: -2.3403


 91%|█████████ | 344/378 [02:46<00:16,  2.07it/s]


KeyboardInterrupt: 

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(losses_g,label="G")
plt.plot(losses_d,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
#%%capture
fig = plt.figure()
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in intermediate_images]
ani = animation.ArtistAnimation(fig, ims, interval=500, repeat_delay=500, blit=True)

HTML(ani.to_jshtml())

In [None]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(trainloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()