In [1]:
import numpy as np
from pylibdl.data import DataLoader
from pylibdl.optim import Adam
from utils import DistractedDriver, MyResNet, validate
from pylibdl import cross_entropy_with_logits
import time
np.random.seed(0)

In [2]:
path = '/home/superbabes/Downloads/ddriver'
batch_size = 64
lr = 1e-4
log_every = 10  # 10

In [3]:
train_data = DistractedDriver(path, val=False)
train_loader = DataLoader(train_data, batch_size, shuffle=True, drop_last=True)
val_data = DistractedDriver(path, val=True)

In [4]:
model = MyResNet()
optim = Adam(model.parameter(), lr)

In [None]:
begin = time.time()
print('iteration | loss')
for i, (imgs, labels) in enumerate(train_loader):
    pred = model(imgs)
    loss = cross_entropy_with_logits(pred, labels)
    loss.backward()
    optim.step()
    optim.zero_grad()
    if i % log_every == 0:
        print(f'{i: <9} | {loss.numpy():.4f}')
    if i >= 100:
        break
model.save('model')
print(f'validation accuracy: {validate(model, val_data)*100:.2f}%')
print(f'training took {time.time() - begin: .2f} seconds')

iteration | loss
0         | 2.6797
10        | 1.9880
20        | 1.6636
30        | 1.1510
40        | 0.9328
50        | 0.7677
60        | 0.7184


In [None]:
best = 0
for k in range(100000):
    model.train()
    for i, (imgs, labels) in enumerate(train_loader):
        pred = model(imgs)
        loss = cross_entropy_with_logits(pred, labels)
        loss.backward()
        optim.step()
        optim.zero_grad()
        if i % log_every == 0:
            print(f'{k}, {i} | {loss.numpy():.4f}')
        optim.lr *= .998
    model.eval()
    acc = validate(model, val_data)
    if acc > best:
        best = acc
        model.save('model_long')
    print(f'validation accuracy {k}: {acc*100:.2f}%')

0, 0 | 0.3795
0, 10 | 0.3673
0, 20 | 0.4727
0, 30 | 0.4124
0, 40 | 0.4258
0, 50 | 0.5351
0, 60 | 0.4523
0, 70 | 0.3872
0, 80 | 0.5308
0, 90 | 0.4304
0, 100 | 0.4114
0, 110 | 0.4302
0, 120 | 0.3551
0, 130 | 0.4129
0, 140 | 0.3442
0, 150 | 0.4351
0, 160 | 0.5221
0, 170 | 0.3605
0, 180 | 0.3901
0, 190 | 0.3318
0, 200 | 0.4651
0, 210 | 0.3849
0, 220 | 0.4958
0, 230 | 0.4724
0, 240 | 0.5280
0, 250 | 0.4773
0, 260 | 0.4235
0, 270 | 0.4167
0, 280 | 0.4425
0, 290 | 0.4840
0, 300 | 0.4400
0, 310 | 0.3507
0, 320 | 0.3575
validation accuracy 0: 94.84%
1, 0 | 0.3161
1, 10 | 0.2648
1, 20 | 0.2924
1, 30 | 0.3590
1, 40 | 0.3625
1, 50 | 0.3179
1, 60 | 0.4408
1, 70 | 0.3792
1, 80 | 0.3594
1, 90 | 0.4294
1, 100 | 0.4517
1, 110 | 0.4050
1, 120 | 0.3854
1, 130 | 0.4545
1, 140 | 0.3652
1, 150 | 0.3309
1, 160 | 0.4118
1, 170 | 0.3743
1, 180 | 0.3776
1, 190 | 0.3420
1, 200 | 0.3774
1, 210 | 0.3320
1, 220 | 0.3402
1, 230 | 0.3564
1, 240 | 0.3815
1, 250 | 0.4011
1, 260 | 0.3843
1, 270 | 0.3133
1, 280 | 0.3512
