# Cassava Classification - PyTorch Starter (Train)

## References

1. [plot_confusion_matrix](https://deeplizard.com/learn/video/0LhiS6yu2qQ)
2. [sklearn metrics example](https://towardsdatascience.com/confusion-matrix-for-your-multi-class-machine-learning-model-ff9aa3bf7826)
3. [multi_class_classification](https://towardsdatascience.com/multi-class-classification-extracting-performance-metrics-from-the-confusion-matrix-b379b427a872)

## Library imports

In [2]:
# basic imports
import os
import numpy as np
import pandas as pd
import random
import itertools
from tqdm.notebook import tqdm
import math

# augumentations library
from albumentations.pytorch import ToTensorV2
import albumentations as A
import cv2

# DL library imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# metrics calculation
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import KFold, StratifiedKFold

# basic plotting library
import matplotlib.pyplot as plt

# interactive plots
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import warnings  
warnings.filterwarnings('ignore')

## Helpter function

In [3]:
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    #print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    #plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

In [4]:
index_label_map = {
                0: "Cassava Bacterial Blight (CBB)", 
                1: "Cassava Brown Streak Disease (CBSD)",
                2: "Cassava Green Mottle (CGM)", 
                3: "Cassava Mosaic Disease (CMD)", 
                4: "Healthy"
                }

class_names = [value for key,value in index_label_map.items()]

In [5]:
val_folder = 'val_predictions/R50_v1/'
model_name = 'R50_v1'

In [6]:
fold = 0

In [15]:
acc = []
for fold in range(5):
    val_preds = np.load(val_folder + model_name + '_val_preds_' + str(fold) + '.npy')
    val_labels = np.load(val_folder + model_name + '_val_labels_' + str(fold) + '.npy')
    cm = confusion_matrix(val_labels, val_preds)
    for i, val in enumerate(cm):
        #print("class {} accuracy: {}".format(i, val[i]/sum(val)*100))
        acc.append(val[i]/sum(val)*100)
print(acc)

[57.3394495412844, 79.68036529680366, 80.71278825995807, 95.02090459901179, 74.2248062015504, 64.6788990825688, 80.36529680365297, 76.10062893081762, 96.7692892436336, 76.35658914728683, 58.986175115207374, 81.05022831050228, 75.8909853249476, 96.80851063829788, 77.86407766990291, 51.61290322580645, 72.6027397260274, 74.42348008385744, 97.26443768996961, 76.69902912621359, 60.36866359447005, 80.54919908466819, 78.66108786610879, 95.93465045592706, 74.9514563106796]


In [18]:
acc = np.array(acc).reshape(5,5)

In [19]:
acc.mean(axis = 0)

array([58.59721811, 78.84956584, 77.15779409, 96.35955853, 76.01919169])

In [20]:
acc.std(axis=0)

array([4.2576704 , 3.15411392, 2.24039065, 0.79508299, 1.2914719 ])

In [9]:

#print(cm)
#plt.figure(figsize=(8,8))
#plot_confusion_matrix(cm, classes=class_names, normalize=True)

for class 0: accuracy: 57.3394495412844
for class 1: accuracy: 79.68036529680366
for class 2: accuracy: 80.71278825995807
for class 3: accuracy: 95.02090459901179
for class 4: accuracy: 74.2248062015504
