# Learning to Drive in Adverse Weather Conditions

In [None]:
torch.cuda.empty_cache()

## Introduction & Setup

This project is designed to explore using Reinforcement learning to teach an autonomous agent to drive in adverse weather conditions.

### Setup

- The project will be performed using an autonomous drving simulator called CARLA.
- Python 3.8
- Anconda

[Project Github](https://github.com/rbuckley25/Tempestas)

### Image Segmentor

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import random
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import carla

from torch.utils.tensorboard import SummaryWriter

from torch.utils.data import DataLoader, Subset 
from torch.utils.data import ConcatDataset
from torchvision.io import read_image
from torch.utils.data import Dataset
from torchvision.utils import save_image
from models import PerceptionNet
from utils import CropResizeTransform,Hflip,recode_tags


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [12]:
class CustomImageDataset(Dataset):
    def __init__(self, weather, town, test=False , transform=None, target_transform=None):
        dirt = './Data/'+weather+'/'+town
        if test:
            dirt = dirt+'/test'
        
        self.sem_dir = dirt+'/Semantic'
        self.rgb_dir = dirt+'/RGB'
        self.transform = transform
        self.target_transform = target_transform
        self.names = os.listdir(self.rgb_dir)

    def __len__(self):
        return len(os.listdir(self.sem_dir))

    def __getitem__(self, idx): 
        img_path = os.path.join(self.rgb_dir, self.names[idx])
        image = read_image(img_path)
        label_name = self.names[idx].split('.')[0]+'.npy'
        label = np.load(os.path.join(self.sem_dir, label_name))
        label = torch.tensor(label).permute(2,0,1)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [13]:
model = PerceptionNet(device)
model.to(device)

PerceptionNet(
  (conv1): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn5): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv6a): Conv2d(512, 64, kernel_size=(4, 4), stride=(1, 1))
  (conv6b): Conv2d(512, 64, kernel_size=(4, 4), stride=(1, 1))
  (conv7): ConvTranspose2d(64, 512, kernel_size=(4, 4), stride=(1, 1))
  (bn6): BatchNorm2d(512, eps=1e-05, mom

In [14]:
def initalize_weights(layer):
    if isinstance(layer, torch.nn.Conv2d) or isinstance(layer,torch.nn.ConvTranspose2d):
        nn.init.kaiming_normal_(layer.weight.data,nonlinearity='relu')

In [15]:
model.apply(initalize_weights)

PerceptionNet(
  (conv1): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (bn5): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv6a): Conv2d(512, 64, kernel_size=(4, 4), stride=(1, 1))
  (conv6b): Conv2d(512, 64, kernel_size=(4, 4), stride=(1, 1))
  (conv7): ConvTranspose2d(64, 512, kernel_size=(4, 4), stride=(1, 1))
  (bn6): BatchNorm2d(512, eps=1e-05, mom

In [16]:
label_freq_weights = np.array([ 426548.,  103471.,   59734.,   61954.,    3507.,   17818.,
         40000., 2355300.,   28709.,  355806.,    4624.,   49085.,
         21033.])

# Some tags are almost empty so any sub 10,000 need to be brougt up otherwise to a baseline
# If not the weighting is unusable 
clean_lables = np.where(label_freq_weights < 10000, 15000, label_freq_weights)

inverse_lables = 1/clean_lables
normalized_lables = inverse_lables/sum(inverse_lables)
lable_weights = torch.tensor(normalized_lables,dtype=torch.float32).to(device)

In [17]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)
loss_fn = nn.CrossEntropyLoss(weight=lable_weights)

Reference: https://github.com/pytorch/examples/blob/master/vae/main.py

In [18]:
crop = CropResizeTransform(64,14,50,100,(128,128))
town03_data = CustomImageDataset('ClearNoon','Town03',test=False)
town03_test_data = CustomImageDataset('ClearNoon','Town03',test=True)
town04_data = CustomImageDataset('ClearNoon','Town04',test=False)
town04_test_data = CustomImageDataset('ClearNoon','Town04',test=True)
town07_data = CustomImageDataset('ClearNoon','Town07',test=False)
#flipped data
town03_data_hf = CustomImageDataset('ClearNoon','Town03',test=False,transform=Hflip(),target_transform=Hflip())
town03_test_data_hf = CustomImageDataset('ClearNoon','Town03',test=True,transform=Hflip(),target_transform=Hflip())
town04_data_hf = CustomImageDataset('ClearNoon','Town04',test=False,transform=Hflip(),target_transform=Hflip())
town04_test_data_hf = CustomImageDataset('ClearNoon','Town04',test=True,transform=Hflip(),target_transform=Hflip())
town07_data_hf = CustomImageDataset('ClearNoon','Town07',test=False,transform=Hflip(),target_transform=Hflip())

#town03_data_cropped = CustomImageDataset('Town03','.',test=False,transform=crop,target_transform=crop)
#town03_test_data_cropped = CustomImageDataset('Town03','.',test=True,transform=crop,target_transform=crop)
#town04_data_cropped = CustomImageDataset('Town04','.',test=False,transform=crop,target_transform=crop)
#town04_test_data_cropped = CustomImageDataset('Town04','.',test=True,transform=crop,target_transform=crop)


train_data = ConcatDataset([town03_data,town04_data,town07_data,town03_data_hf,town04_data_hf,town07_data_hf,
                            ])
#town03_data_cropped, town04_data_cropped
test_data = ConcatDataset([town03_test_data,town04_test_data,town03_test_data_hf,town04_test_data_hf,
                            ])
#town03_test_data_cropped, town04_test_data_cropped
train_loader = DataLoader(train_data, batch_size=512, shuffle=True)
test_loader = DataLoader(test_data, batch_size=512, shuffle=True)

In [None]:
plt.plot(np.sort(normalized_lables))

In [None]:
def train(epoch,writer):
    global step
    model.train()
    train_loss = 0
    
    for batch_idx, data in enumerate(train_loader):
        rgb = data[0]
        target = data[1]
        target = recode_tags(target)
        batch_size = target.shape[0]
        #preds = F.one_hot(preds.to(torch.int64))
        target = target.reshape(batch_size,128,128)
        target = target.to(device,dtype=torch.long)
        optimizer.zero_grad()
        #dont need latent space output while training
        y_batch,_  = model(rgb)
        loss = loss_fn(y_batch,target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_size * batch_idx, len(train_loader.dataset),
            100. * batch_idx / len(train_loader),
            loss.item() / len(data)))

        print(((epoch-1)*10)+step)
        writer.add_scalar("AE Loss", loss.item(), step)
        step += 1

    avg_loss = train_loss / len(train_loader.dataset)
    print('====> Epoch: {} Average loss: {:.8f}'.format(
          epoch, avg_loss))
    writer.flush()
    return avg_loss
    
    


def test(epoch,writer):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            rgb = data[0]
            target = data[1]
            batch_size = target.shape[0]
            target = recode_tags(target)
            #preds = F.one_hot(preds.to(torch.int64))
            target = target.permute(0,3,1,2)
            target= target.reshape(batch_size,128,128)
            target = target.to(device,dtype=torch.long)
            y_batch,_   = model(rgb)
            test_loss += loss_fn(y_batch,target).item()

    print('====> Test set loss: {:.8f}'.format(test_loss))
    writer.add_scalar("AE test Loss", test_loss,epoch)
    writer.flush()

writer = SummaryWriter()
step = 0
for epoch in range(1, 20 + 1):
    
    smallest_loss = 1000
    
    avg_loss = train(epoch,writer)
    if avg_loss < smallest_loss:
        torch.save(model.state_dict(), './AE_params/model_55.best')
    test(epoch,writer)
    torch.save(model.state_dict(), './AE_params/model_55.final')
   

In [None]:
model.load_state_dict(torch.load('./AE_params/model_4.best'))
model.eval()

In [None]:
old_data = town03_data = CustomImageDataset('.','Town03',test=False)
old_data.__len__()

In [None]:
num = random.randint(1,70000)
print(num)
data = train_data.__getitem__(num)
imgs = []
org = data[0]
sem = replace(data[1].numpy().transpose(1,2,0))
imgs.append(Image.fromarray(org.numpy().transpose(1,2,0)))
imgs.append(Image.fromarray(sem))
imgs.append(generate_semantic_im(data[0],model))
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
    axs[0, i].imshow(np.asarray(img))
    axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    #img.save('./AE_results/AE/'+str(num)+'_'+str(i)+'.png')



In [None]:
tag_convert_dict = {0:[70,130,180],
                   1:[70,70,70],
                   2:[100,40,40],
                   3:[55,90,80],
                   4:[220,20,60],
                   5:[153,153,153],
                   6:[157,234,50],
                   7:[128,64,128],
                   8:[244,35,232],
                   9:[107,142,35],
                   10:[0,0,142],
                   11:[102,102,156],
                   12:[220,220,0],
                   13:[70,130,180],
                   14:[81,0,81],
                   15:[150,100,100],
                   16:[230,150,140],
                   17:[180,165,180],
                   18:[250,170,30],
                   19:[110,190,160],
                   20:[170,120,50],
                   21:[45,60,150],
                   22:[145,170,100],
                  }

In [None]:
def generate_semantic_im(RGB_image):
    new_obs = RGB_image.reshape(1,3,128,128)
    out,_ ,_,_ = model(new_obs)
    sample = out.cpu().argmax(dim=1)
    print(sample.shape)
    pic = replace(sample.numpy())
    return Image.fromarray(pic,'RGB')

In [None]:
def replace(a):
    a = a.reshape(128,128)
    pic = np.zeros((128,128,3),dtype='uint8')
    for x, y in np.ndindex(a.shape):
        value = a[x,y]
        RGB_values = tag_convert_dict[value]
        pic[x,y,0] = RGB_values[0]
        pic[x,y,1] = RGB_values[1]
        pic[x,y,2] = RGB_values[2]
    return pic

In [None]:
def show(imgs):
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = T.ToPILImage()(img.to('cpu'))
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
    num = random.randint(0,6000)
    data = town03_data.__getitem__(num)
    imgs = []
    org = data[0]
    org = TF.resized_crop(org,64,14,50,100,(128,128))
    sem = TF.resized_crop(torch.tensor(data[1]).permute(2,0,1),64,14,50,100,(128,128))
    
    sem = replace(data[1])


    imgs.append(Image.fromarray(org.numpy().transpose(1,2,0)))
    imgs.append(Image.fromarray(sem.reshape(128,128,3)))
    imgs.append(generate_semantic_im(org))
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])