In [None]:
class GANImageDataset(Dataset):
    def __init__(self, weather, town, test=False , transform=None, target_transform=None):
        dirt = './Datasets/'+weather+'/'+town
        if test:
            dirt = dirt+'/test'
        
        self.sem_dir = dirt+'/Semantic'
        self.rgb_dir = dirt+'/RGB'
        self.transform = transform
        self.target_transform = target_transform
        self.names = os.listdir(self.rgb_dir)

    def __len__(self):
        return len(os.listdir(self.sem_dir))

    def __getitem__(self, idx): 
        img_path = os.path.join(self.rgb_dir, self.names[idx])
        image = read_image(img_path)
        label_name = self.names[idx].split('.')[0]+'.npy'
        label = np.load(os.path.join(self.sem_dir, label_name))
        label = torch.tensor(label).permute(2,0,1)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import random
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import carla
from utils import AE_initalize_weights


from torch.utils.tensorboard import SummaryWriter

from torch.utils.data import DataLoader, Subset 
from torch.utils.data import ConcatDataset
from torchvision.io import read_image
from torch.utils.data import Dataset
from torchvision.utils import save_image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class PerceptionNet(nn.Module):

    def __init__(self):
        super(PerceptionNet,self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        
        self.conv5 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
        self.bn5 = nn.BatchNorm2d(512)
        
        self.conv6a = nn.Conv2d(512, 64, kernel_size=4, stride=1)
        self.conv6b = nn.Conv2d(512, 64, kernel_size=4, stride=1)
        
        self.conv7 = torch.nn.ConvTranspose2d(64,512, kernel_size =4, stride=1)
        self.bn6 = nn.BatchNorm2d(512)
        
        self.conv8 = torch.nn.ConvTranspose2d(512,256, kernel_size =4, stride=2, padding=1)
        self.bn7 = nn.BatchNorm2d(256)
        
        self.conv9 = torch.nn.ConvTranspose2d(256,128, kernel_size =4, stride=2, padding=1)
        self.bn8 = nn.BatchNorm2d(128)
        
        self.conv10 = torch.nn.ConvTranspose2d(128,64, kernel_size =4, stride=2, padding=1)
        self.bn9 = nn.BatchNorm2d(64)
        
        self.conv11 = torch.nn.ConvTranspose2d(64,32, kernel_size =4, stride=2, padding=1)
        self.bn10 = nn.BatchNorm2d(32)
        
        self.conv12 = torch.nn.ConvTranspose2d(32,13, kernel_size =4, stride=2,padding=1)
        
            
    def encode(self, x):
        x = F.leaky_relu(self.conv1(x),negative_slope=0.02)
        x = F.leaky_relu(self.bn2(self.conv2(x)),negative_slope=0.02)
        x = F.leaky_relu(self.bn3(self.conv3(x)),negative_slope=0.02)
        x = F.leaky_relu(self.bn4(self.conv4(x)),negative_slope=0.02)
        x = F.leaky_relu(self.bn5(self.conv5(x)),negative_slope=0.02)
        return self.conv6a(x)

    
    def decode(self, x):
        x = F.leaky_relu(self.bn6(self.conv7(x)),negative_slope=0.02)
        x = F.leaky_relu(self.bn7(self.conv8(x)),negative_slope=0.02)
        x = F.leaky_relu(self.bn8(self.conv9(x)),negative_slope=0.02)
        x = F.leaky_relu(self.bn9(self.conv10(x)),negative_slope=0.02)
        x = F.leaky_relu(self.bn10(self.conv11(x)),negative_slope=0.02)
        return torch.sigmoid(self.conv12(x))

    def reparameterize(self,mu,logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        latent_sample = mu + eps*std
        return latent_sample
    
    def forward(self, x):
        x = x.to(device, dtype=torch.float32)
        latent = self.encode(x)
        #latent = self.reparameterize(mu,logvar)
        out = self.decode(latent)
        return out, latent

In [3]:
master = PerceptionNet()
servant = PerceptionNet()

master.to(device)
master.load_state_dict(torch.load('./AE_params/model_43.best'))
master.eval()

servant.to(device)
servant.apply(AE_initalize_weights)
servant.train()

gan.to(device)

NameError: name 'CycleGan' is not defined

In [None]:
crop_transform = CropResizeTransform(64,14,50,100,(128,128))
town03_data = CustomImageDataset('ClearNoon','Town03',test=False)
town03_test_data = CustomImageDataset('ClearNoon','Town03',test=True)
town04_data = CustomImageDataset('ClearNoon','Town04',test=False)
town04_test_data = CustomImageDataset('ClearNoon','Town04',test=True)
town07_data = CustomImageDataset('ClearNoon','Town07',test=False)
#flipped data
town03_data_hf = CustomImageDataset('ClearNoon','Town03',test=False,transform=Hflip(),target_transform=Hflip())
town03_test_data_hf = CustomImageDataset('ClearNoon','Town03',test=True,transform=Hflip(),target_transform=Hflip())
town04_data_hf = CustomImageDataset('ClearNoon','Town04',test=False,transform=Hflip(),target_transform=Hflip())
town04_test_data_hf = CustomImageDataset('ClearNoon','Town04',test=True,transform=Hflip(),target_transform=Hflip())
town07_data_hf = CustomImageDataset('ClearNoon','Town07',test=False,transform=Hflip(),target_transform=Hflip())

train_data = ConcatDataset([town03_data,town04_data,town07_data,town03_data_hf,town04_data_hf,town07_data_hf])
test_data = ConcatDataset([town03_test_data,town04_test_data,town03_test_data_hf,town04_test_data_hf])

train_loader = DataLoader(train_data, batch_size=512, shuffle=True)
test_loader = DataLoader(test_data, batch_size=512, shuffle=True)

In [None]:
def train(epoch,recode_dict):
    global step
    servant.train()
    writer = SummaryWriter()
    train_loss = 0
    
    for batch_idx, data in enumerate(train_loader):
        rgb = data[0]
        target = data[1]
        target = recode_tags(target,recode_dict)
        batch_size = target.shape[0]
        #preds = F.one_hot(preds.to(torch.int64))
        target = target.reshape(batch_size,128,128)
        target = target.to(device,dtype=torch.long)
        optimizer.zero_grad()
        #dont need latent space output while training
        y_batch,_ = model(rgb)
        loss = loss_fn(y_batch,target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_size * batch_idx, len(train_loader.dataset),
            100. * batch_idx / len(train_loader),
            loss.item() / len(data)))

        writer.add_scalar("AE Loss", loss.item() / len(data[0]),step)
        step += 1

    avg_loss = train_loss / len(train_loader.dataset)
    print('====> Epoch: {} Average loss: {:.8f}'.format(
          epoch, avg_loss))
    writer.flush()
    return avg_loss
    
    


def test(epoch,recode_dict):
    model.eval()
    writer = SummaryWriter()
    test_loss = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            rgb = data[0]
            target = data[1]
            batch_size = target.shape[0]
            target = recode_tags(target,recode_dict)
            #preds = F.one_hot(preds.to(torch.int64))
            target = target.permute(0,3,1,2)
            target= target.reshape(batch_size,128,128)
            target = target.to(device,dtype=torch.long)
            y_batch,_ = model(rgb)
            test_loss += loss_fn(y_batch,target).item()

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.8f}'.format(test_loss))
    writer.add_scalar("AE test Loss", test_loss,epoch)
    writer.flush()
    writer.close()


for epoch in range(1, 20 + 1):
    step = 0
    smallest_loss = 1000
    
    avg_loss = train(epoch,recode_dict)
    if avg_loss < smallest_loss:
        torch.save(model.state_dict(), './AE_params/model_44.best')
    test(epoch,recode_dict)
    torch.save(model.state_dict(), './AE_params/model_44.final')