# Image Classification Sample

|Item|Description|
|---|---|
|DeepLearning Framework|PyTorch|
|Dataset|CIFAR-10|
|Model Architecture|Simple CNN|


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import random
import numpy as np
import torch
import pprint

from data_loader.data_loader import DataLoader
from models.pytorch import simple_cnn

## Set Random Seed

In [3]:
seed=42

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7fdaa5573710>

## Device Settings

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

## Hyperparameters

In [5]:
epochs = 200
batch_size = 512
learning_rate = 0.001
weight_decay = 0.004

## Load Dataset and Normalize

In [6]:
dataset_dir = '/tmp/dataset'
dataloader = DataLoader('cifar10_pytorch', dataset_dir)

Files already downloaded and verified
Files already downloaded and verified


## Training Model

In [7]:
model = simple_cnn.SimpleCNN(device)

model.train(dataloader.dataset.trainloader, epochs=epochs, lr=learning_rate, wd=weight_decay)

[EPOCH #0] loss: 2.3054567161311117
[EPOCH #1] loss: 2.0840976985692214
[EPOCH #2] loss: 1.9446198368255556
[EPOCH #3] loss: 1.9036013297522136
[EPOCH #4] loss: 1.8671388843466818
[EPOCH #5] loss: 1.8459404222643383
[EPOCH #6] loss: 1.8252441330895695
[EPOCH #7] loss: 1.8138644247579787
[EPOCH #8] loss: 1.793832437929555
[EPOCH #9] loss: 1.7861325860557384
[EPOCH #10] loss: 1.771579421184342
[EPOCH #11] loss: 1.7580820068059178
[EPOCH #12] loss: 1.7644409114400774
[EPOCH #13] loss: 1.7429471612357972
[EPOCH #14] loss: 1.7351299950653973
[EPOCH #15] loss: 1.7269907555210995
[EPOCH #16] loss: 1.7268428049709883
[EPOCH #17] loss: 1.7189283369446289
[EPOCH #18] loss: 1.714641521240913
[EPOCH #19] loss: 1.7075603669527166
[EPOCH #20] loss: 1.702841138809214
[EPOCH #21] loss: 1.7022659230216985
[EPOCH #22] loss: 1.6982211147602444
[EPOCH #23] loss: 1.6931255078056417
[EPOCH #24] loss: 1.6917279020609646
[EPOCH #25] loss: 1.6863489481248082
[EPOCH #26] loss: 1.6873937644641215
[EPOCH #27] los

## Test Model

In [8]:
train_result = model.predict(dataloader.dataset.trainloader)
train_predictions, train_labels = train_result

In [9]:
train_eval_result = model.evaluate(train_labels, train_predictions)
pprint.pprint(train_eval_result)

{'accuracy': 0.9005,
 'classification_report': {'0': {'f1-score': 0.9063670411985018,
                                 'precision': 0.9176060668169707,
                                 'recall': 0.8954,
                                 'support': 5000},
                           '1': {'f1-score': 0.9471037811745776,
                                 'precision': 0.952467637540453,
                                 'recall': 0.9418,
                                 'support': 5000},
                           '2': {'f1-score': 0.8768361581920905,
                                 'precision': 0.9013727560718057,
                                 'recall': 0.8536,
                                 'support': 5000},
                           '3': {'f1-score': 0.8336433956721826,
                                 'precision': 0.8166123153654326,
                                 'recall': 0.8514,
                                 'support': 5000},
                           '4': {'f1-score': 0.8

In [10]:
test_result = model.predict(dataloader.dataset.testloader)
test_predictions, test_labels = test_result

In [11]:
test_eval_result = model.evaluate(test_labels, test_predictions)
pprint.pprint(test_eval_result)

{'accuracy': 0.8112,
 'classification_report': {'0': {'f1-score': 0.8308786185881157,
                                 'precision': 0.8441692466460269,
                                 'recall': 0.818,
                                 'support': 1000},
                           '1': {'f1-score': 0.8982456140350876,
                                 'precision': 0.9005025125628141,
                                 'recall': 0.896,
                                 'support': 1000},
                           '2': {'f1-score': 0.7479338842975206,
                                 'precision': 0.7735042735042735,
                                 'recall': 0.724,
                                 'support': 1000},
                           '3': {'f1-score': 0.6627623230844315,
                                 'precision': 0.6472831267874166,
                                 'recall': 0.679,
                                 'support': 1000},
                           '4': {'f1-score': 0.7940