<a href="https://colab.research.google.com/github/tomek-l/fire-detect-nn/blob/master/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from model import Model, load_dataset, accuracy
import numpy as np
import torch
import torchvision

BACKBONES = ['resnet18','resnet34','resnet50','resnet101', 'densenet121', 'mobilenet']
BACKBONES = ['resnet50'] # override with just one backbone

dataset_paths = {'mine': '/home/013855803/fire_aerial2k_dataset/',
                 'dunnings': '/home/013855803/fire-dataset-dunnings/images-224x224/train',
                 'dunnings_test': '/home/013855803/fire-dataset-dunnings/images-224x224/test'}

train, valid = load_dataset(dataset_paths['dunnings'], batch_size=64)

tr = torchvision.transforms.Compose([torchvision.transforms.Resize((224,224)),
                            torchvision.transforms.ToTensor()])

test_dataset = torchvision.datasets.ImageFolder(root=dataset_paths['dunnings_test'],
                                                transform=tr)


test = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=16,
    num_workers=0,
    shuffle=False
)

print(f'loaded {len(train)} training batches and {len(valid)} validation batches')
print(f'loaded {len(test)} test batches')

# Can be useful if we're retraining many times on the entire dataset
# completely memory extravagant but I have 256GB of RAM to use :)
# train, valid = list(train), list(valid)

loaded 330 training batches and 37 validation batches
loaded 184 test batches


In [2]:
import torch
device = torch.device("cuda:0")
is_validating = True
is_testing = True

history = {
    'train_samples': [],
    'train_acc': [],
    'valid_acc': [],
    'test_acc': [],
    'loss': []
}


for b in BACKBONES:

    import torch.optim as optim

    m = Model(backbone=b)
    m = m.to(device)

    criterion = torch.nn.BCELoss()
    optimizer = optim.Adam(m.parameters(), lr=1e-4, weight_decay=1e-3)

    for epoch in range(1): # epochs

        running_loss = []
        running_acc = []

        # epoch training
        for i, data in enumerate(train):

            # get the inputs; data is a list of [inputs, labels]
            inputs = data[0].to(device)
            labels = data[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = m(inputs)
            loss = criterion(outputs[:,0], labels.type_as(outputs[:,0]))
            loss.backward()
            optimizer.step()

            acc = accuracy(outputs, labels)
            # print statistics

            running_loss.append(loss.item())
            running_acc.append(acc)


            if i%20 == 19:
                print(f'epoch: {epoch+1}, batch: {i}, \
                loss: {np.mean(running_loss)}, training accuracy: {np.mean(running_acc)}')

                history['loss'].append(np.mean(running_loss))
                history['train_samples'].append(epoch*len(train)+i)
                history['train_acc'].append(np.mean(running_acc))

        #########################################
        # on epoch end:
        if is_validating:
            valid_acc = []
            # epoch validation
            for i, data in enumerate(valid):
                # get the inputs; data is a list of [inputs, labels]
                inputs = data[0].to(device)
                labels = data[1].to(device)

                # could pehaps do:
                # for param in m.parameters():
                #     param.requires_grad = False

                outputs = m(inputs)
                valid_acc.append(accuracy(outputs, labels))
            va = round(np.mean(valid_acc), 4)
            print(f'validation accuracy {va}')
            history['valid_acc'].append(va)
        else:
            va='-1'

        if is_testing:
            test_acc = []
            # epoch validation
            for i, data in enumerate(test):
                # get the inputs; data is a list of [inputs, labels]
                inputs = data[0].to(device)
                labels = data[1].to(device)

                # could pehaps do:
                # for param in m.parameters():
                #     param.requires_grad = False

                outputs = m(inputs)
                test_acc.append(accuracy(outputs, labels))
            tst = round(np.mean(test_acc), 4)
            print(f'test_accuracy {tst}')
            history['test_acc'].append(tst)
        else:
            tst = '-1'

        fname =  f'weights/{b}-epoch-{epoch}-valid_acc={va}-test_acc={tst}.pt'
        print(f'Saved {fname}')
        torch.save(m, fname)


    print(f'Finished Training: {b}')

epoch: 1, batch: 19,                 loss: 0.2499097514897585, training accuracy: 0.89375
epoch: 1, batch: 39,                 loss: 0.17549494206905364, training accuracy: 0.926953125
epoch: 1, batch: 59,                 loss: 0.14291056434934338, training accuracy: 0.9432291666666667
epoch: 1, batch: 79,                 loss: 0.1258889297954738, training accuracy: 0.9509765625
epoch: 1, batch: 99,                 loss: 0.10898326711729169, training accuracy: 0.95859375
epoch: 1, batch: 119,                 loss: 0.09728525512618943, training accuracy: 0.963671875
epoch: 1, batch: 139,                 loss: 0.0914144716258826, training accuracy: 0.9658482142857143
epoch: 1, batch: 159,                 loss: 0.08747305851138662, training accuracy: 0.967578125
epoch: 1, batch: 179,                 loss: 0.08304406131048583, training accuracy: 0.9693576388888889
epoch: 1, batch: 199,                 loss: 0.07870092785684392, training accuracy: 0.971171875
epoch: 1, batch: 219,          

In [3]:
import matplotlib.pyplot as plt

for history in histories:
    plt.figure()
    plt.plot(history['train_samples'], history['train_acc'])
    plt.plot(history['train_samples'], history['loss'])
    plt.scatter([len(train)], history['valid_acc'])
    plt.scatter([len(train)], history['test_acc'])

NameError: name 'histories' is not defined

In [None]:
for k,v in histories.items():
    print(v['train_acc'])

In [None]:
history.keys()

