In [5]:
import copy
import glob
import os
import os.path as osp
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import time
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from sklearn.metrics import classification_report, f1_score

from wrapper import OASIS
from split import split_data

scans_home = 'data/scans'
labels_file = 'data/OASIS3_MRID2Label_052918.csv'
stats_filepath = 'outputs_1_5_5.txt'
n_classes = 3
freeze_layers = False
start_freeze_layer = 'Mixed_5d'
use_parallel = True

loss_weights = torch.tensor([1.,5., 5.])
if torch.cuda.is_available():
    loss_weights = loss_weights.cuda()
criterion = nn.CrossEntropyLoss(weight=loss_weights)
optimizer_type = torch.optim.Adam
lr_scheduler_type = optim.lr_scheduler.StepLR
num_epochs = 10
best_model_filepath = None
load_model_filepath = None
#load_model_filepath = 'model_best.pth.tar'

def get_counts(filename_labels):
    counts = [0]*3
    for filename, label in filename_labels:
        counts[label] += 1
    return counts


def train_model(model, dataloaders, datasets, dataset_sizes, criterion, optimizer, scheduler, use_gpu, num_epochs=5):
    since = time.time()

    best_model_wts = model.state_dict()
    best_f1_score = 0.0
    best_acc = 0.0
    
    # list of models from all epochs
    model_list = []

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in tqdm(dataloaders[phase]):
                if use_gpu:
                    inputs = Variable(inputs.cuda())
                    labels = Variable(labels.cuda())
                    model = model.cuda()
                else:
                    input = Variable(inputs)
                    labels = Variable(labels)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                outputs = model(inputs)
                if type(outputs) == tuple:
                    outputs, _ = outputs
                _, preds = torch.max(outputs.data, 1)
                loss = criterion(outputs, labels)

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                # statistics
                running_loss += loss.data[0]
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.item() / dataset_sizes[phase]
            
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            with open(stats_filepath, 'a') as f:
                f.write('Epoch {} {} Loss: {:.4f} Acc: {:.4f}\n'.format(epoch, phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val':
                predictions = evaluate_model(model, dataloaders['val'], dataset_sizes['val'], use_cuda)
                true_y = [y for img, y in datasets['val']]
                f1 = f1_score(true_y, predictions, average = 'macro')
                all_f1s = f1_score(true_y, predictions, average = None)
                
                # print f1 score and write to file
                print('macro f1_score: {:.4f}'.format(f1))
                print('all f1_scores: {}'.format(str(all_f1s)))
                with open(stats_filepath, 'a') as f:
                    f.write('Epoch {} macro f1_score = {:.4f} \n'.format(epoch, f1))
                    f.write('all f1_scores: {}'.format(str(all_f1s)))
                
                #update epoch acc
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    
                # update best model based on f1_score
                if f1 > best_f1_score:
                    best_f1_score = f1
                    best_model_wts = model.state_dict()

                    state = {'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
                    if best_model_filepath is not None:
                        torch.save(state, best_model_filepath)
        
        model_list.append(copy.deepcopy(model))
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    with open(stats_filepath, 'a') as f:
        f.write('Best val Acc: {:4f}\n'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model_list, model


def evaluate_model(model, testset_loader, test_size, use_gpu):
    model.train(False)  # Set model to evaluate mode

    predictions = []
    # Iterate over data
    for inputs, labels in tqdm(testset_loader):
        # TODO: wrap them in Variable?
        if use_gpu:
            inputs = inputs.cuda()
            labels = labels.cuda()

        # forward
        outputs = model(inputs)
        if type(outputs) == tuple:
            outputs, _ = outputs
        _, preds = torch.max(outputs.data, 1)
        predictions.extend(preds.tolist())
    return predictions


def load_saved_model(filepath, model, optimizer=None):
    state = torch.load(filepath)
    model.load_state_dict(state['state_dict'])
    # Only need to load optimizer if you are going to resume training on the model
    if optimizer is not None:
        optimizer.load_state_dict(state['optimizer'])

In [6]:
train_filenames, val_filenames, test_filenames = split_data(scans_home, labels_file)
print('train filenames size: ', len(train_filenames))
print('validation filenames size: ', len(val_filenames))
print('test filenames size: ', len(test_filenames))
print('label counts for training set: ', get_counts(train_filenames))
print('label counts for validation set: ', get_counts(val_filenames))
print('label counts for test set: ', get_counts(test_filenames))

train_dataset = OASIS(train_filenames)
val_dataset = OASIS(val_filenames)
test_dataset = OASIS(test_filenames)
print([y for img, y in train_dataset])
print([y for img, y in val_dataset])
print([y for img, y in test_dataset])

#print out a sample image shape
'''image_array, label = train_dataset[4]
print(image_array.shape)'''
# print('training dataset size: ', len(train_dataset))
# print('validation dataset size: ', len(val_dataset))
# print('test dataset size: ', len(test_dataset))
trainset_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=4)
valset_loader = DataLoader(val_dataset, batch_size=10, shuffle=False, num_workers=4)
testset_loader = DataLoader(test_dataset, batch_size=10, shuffle=False, num_workers=4)

# Use GPU if available, otherwise stick with cpu
use_cuda = torch.cuda.is_available()
torch.manual_seed(123)
device = torch.device("cuda" if use_cuda else "cpu")
print(device)

#resnet18 = torchvision.models.resnet18()
inception = torchvision.models.inception_v3()
# Since imagenet has 1000 classes, we need to change our last layer according to the number of classes we have
n_features = inception.fc.in_features
inception.fc = nn.Linear(n_features, n_classes)

# Freeze layers if freeze_layer is True
for i, param in inception.named_parameters():
    if freeze_layers:
        param.requires_grad = False
    else:
        param.requires_grad = True
if freeze_layers:
    ct = []
    for name, child in inception.named_children():
        #if name == 'fc':
        if start_freeze_layer in ct:
            for params in child.parameters():
                params.requires_grad = True
        ct.append(name)
        
# He initialization
def init_weights(m):
    # if type(m) == nn.Linear or type(m) == nn.Conv1d:
    if m.requires_grad:
        nn.init.kaiming_normal_(m.weight)

# To view which layers are freezed and which layers are not freezed:
for name, child in inception.named_children():
    for name_2, params in child.named_parameters():
        print(name_2, params.requires_grad)

if use_parallel:
    print("[Using all the available GPUs]")
    inception = nn.DataParallel(inception, device_ids=[0, 1])

dataloaders = {'train': trainset_loader, 'val': valset_loader}
datasets = {'train': train_dataset, 'val': val_dataset}
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
optimizable_params = [param for param in inception.parameters() if param.requires_grad]
optimizer = optimizer_type(optimizable_params, lr=0.001)
exp_lr_scheduler = lr_scheduler_type(optimizer, step_size=7, gamma=0.1)

# If we want to load a model with saved parameters
if load_model_filepath is not None:
    load_saved_model(load_model_filepath, inception, optimizer)

num labels is 2107
num filenames is 2193
num experiments is 1950
counts per class: [1536, 322, 92]
train filenames size:  1365
validation filenames size:  292
test filenames size:  293
label counts for training set:  [1075, 225, 65]
label counts for validation set:  [230, 48, 14]
label counts for test set:  [231, 49, 13]




finished preprocessing
mean is 24.182172676186866
std is 35.69391971454612
finished preprocessing
mean is 23.371209937836372
std is 35.27608823338848
finished preprocessing
mean is 24.090649883711215
std is 35.42283481015584
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

conv.weight True
bn.weight True
bn.bias True
conv.weight True
bn.weight True
bn.bias True
conv.weight True
bn.weight True
bn.bias True
conv.weight True
bn.weight True
bn.bias True
conv.weight True
bn.weight True
bn.bias True
branch1x1.conv.weight True
branch1x1.bn.weight True
branch1x1.bn.bias True
branch5x5_1.conv.weight True
branch5x5_1.bn.weight True
branch5x5_1.bn.bias True
branch5x5_2.conv.weight True
branch5x5_2.bn.weight True
branch5x5_2.bn.bias True
branch3x3dbl_1.conv.weight True
branch3x3dbl_1.bn.weight True
branch3x3dbl_1.bn.bias True
branch3x3dbl_2.conv.weight True
branch3x3dbl_2.bn.weight True
branch3x3dbl_2.bn.bias True
branch3x3dbl_3.conv.weight True
branch3x3dbl_3.bn.weight True
branch3x3dbl_3.bn.bias True
branch_pool.conv.weight True
branch_pool.bn.weight True
branch_pool.bn.bias True
branch1x1.conv.weight True
branch1x1.bn.weight True
branch1x1.bn.bias True
branch5x5_1.conv.weight True
branch5x5_1.bn.weight True
branch5x5_1.bn.bias True
branch5x5_2.conv.weight True
br

In [7]:
model_list, best_model = train_model(inception,
                             dataloaders,
                             datasets,
                             dataset_sizes,
                             criterion,
                             optimizer,
                             exp_lr_scheduler,
                             use_cuda,
                             num_epochs)
    
epoch = 0 
for model in model_list:
    predictions = evaluate_model(model, valset_loader, len(val_dataset), use_cuda)
    true_y = [y for img, y in val_dataset]
    report = classification_report(true_y, predictions)
    with open(stats_filepath, 'a') as f:
        f.write('\n Epoch {} \n'.format(epoch))
        f.write(report)
    epoch += 1
    print(report)


  0%|          | 0/2045 [00:00<?, ?it/s]

Epoch 0/9
----------


100%|██████████| 2045/2045 [09:34<00:00,  3.56it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

train Loss: 0.0982 Acc: 0.6016


100%|██████████| 438/438 [00:42<00:00, 10.32it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

val Loss: 0.0732 Acc: 0.7110


100%|██████████| 438/438 [00:41<00:00, 10.66it/s]
  'precision', 'predicted', average, warn_for)
  0%|          | 0/2045 [00:00<?, ?it/s]

macro f1_score: 0.3703
all f1_scores: [0.83608456 0.27472527 0.        ]

Epoch 1/9
----------


100%|██████████| 2045/2045 [09:34<00:00,  3.56it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

train Loss: 0.0946 Acc: 0.6209


100%|██████████| 438/438 [00:42<00:00, 10.19it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

val Loss: 0.0798 Acc: 0.6532


100%|██████████| 438/438 [00:41<00:00, 10.64it/s]
  0%|          | 0/2045 [00:00<?, ?it/s]

macro f1_score: 0.3560
all f1_scores: [0.78610603 0.28196059 0.        ]

Epoch 2/9
----------


100%|██████████| 2045/2045 [09:34<00:00,  3.56it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

train Loss: 0.0944 Acc: 0.6439


100%|██████████| 438/438 [00:42<00:00, 10.25it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

val Loss: 0.0761 Acc: 0.6927


 77%|███████▋  | 337/438 [00:31<00:09, 10.61it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 2045/2045 [09:35<00:00,  3.56it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

train Loss: 0.0918 Acc: 0.6336


100%|██████████| 438/438 [00:42<00:00, 10.23it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

val Loss: 0.0858 Acc: 0.5668


100%|██████████| 438/438 [00:41<00:00, 10.62it/s]
  0%|          | 0/2045 [00:00<?, ?it/s]

macro f1_score: 0.3138
all f1_scores: [0.71680414 0.22448115 0.        ]

Epoch 5/9
----------


100%|██████████| 2045/2045 [09:34<00:00,  3.56it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

train Loss: 0.0955 Acc: 0.6324


100%|██████████| 438/438 [00:42<00:00, 10.23it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

val Loss: 0.0722 Acc: 0.6875


100%|██████████| 438/438 [00:41<00:00, 10.62it/s]
  0%|          | 0/2045 [00:00<?, ?it/s]

macro f1_score: 0.3723
all f1_scores: [0.81939649 0.29739777 0.        ]

Epoch 6/9
----------


100%|██████████| 2045/2045 [09:34<00:00,  3.56it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

train Loss: 0.0922 Acc: 0.6156


100%|██████████| 438/438 [00:42<00:00, 10.24it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

val Loss: 0.0858 Acc: 0.6207


100%|██████████| 438/438 [00:41<00:00, 10.65it/s]
  0%|          | 0/2045 [00:00<?, ?it/s]

macro f1_score: 0.3774
all f1_scores: [0.77272001 0.28943259 0.06993007]

Epoch 7/9
----------


100%|██████████| 2045/2045 [09:34<00:00,  3.56it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

train Loss: 0.0910 Acc: 0.6702


100%|██████████| 438/438 [00:43<00:00, 10.18it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

val Loss: 0.0737 Acc: 0.6877


100%|██████████| 438/438 [00:41<00:00, 10.64it/s]
  0%|          | 0/2045 [00:00<?, ?it/s]

macro f1_score: 0.3723
all f1_scores: [0.82006609 0.29692471 0.        ]

Epoch 8/9
----------


100%|██████████| 2045/2045 [09:34<00:00,  3.56it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

train Loss: 0.0891 Acc: 0.6599


100%|██████████| 438/438 [00:42<00:00, 10.21it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

val Loss: 0.0776 Acc: 0.6349


100%|██████████| 438/438 [00:41<00:00, 10.61it/s]
  0%|          | 0/2045 [00:00<?, ?it/s]

macro f1_score: 0.3711
all f1_scores: [0.77058249 0.34278565 0.        ]

Epoch 9/9
----------


100%|██████████| 2045/2045 [09:34<00:00,  3.56it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

train Loss: 0.0881 Acc: 0.6688


100%|██████████| 438/438 [00:42<00:00, 10.23it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

val Loss: 0.0746 Acc: 0.6571


100%|██████████| 438/438 [00:41<00:00, 10.66it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

macro f1_score: 0.3818
all f1_scores: [0.7900128 0.3554007 0.       ]

Training complete in 112m 36s
Best val Acc: 0.710989


100%|██████████| 438/438 [00:41<00:00, 10.61it/s]
  'precision', 'predicted', average, warn_for)
  0%|          | 0/438 [00:00<?, ?it/s]

             precision    recall  f1-score   support

          0       0.83      0.84      0.84      3447
          1       0.25      0.31      0.27       720
          2       0.00      0.00      0.00       210

avg / total       0.70      0.71      0.70      4377



100%|██████████| 438/438 [00:41<00:00, 10.60it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

             precision    recall  f1-score   support

          0       0.83      0.75      0.79      3447
          1       0.22      0.39      0.28       720
          2       0.00      0.00      0.00       210

avg / total       0.69      0.65      0.67      4377



100%|██████████| 438/438 [00:41<00:00, 10.62it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

             precision    recall  f1-score   support

          0       0.85      0.79      0.82      3447
          1       0.27      0.45      0.34       720
          2       0.00      0.00      0.00       210

avg / total       0.72      0.69      0.70      4377



100%|██████████| 438/438 [00:41<00:00, 10.56it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

             precision    recall  f1-score   support

          0       0.88      0.68      0.77      3447
          1       0.26      0.63      0.37       720
          2       0.00      0.00      0.00       210

avg / total       0.74      0.64      0.67      4377



100%|██████████| 438/438 [00:41<00:00, 10.55it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

             precision    recall  f1-score   support

          0       0.81      0.64      0.72      3447
          1       0.16      0.37      0.22       720
          2       0.00      0.00      0.00       210

avg / total       0.66      0.57      0.60      4377



100%|██████████| 438/438 [00:41<00:00, 10.59it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

             precision    recall  f1-score   support

          0       0.85      0.79      0.82      3447
          1       0.24      0.39      0.30       720
          2       0.00      0.00      0.00       210

avg / total       0.71      0.69      0.69      4377



100%|██████████| 438/438 [00:41<00:00, 10.63it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

             precision    recall  f1-score   support

          0       0.86      0.70      0.77      3447
          1       0.23      0.39      0.29       720
          2       0.06      0.10      0.07       210

avg / total       0.72      0.62      0.66      4377



100%|██████████| 438/438 [00:41<00:00, 10.61it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

             precision    recall  f1-score   support

          0       0.85      0.79      0.82      3447
          1       0.24      0.39      0.30       720
          2       0.00      0.00      0.00       210

avg / total       0.71      0.69      0.69      4377



100%|██████████| 438/438 [00:41<00:00, 10.62it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

             precision    recall  f1-score   support

          0       0.88      0.69      0.77      3447
          1       0.24      0.57      0.34       720
          2       0.00      0.00      0.00       210

avg / total       0.73      0.63      0.66      4377



100%|██████████| 438/438 [00:41<00:00, 10.61it/s]
  0%|          | 0/438 [00:00<?, ?it/s]

             precision    recall  f1-score   support

          0       0.88      0.72      0.79      3447
          1       0.26      0.57      0.36       720
          2       0.00      0.00      0.00       210

avg / total       0.74      0.66      0.68      4377



100%|██████████| 438/438 [00:41<00:00, 10.59it/s]


             precision    recall  f1-score   support

          0       0.87      0.71      0.78      3458
          1       0.26      0.56      0.36       718
          2       0.00      0.00      0.00       195

avg / total       0.73      0.65      0.68      4371



In [8]:
predictions = evaluate_model(best_model, testset_loader, len(test_dataset), use_cuda)
true_y = [y for img, y in test_dataset]
best_report = classification_report(true_y, predictions)

with open(stats_filepath, 'a') as f:
    f.write('\n Best report \n {}'.format(report))   
    print(report)

100%|██████████| 438/438 [00:41<00:00, 10.62it/s]


             precision    recall  f1-score   support

          0       0.88      0.72      0.79      3447
          1       0.26      0.57      0.36       720
          2       0.00      0.00      0.00       210

avg / total       0.74      0.66      0.68      4377



  'precision', 'predicted', average, warn_for)
