In [2]:
import os
if not os.path.exists('GTSRB'):
    !pip install -U -q torch_snippets
    !wget -qq https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Training_Images.zip
    !wget -qq https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip
    !unzip -qq GTSRB_Final_Training_Images.zip
    !unzip -qq GTSRB_Final_Test_Images.zip
    !wget https://raw.githubusercontent.com/georgesung/traffic_sign_classification_german/master/signnames.csv
    !rm GTSRB_Final_Training_Images.zip GTSRB_Final_Test_Images.zip

from torch_snippets import *

--2020-12-14 11:48:10--  https://raw.githubusercontent.com/georgesung/traffic_sign_classification_german/master/signnames.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.132.133
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.132.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 999 [text/plain]
Saving to: ‘signnames.csv’


2020-12-14 11:48:11 (102 MB/s) - ‘signnames.csv’ saved [999/999]



In [3]:
classIds = pd.read_csv('signnames.csv')
classIds.set_index('ClassId', inplace=True)
classIds = classIds.to_dict()['SignName']
classIds = {f'{k:05d}':v for k,v in classIds.items()}
id2int = {v:ix for ix,(k,v) in enumerate(classIds.items())}

In [4]:
from torchvision import transforms as T
trn_tfms = T.Compose([
                T.ToPILImage(),
                T.Resize(32),
                T.CenterCrop(32),
                T.ColorJitter(brightness=(0.8,1.2), 
                contrast=(0.8,1.2), 
                saturation=(0.8,1.2), 
                hue=0.25),
                T.RandomAffine(5, translate=(0.01,0.1)),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225]),
            ])

In [5]:
val_tfms = T.Compose([
                T.ToPILImage(),
                T.Resize(32),
                T.CenterCrop(32),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225]),
            ])

In [7]:
class GTSRB(Dataset):

    def __init__(self, files, transform=None):
        self.files = files
        self.transform = transform
        logger.info(len(self))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, ix):
        fpath = self.files[ix]
        clss = fname(parent(fpath))
        img = read(fpath, 1)
        return img, classIds[clss]

    def choose(self):
        return self[randint(len(self))]
    
    def collate_fn(self, batch):
        imgs, classes = list(zip(*batch))
        if self.transform:
            imgs =[self.transform(img)[None] \
                   for img in imgs]
        classes = [torch.tensor([id2int[clss]]) \
                   for clss in classes]
        imgs, classes = [torch.cat(i).to(device) \
                         for i in [imgs, classes]]
        return imgs, classes

In [8]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
all_files = Glob('GTSRB/Final_Training/Images/*/*.ppm')
np.random.seed(10)
np.random.shuffle(all_files)

from sklearn.model_selection import train_test_split
trn_files, val_files = train_test_split(all_files, \
                                        random_state=1)

trn_ds = GTSRB(trn_files, transform=trn_tfms)
val_ds = GTSRB(val_files, transform=val_tfms)
trn_dl = DataLoader(trn_ds, 32, shuffle=True, \
                    collate_fn=trn_ds.collate_fn)
val_dl = DataLoader(val_ds, 32, shuffle=False, \
                    collate_fn=val_ds.collate_fn)

2020-12-14 11:50:44.934 | INFO     | torch_snippets.loader:Glob:190 - 39209 files found at GTSRB/Final_Training/Images/*/*.ppm
2020-12-14 11:50:45.103 | INFO     | __main__:__init__:6 - 29406
2020-12-14 11:50:45.104 | INFO     | __main__:__init__:6 - 9803


In [9]:
import torchvision.models as models

def convBlock(ni, no):
    return nn.Sequential(
                nn.Dropout(0.2),
                nn.Conv2d(ni, no, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.BatchNorm2d(no),
                nn.MaxPool2d(2),
            )
    
class SignClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
                        convBlock(3, 64),
                        convBlock(64, 64),
                        convBlock(64, 128),
                        convBlock(128, 64),
                        nn.Flatten(),
                        nn.Linear(256, 256),
                        nn.Dropout(0.2),
                        nn.ReLU(inplace=True),
                        nn.Linear(256, len(id2int))
                    )
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def compute_metrics(self, preds, targets):
        ce_loss = self.loss_fn(preds, targets)
        acc =(torch.max(preds, 1)[1]==targets).float().mean()
        return ce_loss, acc

In [10]:
def train_batch(model, data, optimizer, criterion):
    model.train()
    ims, labels = data
    _preds = model(ims)
    optimizer.zero_grad()
    loss, acc = criterion(_preds, labels)
    loss.backward()
    optimizer.step()
    return loss.item(), acc.item()

@torch.no_grad()
def validate_batch(model, data, criterion):
    model.eval()
    ims, labels = data
    _preds = model(ims)
    loss, acc = criterion(_preds, labels)
    return loss.item(), acc.item()

In [11]:
model = SignClassifier().to(device)
criterion = model.compute_metrics
optimizer = optim.Adam(model.parameters(), lr=1e-3)
n_epochs = 50

log = Report(n_epochs)
for ex in range(n_epochs):
    N = len(trn_dl)
    for bx, data in enumerate(trn_dl):
        loss, acc = train_batch(model, data, optimizer, \
                                    criterion)
        log.record(ex+(bx+1)/N,trn_loss=loss, trn_acc=acc, \
                                     end='\r')

    N = len(val_dl)
    for bx, data in enumerate(val_dl):
        loss, acc = validate_batch(model, data, criterion)
        log.record(ex+(bx+1)/N, val_loss=loss, val_acc=acc, \
                                    end='\r')
        
    log.report_avgs(ex+1)
    if ex == 10: optimizer = optim.Adam(model.parameters(), \
                                    lr=1e-4)

EPOCH: 1.000	trn_loss: 1.877	trn_acc: 0.449	val_loss: 0.809	val_acc: 0.761	(24.65s - 1207.95s remaining)
EPOCH: 2.000	trn_loss: 0.690	trn_acc: 0.778	val_loss: 0.832	val_acc: 0.761	(49.04s - 1176.99s remaining)
EPOCH: 3.000	trn_loss: 0.479	trn_acc: 0.843	val_loss: 0.675	val_acc: 0.804	(73.33s - 1148.91s remaining)
EPOCH: 4.000	trn_loss: 0.391	trn_acc: 0.872	val_loss: 0.582	val_acc: 0.844	(97.38s - 1119.88s remaining)
EPOCH: 5.000	trn_loss: 0.326	trn_acc: 0.893	val_loss: 0.423	val_acc: 0.881	(121.32s - 1091.92s remaining)
EPOCH: 6.000	trn_loss: 0.290	trn_acc: 0.906	val_loss: 0.523	val_acc: 0.867	(145.23s - 1065.00s remaining)
EPOCH: 7.000	trn_loss: 0.272	trn_acc: 0.912	val_loss: 0.330	val_acc: 0.893	(169.01s - 1038.20s remaining)
EPOCH: 8.000	trn_loss: 0.239	trn_acc: 0.921	val_loss: 0.357	val_acc: 0.895	(192.86s - 1012.51s remaining)
EPOCH: 9.000	trn_loss: 0.222	trn_acc: 0.927	val_loss: 0.525	val_acc: 0.865	(216.60s - 986.74s remaining)
EPOCH: 10.000	trn_loss: 0.205	trn_acc: 0.933	val_lo