In [1]:
!nvidia-smi

Sat Dec 31 15:58:07 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.203.03   Driver Version: 450.203.03   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

# 라이브러리 및 파일 불러오기

In [2]:
# load libraries

import os
import numpy as np
import pandas as pd
from statistics import mean

import torch
import torchvision

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from scipy import stats

from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings("ignore")

PROJECT_PATH = os.getenv('HOME') + '/aiffel/project/AIFFELTHON'
MODEL_PATH = os.path.join(PROJECT_PATH, 'weights/om_weights')
DATA_PATH = os.path.join('data')
TRAIN_PATH = os.path.join(DATA_PATH, 'train')
TEST_PATH = os.path.join(DATA_PATH, 'test')
REJECT_PATH = os.path.join(DATA_PATH, 'reject')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device) # connected to GPU if 'cuda' is printed here

cuda


In [None]:
# checking imgs in a folder

for dirpath, dirnames, filenames in os.walk(TRAIN_PATH):
    for i, filename in enumerate(filenames):
        print(os.path.join(dirpath, filename)) # prints file names
        image = Image.open(os.path.join(dirpath, filename), 'r')
        print(f'size: ({image.width}, {image.height}, {image.getbands()})') # prints img info
        plt.imshow(image)
        plt.show()
        if i==4:
            break # print 4 per folder

# Create Functions

In [5]:
# Normalize imgs, resize to 224x224
# Create pipeline
# PyTorch offers various augmentation techniques in torchvision.transforms.Compose

def create_dataloader(path, batch_size, istrain):
    nearest_mode = torchvision.transforms.InterpolationMode.NEAREST
    normalize = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
    )
    train_transformer = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224,224), interpolation=nearest_mode),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomVerticalFlip(),
        torchvision.transforms.ColorJitter(),
        torchvision.transforms.ToTensor(),
        normalize
    ])

    test_transformer = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224,224), interpolation=nearest_mode),
        torchvision.transforms.ToTensor(),
        normalize
    ])
    
    if istrain:
        data = torchvision.datasets.ImageFolder(path, transform=train_transformer)
        dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
        
    else:
        data = torchvision.datasets.ImageFolder(path, transform=test_transformer)
        dataloader = torch.utils.data.DataLoader(data, shuffle=False)

    return dataloader, data

In [6]:
# creating train dataset

BATCH_SIZE = 64 # changed from 64 to 1

train_loader, _train_data = create_dataloader(TRAIN_PATH, BATCH_SIZE, True)
target_class_num = len(os.listdir(os.path.join(TRAIN_PATH)))

print('target_class_num: ', target_class_num)
print('train: ', _train_data.class_to_idx)

target_class_num:  3
train:  {'07_inner_cupholder_resized': 0, 'resized_11_inner_front_seat': 1, 'resized_data12_inner_rear_seat': 2}


In [7]:
# checking num of imgs in each class

for rootpath, dirpath, filenames in os.walk(TRAIN_PATH):
    print(f'{rootpath} : {len(filenames)}')

data/train : 0
data/train/resized_11_inner_front_seat : 1671
data/train/resized_data12_inner_rear_seat : 1957
data/train/07_inner_cupholder_resized : 2000


In [8]:
# creating test dataset

BATCH_SIZE = 64 # changed from 64 to 1

test_loader, _test_data = create_dataloader(TEST_PATH, BATCH_SIZE, False)
target_class_num = len(os.listdir(os.path.join(TEST_PATH)))

print('target_class_num: ', target_class_num)
print('test: ', _test_data.class_to_idx)

target_class_num:  3
test:  {'07_inner_cupholder_resized': 0, 'resized_11_inner_front_seat': 1, 'resized_data12_inner_rear_seat': 2}


In [9]:
# checking num of imgs in each class

for rootpath, dirpath, filenames in os.walk(TEST_PATH):
    print(f'{rootpath} : {len(filenames)}')

data/test : 0
data/test/resized_11_inner_front_seat : 1671
data/test/resized_data12_inner_rear_seat : 1957
data/test/07_inner_cupholder_resized : 2000


In [10]:
# metrics from sklearn.metrics

def calculate_metrics(trues, preds):
    accuracy = accuracy_score(trues, preds)
    f1 = f1_score(trues, preds, average='macro')
    precision = precision_score(trues, preds, average='macro')
    recall = recall_score(trues, preds, average='macro')
    return accuracy, f1, precision, recall

In [11]:
# train function

def train(dataloader, net, learning_rate, weight_decay_level, device):
    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        net.parameters(),
        lr = learning_rate, 
        weight_decay = weight_decay_level
    )

    net.train()

    train_losses = list()
    train_preds = list()
    train_trues = list()

    for idx, (img, label) in enumerate(dataloader):

        img = img.to(device)
        label = label.to(device)
        
        optimizer.zero_grad()

        out = net(img)

        _, pred = torch.max(out, 1)
        loss = criterion(out, label)

        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())
        train_trues.extend(label.view(-1).cpu().numpy().tolist())
        train_preds.extend(pred.view(-1).cpu().detach().numpy().tolist())

    acc, f1, prec, rec = calculate_metrics(train_trues, train_preds)

    print('\n''====== Training Metrics ======')
    print('Loss: ', mean(train_losses))
    print('Acc: ', acc)
    print('F1: ', f1)
    print('Precision: ', prec)
    print('Recall: ', rec)
    print(confusion_matrix(train_trues, train_preds))

    return net, acc, f1, prec, rec

In [12]:
# test function

def test(dataloader, net, device):

    criterion = torch.nn.CrossEntropyLoss()
    
    net.eval()
    test_losses = list()
    test_trues = list()
    test_preds = list()
    
    with torch.no_grad():
        for idx, (img, label) in enumerate(dataloader):

            img = img.to(device)
            label = label.to(device)

            out = net(img)

            _, pred = torch.max(out, 1)
            loss = criterion(out, label)

            test_losses.append(loss.item())
            test_trues.extend(label.view(-1).cpu().numpy().tolist())
            test_preds.extend(pred.view(-1).cpu().detach().numpy().tolist())

    acc, f1, prec, rec = calculate_metrics(test_trues, test_preds)

    print('====== Test Metrics ======')
    print('Test Loss: ', mean(test_losses))
    print('Test Acc: ', acc)
    print('Test F1: ', f1)
    print('Test Precision: ', prec)
    print('Test Recall: ', rec)
    print(confusion_matrix(test_trues, test_preds))

    return net, acc, f1, prec, rec

In [13]:
# code to save best params based on acc

def train_classifier(net, train_loader, test_loader, n_epochs, learning_rate, weight_decay, device):
    best_test_acc = 0
    
    model_save_path = None
    model_save_base = 'weights/om_weights'
    if not os.path.exists(model_save_base):
        os.makedirs(model_save_base)
    
    print('>> Start Training Model!')
    for epoch in range(n_epochs):
        
        print('> epoch: ', epoch)

        net, _, _, _, _ = train(train_loader, net, learning_rate, weight_decay, device)
        net, test_acc, _, _, _  = test(test_loader, net, device)

        if test_acc > best_test_acc:

            best_test_acc = test_acc
            test_acc_str = '%.5f' % test_acc

            print('[Notification] Best Model Updated!')
            model_save_path = os.path.join(model_save_base, 'om_300_classifier_acc_' + str(test_acc_str) + '.pth') 
            torch.save(net.state_dict(), model_save_path)
                
    return model_save_path

In [14]:
target_class_num

3

In [15]:
# load pre-trained resnet50

net = torchvision.models.resnet50(pretrained=True)
net.fc = torch.nn.Linear(
    net.fc.in_features,
    target_class_num
)

net.to(device)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /aiffel/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [16]:
# training 300 epochs

EPOCHS = 300
LEARNING_RATE = 0.005
WEIGHT_DECAY = 0.0005

saved_weight_path = train_classifier(net, train_loader, test_loader, EPOCHS, LEARNING_RATE, WEIGHT_DECAY, device)

>> Start Training Model!
> epoch:  0

Loss:  0.5427066770128228
Acc:  0.7937100213219617
F1:  0.7888276157394847
Precision:  0.788829221898224
Recall:  0.7888716073008339
[[1760  146   94]
 [ 151 1182  338]
 [ 113  319 1525]]
Test Loss:  0.30675318112983146
Test Acc:  0.8987206823027718
Test F1:  0.8961094657618087
Test Precision:  0.8975098542774579
Test Recall:  0.8956445053285575
[[1899   83   18]
 [ 153 1409  109]
 [  73  134 1750]]
[Notification] Best Model Updated!
> epoch:  1

Loss:  0.2972937584431334
Acc:  0.9134683724235964
F1:  0.9122387030677569
Precision:  0.9127151417158218
Recall:  0.9122537337006563
[[1829  101   70]
 [  47 1485  139]
 [  45   85 1827]]
Test Loss:  1.467846912093647
Test Acc:  0.4564676616915423
Test F1:  0.34839237088733704
Test Precision:  0.7610401702580473
Test Recall:  0.43615003831122373
[[ 588    2 1410]
 [   0   25 1646]
 [   0    1 1956]]
> epoch:  2

Loss:  0.24509322778745132
Acc:  0.9193319118692252
F1:  0.9179266364339987
Precision:  0.9184


Loss:  0.08429585794113915
Acc:  0.974591329068941
F1:  0.974171517514697
Precision:  0.9742730370188406
Recall:  0.9741024280661797
[[1951   30   19]
 [  21 1612   38]
 [  13   22 1922]]
Test Loss:  0.5178840602324926
Test Acc:  0.8050817341862118
Test F1:  0.7918203657517474
Test Precision:  0.873976280172494
Test Recall:  0.789759083205332
[[1740   16  244]
 [   7  836  828]
 [   2    0 1955]]
> epoch:  18

Loss:  0.08532437799625438
Acc:  0.9721037668798863
F1:  0.9716340830553972
Precision:  0.9716007934040146
Recall:  0.9716996593731108
[[1948   32   20]
 [  17 1611   43]
 [  14   31 1912]]
Test Loss:  0.09794055414003003
Test Acc:  0.9685501066098081
Test F1:  0.9679029829601006
Test Precision:  0.9675180780787119
Test Recall:  0.9693898978139718
[[1951   46    3]
 [   9 1649   13]
 [  14   92 1851]]
[Notification] Best Model Updated!
> epoch:  19

Loss:  0.08610598427582193
Acc:  0.9731698649609097
F1:  0.9726261974531836
Precision:  0.9728057893423582
Recall:  0.9725102393460


Loss:  0.07407690603709356
Acc:  0.9793887704335466
F1:  0.9789499676992491
Precision:  0.979061471582646
Recall:  0.9788577653135063
[[1964   26   10]
 [  17 1619   35]
 [  10   18 1929]]
Test Loss:  1.756597816756339
Test Acc:  0.6457000710732054
Test F1:  0.5694930172773108
Test Precision:  0.7884464382672572
Test Recall:  0.6605573777977973
[[1812  185    3]
 [  17 1654    0]
 [ 176 1613  168]]
> epoch:  35

Loss:  0.07806107562183487
Acc:  0.9769012082444918
F1:  0.9764004404192953
Precision:  0.9766042644516414
Recall:  0.9762364239793094
[[1957   31   12]
 [  23 1610   38]
 [  10   16 1931]]
Test Loss:  0.08381054565183395
Test Acc:  0.972636815920398
Test F1:  0.9720301057986868
Test Precision:  0.97344184619021
Test Recall:  0.9711463592717188
[[1960   21   19]
 [  16 1576   79]
 [  12    7 1938]]
> epoch:  36

Loss:  0.07716545800212771
Acc:  0.9770788912579957
F1:  0.9766459814430742
Precision:  0.9768083325798299
Recall:  0.9765160390343309
[[1959   26   15]
 [  19 1614   

Test Loss:  0.07061413043429048
Test Acc:  0.9790334044065387
Test F1:  0.9786313880495513
Test Precision:  0.9787637325851507
Test Recall:  0.9787252699955079
[[1939   42   19]
 [   6 1623   42]
 [   3    6 1948]]
> epoch:  52

Loss:  0.07181370988572863
Acc:  0.9763681592039801
F1:  0.9759499029741031
Precision:  0.9759624389549241
Recall:  0.9759549966204374
[[1958   27   15]
 [  15 1618   38]
 [  13   25 1919]]
Test Loss:  0.4592199186555781
Test Acc:  0.830135039090263
Test F1:  0.8274613983708905
Test Precision:  0.8738688541955574
Test Recall:  0.8397638630618135
[[1195  775   30]
 [   0 1659   12]
 [   1  138 1818]]
> epoch:  53

Loss:  0.07358979716346684
Acc:  0.9776119402985075
F1:  0.9771838187175018
Precision:  0.9772871531505863
Recall:  0.977107158944231
[[1961   24   15]
 [  16 1617   38]
 [  10   23 1924]]
Test Loss:  0.22332015613081774
Test Acc:  0.9244847192608386
Test F1:  0.9215560842870788
Test Precision:  0.9335223715381646
Test Recall:  0.9184832968772759
[[192


Loss:  0.06525147839618678
Acc:  0.9800995024875622
F1:  0.9797450225878466
Precision:  0.9799220369978014
Recall:  0.9795937234116184
[[1965   23   12]
 [  16 1621   34]
 [  12   15 1930]]
Test Loss:  0.06344972929020075
Test Acc:  0.9811656005685856
Test F1:  0.9808126096044719
Test Precision:  0.9805864632460132
Test Recall:  0.9811515730026814
[[1954   29   17]
 [   3 1638   30]
 [   2   25 1930]]
> epoch:  70

Loss:  0.06513688339724798
Acc:  0.9799218194740583
F1:  0.9795179486949831
Precision:  0.9796297188422093
Recall:  0.9794233946771609
[[1965   24   11]
 [  17 1621   33]
 [  10   18 1929]]
Test Loss:  0.05986145823222992
Test Acc:  0.9815209665955935
Test F1:  0.9811013240337115
Test Precision:  0.9811803176949127
Test Recall:  0.9810802878790056
[[1963   27   10]
 [   7 1625   39]
 [   6   15 1936]]
> epoch:  71

Loss:  0.06306764205642552
Acc:  0.9795664534470505
F1:  0.9791345078810632
Precision:  0.9792605217632774
Recall:  0.9790280940479638
[[1964   26   10]
 [  18 1

Test Loss:  0.21852063194256766
Test Acc:  0.9214641080312722
Test F1:  0.9210628149921919
Test Precision:  0.9271417577863911
Test Recall:  0.9234892982894448
[[1650  157  193]
 [   1 1585   85]
 [   0    6 1951]]
> epoch:  87

Loss:  0.061207727214258
Acc:  0.9800995024875622
F1:  0.97971305559008
Precision:  0.9797984647816677
Recall:  0.9796447045041093
[[1967   23   10]
 [  11 1623   37]
 [  12   19 1926]]
Test Loss:  0.08919180115444345
Test Acc:  0.9697938877043355
Test F1:  0.9692535249820732
Test Precision:  0.9691520064471876
Test Recall:  0.9697143689565025
[[1981   11    8]
 [  26 1623   22]
 [  31   72 1854]]
> epoch:  88

Loss:  0.06527995601275259
Acc:  0.98045486851457
F1:  0.9799892525610883
Precision:  0.9802489534730999
Recall:  0.979770307471397
[[1970   22    8]
 [  15 1616   40]
 [   9   16 1932]]
Test Loss:  1.4931351308923906
Test Acc:  0.7199715707178393
Test F1:  0.6698429863600591
Test Precision:  0.8040359775147333
Test Recall:  0.7309397308948701
[[1947   5


Loss:  0.05857179044026204
Acc:  0.9827647476901208
F1:  0.9824475769341409
Precision:  0.98283423912516
Recall:  0.9821193578861541
[[1973   16   11]
 [  17 1621   33]
 [  10   10 1937]]
Test Loss:  0.3116647942221035
Test Acc:  0.8813077469793887
Test F1:  0.8802423688795838
Test Precision:  0.8919471563528392
Test Recall:  0.8866453653816379
[[1470  219  311]
 [   5 1619   47]
 [   7   79 1871]]
> epoch:  105

Loss:  0.0677637454725548
Acc:  0.9774342572850035
F1:  0.9769775640883557
Precision:  0.9770021952728666
Recall:  0.9769696448916007
[[1960   26   14]
 [  18 1618   35]
 [   9   25 1923]]
Test Loss:  0.09603014012651495
Test Acc:  0.9717484008528785
Test F1:  0.9717975475354351
Test Precision:  0.9726647894391199
Test Recall:  0.9713995725064755
[[1921   17   62]
 [  21 1609   41]
 [   6   12 1939]]
> epoch:  106

Loss:  0.058884105131834404
Acc:  0.984363894811656
F1:  0.9841057342329607
Precision:  0.9842839964600677
Recall:  0.9839475047044267
[[1972   18   10]
 [  14 163

Test Loss:  0.19845297229209732
Test Acc:  0.9331911869225302
Test F1:  0.931291785188637
Test Precision:  0.9441668228968547
Test Recall:  0.9273710229744004
[[1996    2    2]
 [ 246 1376   49]
 [  76    1 1880]]
> epoch:  122

Loss:  0.0590363126988946
Acc:  0.9832977967306326
F1:  0.9828331570323557
Precision:  0.9828980595877234
Recall:  0.9827834312952904
[[1971   25    4]
 [  14 1626   31]
 [   5   15 1937]]
Test Loss:  0.0542214878764678
Test Acc:  0.9818763326226013
Test F1:  0.9817640951049932
Test Precision:  0.9823361544249863
Test Recall:  0.981479538432574
[[1947   20   33]
 [  10 1625   36]
 [   2    1 1954]]
> epoch:  123

Loss:  0.05888042063981464
Acc:  0.9836531627576404
F1:  0.9833198052747568
Precision:  0.9832296205948342
Recall:  0.9834229390401511
[[1969   21   10]
 [  11 1636   24]
 [   6   20 1931]]
Test Loss:  0.16073104008266242
Test Acc:  0.9424307036247335
Test F1:  0.9409048980049235
Test Precision:  0.9482632882425821
Test Recall:  0.9380721999143565
[[19


Loss:  0.05956906446425075
Acc:  0.9820540156361052
F1:  0.981686536480027
Precision:  0.9817262270221289
Recall:  0.9816677457211149
[[1966   24   10]
 [  12 1628   31]
 [   8   16 1933]]
Test Loss:  0.6126995785148627
Test Acc:  0.829957356076759
Test F1:  0.8176234854215895
Test Precision:  0.8874953482555871
Test Recall:  0.8143949450386584
[[1994    2    4]
 [ 744  895   32]
 [ 175    0 1782]]
> epoch:  140

Loss:  0.059313463139749896
Acc:  0.9822316986496091
F1:  0.9817696155907097
Precision:  0.9818809270306573
Recall:  0.981677663114227
[[1970   24    6]
 [  15 1623   33]
 [   6   16 1935]]
Test Loss:  0.1460706897042596
Test Acc:  0.9461620469083155
Test F1:  0.9456293473886611
Test Precision:  0.9474501501973193
Test Recall:  0.9486342488579259
[[1814  171   15]
 [   0 1656   15]
 [   2  100 1855]]
> epoch:  141

Loss:  0.06347484666515481
Acc:  0.9811656005685856
F1:  0.9807761247811927
Precision:  0.9809654306020574
Recall:  0.9806156958183633
[[1965   24   11]
 [  20 162

Test Loss:  0.15389732774744921
Test Acc:  0.949182658137882
Test F1:  0.9478634479959779
Test Precision:  0.9489685105725432
Test Recall:  0.9500473556693323
[[1980   15    5]
 [  32 1626   13]
 [  27  194 1736]]
> epoch:  157

Loss:  0.060862378323112025
Acc:  0.9809879175550817
F1:  0.9806854460067891
Precision:  0.9809301796523254
Recall:  0.9804855059013148
[[1962   23   15]
 [  19 1622   30]
 [  10   10 1937]]
Test Loss:  0.16133019576942087
Test Acc:  0.9422530206112296
Test F1:  0.9407278636353831
Test Precision:  0.9517780231073377
Test Recall:  0.9373524971507398
[[1937    3   60]
 [   2 1413  256]
 [   3    1 1953]]
> epoch:  158

Loss:  0.05886202389162711
Acc:  0.9811656005685856
F1:  0.98076826414122
Precision:  0.9807694151218532
Recall:  0.9807942735703725
[[1964   26   10]
 [  10 1627   34]
 [   8   18 1931]]
Test Loss:  0.3020449033492449
Test Acc:  0.8857498223169865
Test F1:  0.8725091474087926
Test Precision:  0.9076711681983518
Test Recall:  0.8724332995835763
[[1


Loss:  0.05638784083633006
Acc:  0.9815209665955935
F1:  0.9810848309115716
Precision:  0.9813333828761138
Recall:  0.980887061855833
[[1968   22   10]
 [  15 1619   37]
 [   6   14 1937]]
Test Loss:  0.4282437531160586
Test Acc:  0.8589196872778962
Test F1:  0.8411072031152176
Test Precision:  0.8885389957113493
Test Recall:  0.8434778644507418
[[1984    4   12]
 [ 180  939  552]
 [  46    0 1911]]
> epoch:  175

Loss:  0.05806012940063903
Acc:  0.9815209665955935
F1:  0.9811097973086692
Precision:  0.9812372037532198
Recall:  0.9810000102441877
[[1969   21   10]
 [  17 1623   31]
 [   7   18 1932]]
Test Loss:  0.18306923708006126
Test Acc:  0.9292821606254442
Test F1:  0.9262952518414723
Test Precision:  0.9432976343463545
Test Recall:  0.9223456014362658
[[1947    0   53]
 [   7 1326  338]
 [   0    0 1957]]
> epoch:  176

Loss:  0.055596875931686635
Acc:  0.9841862117981521
F1:  0.9837742062036178
Precision:  0.9839501453186076
Recall:  0.9836240887642055
[[1974   19    7]
 [  15 

Test Loss:  0.11540062763332189
Test Acc:  0.9550461975835111
Test F1:  0.9535356832667315
Test Precision:  0.9598716366832908
Test Recall:  0.9511353322444934
[[1960   11   29]
 [   8 1468  195]
 [   6    4 1947]]
> epoch:  192

Loss:  0.05467415150170299
Acc:  0.9818763326226013
F1:  0.9813923526653925
Precision:  0.9814287484728087
Recall:  0.9813698203271392
[[1969   26    5]
 [  14 1624   33]
 [   6   18 1933]]
Test Loss:  0.3178285337701443
Test Acc:  0.8763326226012793
Test F1:  0.8622103545017707
Test Precision:  0.8989492647609105
Test Recall:  0.8628598768393796
[[1958    1   41]
 [ 320 1028  323]
 [  11    0 1946]]
> epoch:  193

Loss:  0.05829815464676358
Acc:  0.9827647476901208
F1:  0.9824338520627302
Precision:  0.9824582559607511
Recall:  0.9824218702298908
[[1970   20   10]
 [  10 1631   30]
 [   9   18 1930]]
Test Loss:  0.04114676886939747
Test Acc:  0.988272921108742
Test F1:  0.9880437753426072
Test Precision:  0.9884765095537126
Test Recall:  0.9876726205274564
[[


Loss:  0.061463915324896916
Acc:  0.9795664534470505
F1:  0.9791493363228246
Precision:  0.9792107400621933
Recall:  0.9791010475472
[[1960   28   12]
 [  23 1621   27]
 [   8   17 1932]]
Test Loss:  0.3482201494702837
Test Acc:  0.8937455579246624
Test F1:  0.8945152956557559
Test Precision:  0.9102652636067892
Test Recall:  0.8992762910963942
[[1619  378    3]
 [   0 1662    9]
 [   6  202 1749]]
> epoch:  210

Loss:  0.055487713649530306
Acc:  0.9820540156361052
F1:  0.9817574507934205
Precision:  0.9817087113654236
Recall:  0.9818244949946694
[[1963   23   14]
 [  13 1633   25]
 [   8   18 1931]]
Test Loss:  0.16859267019565735
Test Acc:  0.9383439943141436
Test F1:  0.9385434018563883
Test Precision:  0.9413828130043455
Test Recall:  0.9412146091088056
[[1706  110  184]
 [   4 1645   22]
 [   3   24 1930]]
> epoch:  211

Loss:  0.06074557926463471
Acc:  0.9840085287846482
F1:  0.9835408061496157
Precision:  0.9837141054122731
Recall:  0.9833844685983025
[[1977   18    5]
 [  18 1

Test Loss:  0.07561579961912168
Test Acc:  0.9715707178393745
Test F1:  0.9709771870686094
Test Precision:  0.9705210740730777
Test Recall:  0.972562795953821
[[1931   65    4]
 [   5 1656   10]
 [   3   73 1881]]
> epoch:  227

Loss:  0.058940006435891104
Acc:  0.9797441364605544
F1:  0.9792988655269524
Precision:  0.9794570996536754
Recall:  0.979165608100594
[[1965   24   11]
 [  19 1618   34]
 [   8   18 1931]]
Test Loss:  0.17410994958774526
Test Acc:  0.9376332622601279
Test F1:  0.937288413194865
Test Precision:  0.9408457557403223
Test Recall:  0.9406771861733841
[[1810  175   15]
 [   0 1663    8]
 [   1  152 1804]]
> epoch:  228

Loss:  0.054520999028516766
Acc:  0.9838308457711443
F1:  0.9834123170401176
Precision:  0.9835241166380086
Recall:  0.9833199080449084
[[1972   21    7]
 [  15 1627   29]
 [   4   15 1938]]
Test Loss:  0.8104850974688744
Test Acc:  0.7713219616204691
Test F1:  0.7745252706706256
Test Precision:  0.854001764824096
Test Recall:  0.7687325287415724
[[1


Loss:  0.05439181625843048
Acc:  0.9831201137171286
F1:  0.9827549002322771
Precision:  0.9827935071964463
Recall:  0.982718726813606
[[1974   18    8]
 [  15 1630   26]
 [   8   20 1929]]
Test Loss:  0.13491123883443912
Test Acc:  0.9578891257995735
Test F1:  0.9570301374913169
Test Precision:  0.9575142481849709
Test Recall:  0.9582512187270685
[[1987    9    4]
 [  35 1622   14]
 [  48  127 1782]]
> epoch:  245

Loss:  0.053662707134869626
Acc:  0.9848969438521677
F1:  0.9845111180786273
Precision:  0.9847265100853212
Recall:  0.9843418804516535
[[1972   21    7]
 [  12 1627   32]
 [   4    9 1944]]
Test Loss:  0.17351461049408518
Test Acc:  0.9346126510305615
Test F1:  0.9351843516662707
Test Precision:  0.9439661128430212
Test Recall:  0.9322466148055933
[[1992    5    3]
 [ 166 1496    9]
 [ 156   29 1772]]
> epoch:  246

Loss:  0.059328907079444354
Acc:  0.9808102345415778
F1:  0.9804638029015798
Precision:  0.9804286673831238
Recall:  0.9805045971939488
[[1966   22   12]
 [  1

Test Loss:  0.04831118295477118
Test Acc:  0.9872068230277186
Test F1:  0.9870378739150186
Test Precision:  0.9870721104617667
Test Recall:  0.9870042854852294
[[1977   21    2]
 [  10 1643   18]
 [  16    5 1936]]
> epoch:  262

Loss:  0.06766564701154659
Acc:  0.9792110874200426
F1:  0.9787870036939331
Precision:  0.9788985082190758
Recall:  0.978702084850212
[[1960   28   12]
 [  19 1619   33]
 [   9   16 1932]]
Test Loss:  0.30755396818977265
Test Acc:  0.8896588486140725
Test F1:  0.8863101995612622
Test Precision:  0.8901278913262431
Test Recall:  0.8888879851068877
[[1961   29   10]
 [ 107 1483   81]
 [  23  371 1563]]
> epoch:  263

Loss:  0.05304500420408493
Acc:  0.9840085287846482
F1:  0.9836601662379535
Precision:  0.9837876087774703
Recall:  0.9835485420074388
[[1972   19    9]
 [  16 1629   26]
 [   6   14 1937]]
Test Loss:  0.08892219109729486
Test Acc:  0.972636815920398
Test F1:  0.9721428539998501
Test Precision:  0.9715452317876793
Test Recall:  0.9734573156293381
[[


Loss:  0.05998312630055642
Acc:  0.9820540156361052
F1:  0.9816907860697422
Precision:  0.9819015504804552
Recall:  0.981514658515351
[[1968   22   10]
 [  15 1623   33]
 [   9   12 1936]]
Test Loss:  0.07527878194205986
Test Acc:  0.9744136460554371
Test F1:  0.9743904727873107
Test Precision:  0.9748632603166992
Test Recall:  0.9745998109565105
[[1904   41   55]
 [   2 1629   40]
 [   2    4 1951]]
> epoch:  280

Loss:  0.05424615162403577
Acc:  0.9832977967306326
F1:  0.9829078289755709
Precision:  0.9831049215726516
Recall:  0.9827396304100907
[[1975   16    9]
 [  12 1625   34]
 [   6   17 1934]]
Test Loss:  0.15022087279808755
Test Acc:  0.9475835110163469
Test F1:  0.947061520139172
Test Precision:  0.9516909690940043
Test Recall:  0.9448456933281593
[[1931   13   56]
 [  45 1494  132]
 [  48    1 1908]]
> epoch:  281

Loss:  0.05307255874239755
Acc:  0.9852523098791756
F1:  0.9849451859617774
Precision:  0.9849792954352358
Recall:  0.9849194209006505
[[1971   23    6]
 [  13 1

Test Loss:  0.7932885926025086
Test Acc:  0.7883795309168443
Test F1:  0.7894424637077425
Test Precision:  0.8585960841002803
Test Recall:  0.7984533035263145
[[1633  364    3]
 [   0 1670    1]
 [  13  810 1134]]
> epoch:  297

Loss:  0.0551548249098811
Acc:  0.9818763326226013
F1:  0.9814423036716836
Precision:  0.981523657899564
Recall:  0.9813881306660934
[[1964   27    9]
 [  17 1624   30]
 [   5   14 1938]]
Test Loss:  0.06909570623131552
Test Acc:  0.9770788912579957
Test F1:  0.9770328776630696
Test Precision:  0.9771482101466881
Test Recall:  0.9774096474358288
[[1914   38   48]
 [   2 1639   30]
 [   1   10 1946]]
> epoch:  298

Loss:  0.05809373544814827
Acc:  0.9811656005685856
F1:  0.9808031009716321
Precision:  0.9810817194667157
Recall:  0.9805572466620003
[[1973   16   11]
 [  16 1620   35]
 [  11   17 1929]]
Test Loss:  0.05470168287435214
Test Acc:  0.9856076759061834
Test F1:  0.985288523323523
Test Precision:  0.98473503509611
Test Recall:  0.9860873377659577
[[1968

In [18]:
# create confidence function
# need softmax and entropy for to check confidence
# get the highest softmax value out of all softmax values and compute entropy based on the mathematical expression

def get_confidence(net, infer_loader, device):    
    container = list()
    
    with torch.no_grad():
        for idx, (img, label) in enumerate(infer_loader):
            img = img.to(device)
            label = label.to(device)
            out = net(img) 
            out_softmax = torch.softmax(out, 1)

            msp = float(out_softmax.detach().cpu().numpy().max()) # max softmax value

            pA = out_softmax.detach().cpu().numpy() / out_softmax.detach().cpu().numpy().sum()
            entropy = -np.sum( pA * np.log2(pA))

            fname, _ = infer_loader.dataset.samples[idx]
            label = int(label.detach().cpu().numpy())

            tmp_container = {
                'fname':fname,
                'label':label,
                'msp':msp,
                'entropy':entropy
            }
            container.append(tmp_container)
        
    return container

# Extract Activation Vector

In [25]:
# need correct activation vector for openmax
# input val in softmax layer is Activation Vector, so retrieve activation from torch.softmax()

train_loader, _train_data = create_dataloader(TRAIN_PATH, 1, False)
target_class_num = len(os.listdir(TRAIN_PATH))

train_preds = list()
train_actvecs = list()
train_outputs_softmax = list()
train_labels = list()

with torch.no_grad():
    for idx, (img, label) in enumerate(train_loader):
        img = img.to(device)
        label = label.to(device)

        out = net(img)
        out_actvec = out.cpu().detach().numpy()[0]
        out_softmax = torch.softmax(out, 1).cpu().detach().numpy()[0]
        out_pred = int(torch.argmax(out).cpu().detach().numpy())
        out_label = int(label.cpu().detach().numpy())

        train_actvecs.append(out_actvec) # component 1: Activation Vector before softmax
        train_preds.append(out_pred) # componenet 2: preds of each data
        train_outputs_softmax.append(out_softmax) # component 3: softmax of each data
        train_labels.append(out_label) # component 4: labels of each data

train_actvecs = np.asarray(train_actvecs)
train_preds = np.asarray(train_preds)
train_outputs_softmax = np.asarray(train_outputs_softmax)
train_labels = np.asarray(train_labels)

In [26]:
# only using correct activations vectors in OpenMax algorithm

train_correct_actvecs = train_actvecs[train_labels==train_preds]
train_correct_labels = train_labels[train_labels==train_preds]
print('Activation vector: ', train_correct_actvecs.shape)
print('Labels: ', train_correct_labels.shape)

Activation vector:  (5427, 3)
Labels:  (5427,)


# Weibull-Distribution

In [27]:
np.unique(train_labels)

array([0, 1, 2])

In [28]:
# parameters for weibull-dist are 3 = shape, loc, scale, class has 4 parameters, so total of 12 nums

class_means = list()
dist_to_means = list()
mr_models = {}

for class_idx in np.unique(train_labels):
    
    print('class_idx: ', class_idx)
    class_act_vec = train_correct_actvecs[train_correct_labels==class_idx]
    print(class_act_vec.shape)
    
    class_mean = class_act_vec.mean(axis=0)
    class_means.append(class_mean)
    
    dist_to_mean = np.square(class_act_vec - class_mean).sum(axis=1) # compute distance of activation vectors
    dist_to_mean_sorted = np.sort(dist_to_mean).astype(np.float64) # sort based on distance
    dist_to_means.append(dist_to_mean_sorted)

    shape, loc, scale = stats.weibull_max.fit(dist_to_mean_sorted[-100:]) # parameters of furthest 100 act vecs
    
    mr_models[str(class_idx)] = {
        'shape':shape,
        'loc':loc,
        'scale':scale
    }
    
class_means = np.asarray(class_means)

class_idx:  0
(1958, 3)
class_idx:  1
(1605, 3)
class_idx:  2
(1864, 3)


In [29]:
def compute_openmax(actvec, class_means, mr_models):
    dist_to_mean = np.square(actvec - class_means).sum(axis=1)

    scores = list()
    for class_idx in range(len(class_means)):
        params = mr_models[str(class_idx)]
        score = stats.weibull_max.cdf(
            dist_to_mean[class_idx],
            params['shape'],
            params['loc'],
            params['scale']
        )
        scores.append(score)
    scores = np.asarray(scores)
    
    weight_on_actvec = 1 - scores # weight of each class
    rev_actvec = np.concatenate([
        weight_on_actvec * actvec, # multiplication of known class
        [((1-weight_on_actvec) * actvec).sum()] # computing unknown class
    ])
    
    openmax_prob = np.exp(rev_actvec) / np.exp(rev_actvec).sum()
    return openmax_prob

In [30]:
def inference(actvec, threshold, target_class_num, class_means, mr_models):
    openmax_prob = compute_openmax(actvec, class_means, mr_models)
    openmax_softmax = np.exp(openmax_prob)/sum(np.exp(openmax_prob))

    pred = np.argmax(openmax_softmax)
    if np.max(openmax_softmax) < threshold:
        pred = target_class_num
    return pred

In [31]:
def inference_dataloader(net, data_loader, threshold, target_class_num, class_means, mr_models, is_reject=False):
    result_preds = list()
    result_labels = list()

    with torch.no_grad():
        for idx, (img, label) in enumerate(data_loader):
            img = img.to(device)
            label = label.to(device)

            out = net(img)
            out_actvec = out.cpu().detach().numpy()[0]
            out_softmax = torch.softmax(out, 1).cpu().detach().numpy()[0]
            out_label = int(label.cpu().detach().numpy())

            pred = inference(out_actvec, threshold, target_class_num, class_means, mr_models)

            result_preds.append(pred)
            if is_reject:
                result_labels.append(target_class_num) # 3
            else:
                result_labels.append(out_label) # 0, 1, 2

    return result_preds, result_labels

In [32]:
# compute acc with 0.35 threshold

test_loader, _test_data = create_dataloader(TEST_PATH, 1, False)
reject_loader, _reject_data = create_dataloader(REJECT_PATH, 1, False)
target_class_num = len(os.listdir(TEST_PATH)) # 3

test_preds, test_labels = inference_dataloader(net, test_loader, 0.35, target_class_num, class_means, mr_models)
reject_preds, reject_labels = inference_dataloader(net, reject_loader, 0.35, target_class_num, class_means, mr_models, is_reject=True)

print('Test Accuracy: ', accuracy_score(test_labels, test_preds))
print('Reject Accuracy: ', accuracy_score(reject_labels, reject_preds))

Test Accuracy:  0.9253731343283582
Reject Accuracy:  0.25196123673281035


In [35]:
# finding adequate threshold
for i in np.arange(0.1,1,0.1):
    test_preds, test_labels = inference_dataloader(net, test_loader, i, target_class_num, class_means, mr_models)
    reject_preds, reject_labels = inference_dataloader(net, reject_loader, i, target_class_num, class_means, mr_models, is_reject=True)
    print('threshold: ', i)
    print('Test Accuracy: ', accuracy_score(test_labels, test_preds))
    print('Reject Accuracy: ', accuracy_score(reject_labels, reject_preds))

threshold:  0.1
Test Accuracy:  0.9630419331911869
Reject Accuracy:  0.003230272265805261
threshold:  0.2
Test Accuracy:  0.9630419331911869
Reject Accuracy:  0.003230272265805261
threshold:  0.30000000000000004
Test Accuracy:  0.9554015636105189
Reject Accuracy:  0.09598523304107061
threshold:  0.4
Test Accuracy:  0.8605188343994314
Reject Accuracy:  0.4134748500230734
threshold:  0.5
Test Accuracy:  0.0
Reject Accuracy:  1.0
threshold:  0.6
Test Accuracy:  0.0
Reject Accuracy:  1.0
threshold:  0.7000000000000001
Test Accuracy:  0.0
Reject Accuracy:  1.0
threshold:  0.8
Test Accuracy:  0.0
Reject Accuracy:  1.0
threshold:  0.9
Test Accuracy:  0.0
Reject Accuracy:  1.0


In [36]:
# finding adequate threshold
for i in np.arange(0.3,0.5,0.01):
    test_preds, test_labels = inference_dataloader(net, test_loader, i, target_class_num, class_means, mr_models)
    reject_preds, reject_labels = inference_dataloader(net, reject_loader, i, target_class_num, class_means, mr_models, is_reject=True)
    print('threshold: ', i)
    print('Test Accuracy: ', accuracy_score(test_labels, test_preds))
    print('Reject Accuracy: ', accuracy_score(reject_labels, reject_preds))

threshold:  0.3
Test Accuracy:  0.9554015636105189
Reject Accuracy:  0.09598523304107061
threshold:  0.31
Test Accuracy:  0.9507818052594172
Reject Accuracy:  0.127365020766036
threshold:  0.32
Test Accuracy:  0.9447405828002843
Reject Accuracy:  0.15736040609137056
threshold:  0.33
Test Accuracy:  0.9399431414356787
Reject Accuracy:  0.18458698661744347
threshold:  0.34
Test Accuracy:  0.933546552949538
Reject Accuracy:  0.21873557914167052
threshold:  0.35000000000000003
Test Accuracy:  0.9253731343283582
Reject Accuracy:  0.25196123673281035
threshold:  0.36000000000000004
Test Accuracy:  0.9166666666666666
Reject Accuracy:  0.27826488232579605
threshold:  0.37000000000000005
Test Accuracy:  0.9065387348969438
Reject Accuracy:  0.31702814951545916
threshold:  0.38000000000000006
Test Accuracy:  0.8928571428571429
Reject Accuracy:  0.3474850023073373
threshold:  0.39000000000000007
Test Accuracy:  0.8800639658848614
Reject Accuracy:  0.37517305029995385
threshold:  0.4000000000000001