In [30]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
from torchvision.io import read_image

import torch.nn as nn             # for torch.nn.Module, the parent object for Pytorch models
import torch.nn.functional as F   # for activation function

import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np

In [42]:
class TrafficSignDataset(Dataset):
    
    def __init__(self, annotations, directory, transform=None):
        # directory containing the images
        self.directory = directory
        
        # loading the csv with info about images
        annotations_file_dir = os.path.join(self.directory, annotations)
        self.labels = pd.read_csv(annotations_file_dir)
        
        # transform to be applied on images
        self.transform = transform
 
        # Number of images in dataset
        self.len = self.labels.shape[0]
 
    # getting the length
    def __len__(self):
        return len(self.labels)
 
    # getting the data items
    def __getitem__(self, idx):
        # defining the image path
        image_path = os.path.join(self.directory, self.labels.iloc[idx, 0])
        # reading the images
        image = read_image(image_path)
        # corresponding class labels of the images 
        label = self.labels.iloc[idx, 1]
 
        # apply the transform if not set to None
        if self.transform:
            image = self.transform(image)
        
        # returning the image and label
        return image, label

In [None]:
# instantiate dataset
directory = "data"
annotations = "annotations.csv"
custom_dataset = TrafficSignDataset(directory=directory,
                                    annotations=annotations)

In [None]:
# data loader


In [None]:
# GTSRB labels reference: https://github.com/magnusja/GTSRB-caffe-model/blob/master/labeller/main.py
label_map = {
    '0': 'stop',
    '1': 'turn_straight_left',
    '2': '20_speed'
}

In [None]:
# Need datasets for:
# traffic signs
# traffic lights
# vehicles

# use existing models? such as ResNet50
