# Analysis

## Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from sklearn.utils import class_weight
import matplotlib.pyplot as plt
import time
import os
import PIL
from sklearn.metrics import confusion_matrix, f1_score
import seaborn as sn
import pandas as pd
import copy
from collections import Counter

PHASES = ['train', 'val', 'test']
DATA_DIR = "./corn_dataset"

In [None]:
image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_DIR, x))
                  for x in ['train', 'val', 'test']}
                  
dataloaders = {
    'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=4,
                                             shuffle=True, num_workers=4),
    'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=4,
                                             shuffle=False, num_workers=4),
    'test': torch.utils.data.DataLoader(image_datasets['test'], batch_size=4,
                                             shuffle=False, num_workers=4)
}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
class_names = image_datasets['train']

label_count = {x: image_datasets[x].targets for x in ['train', 'val', 'test']}
distribution = {x: list(Counter(image_datasets[x].targets).values()) for x in ['train', 'val', 'test']}

classes = image_datasets['train'].classes
total_distribution = [distribution[x] for x in PHASES]
df_cm = pd.DataFrame(total_distribution, index=PHASES, columns=classes)
print(df_cm)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Load Model

In [None]:
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
for module, param in zip(model_ft.modules(), model_ft.parameters()):
	if isinstance(module, nn.BatchNorm2d):
		param.requires_grad = False
        
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 4)

model_ft.load_state_dict(torch.load('./model_weight/corn_full5.pth'))
model_ft.eval()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft.to(device)

## Analysis

### Analysis Helper Functions

In [None]:
def printResult(path, classes, output, percentage):
    print("Less than threshold!")
    print("Full path: ", path)
    print("True: ", classes)
    print("Predicted: ", output)
    print("Probability: ", percentage)
    print("\n")

def plotCM(y_true, y_pred, ax1, phase, threshold, plot_graph = False):
    class_titles = ('grass', 'high_tillage', 'low_tillage', 'no_tillage')
    cf_matrix = confusion_matrix(y_true, y_pred)
    
    if cf_matrix.shape[0] != 4 :
        df_cm = pd.DataFrame((cf_matrix.T/np.sum(cf_matrix, axis=1)).T, index = [i for i in  ('grass', 'high_tillage', 'low_tillage')],
                        columns = [i for i in  ('grass', 'high_tillage', 'low_tillage')])
    else:
        df_cm = pd.DataFrame((cf_matrix.T/np.sum(cf_matrix, axis=1)).T, index = [i for i in class_titles],
                        columns = [i for i in class_titles])

    total_correct = np.sum(cf_matrix.diagonal())
    total = cf_matrix.sum()
    accuracy = (total_correct * 1.0) / (total * 1.0)
    print("Number of correct data: ", total_correct)
    print("Total number of data: ", total)
    print("Accuracy: ", (total_correct * 1.0) / (total * 1.0))

    if plot_graph:
        if threshold == -1:
            ax1.set_title('%s Confusion Matrix: Tillage Classification without threshold' % (phase))
        else:
            print(cf_matrix)
            # ax1.set_title('%s Confusion Matrix: Tillage Classification with threshold %f' % (phase, threshold))
            plt.title('%s Confusion Matrix: Tillage Classification with threshold %f' % (phase, threshold))
            # ax1.savefig('%s_confusion_matrix_threshold_%d_.png' % (phase, threshold * 10))
        sn.heatmap(df_cm, ax=ax1, annot=True)

    return accuracy, total

    
def plotThreshold(threshold_high_correct, threshold_high_wrong, threshold_low_correct, threshold_low_wrong, ax1, phase, threshold, plot_graph = False):
    labels = 'Threshold High Correct', 'Threshold High Wrong', 'Threshold Low Correct', 'Threshold Low Wrong'
    sizes = [threshold_high_correct, threshold_high_wrong, threshold_low_correct, threshold_low_wrong]
    explode = (0, 0, 0, 0)  # only "explode" the 2nd slice (i.e. 'Hogs')

    if plot_graph:
    # fig1, ax1 = plt.subplots()
        ax1.pie(sizes, explode=explode, labels=labels, autopct='%1.1f%%',
                shadow=True, startangle=90)
        ax1.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
        ax1.set_title("Percentage analysis for %s with threshold of %f" % (phase, threshold))
        # ax1.savefig('%s_pie_threshold_%d_.png' % (phase, threshold * 10))
        plt.show()  
    # print("Higher than threshold correct: ", threshold_high_correct)
    # print("Higher than threshold wrong: ", threshold_high_wrong)
    # print("Lower than threshold correct: ", threshold_low_correct)
    # print("Lower than threshold wrong: ", threshold_low_wrong)
    

In [None]:
image_data_transforms = transforms.Compose([
        transforms.Resize((576, 768), antialias=True),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# thresholds = [0.5, 0.6, 0.7, 0.8, 0.9, 0.95]

def thresholdAnalysis(threshold_value, plot_graph = False):
    class_map = {
        0: 'grass',
        1: 'high_tillage',
        2: 'low_tillage',
        3: 'no_tillage'
    }

    acc, f_scores, supports = [], [], []

    # for folders in os.listdir(data_dir):
    for folders in ["train", "val", "test"]:
        print(folders)
        y_pred = []
        y_true = []

        y_pred_threshold = []
        y_true_threshold = []

        threshold_low_wrong = 0
        threshold_low_correct = 0
        threshold_high_wrong = 0
        threshold_high_correct = 0

        for classes in os.listdir("%s/" % (DATA_DIR) + folders):
            for files in os.listdir("%s/%s/%s/" %(DATA_DIR, folders, classes)):
                fullpath = DATA_DIR + "/" + folders + "/" + classes + "/" + files
                
                im = PIL.Image.open(fullpath)
                image = image_data_transforms(im)
                image = image.unsqueeze(0)
                image = image.to(device)
                out = model_ft(image)
                output = class_map[(torch.max(torch.exp(out), 1)[1]).data.cpu().numpy()[0]]


                y_pred.append(output)
                y_true.append(classes)

                out_val = torch.exp(out).data.cpu().numpy()
                percentage = np.around(out_val/ np.sum(out_val), 2)

                max_percentage = np.max(percentage)
                pred = np.argmax(percentage)

                if max_percentage < threshold_value:
                    if output != classes:
                        # printResult(fullpath, classes, output, percentage)
                        threshold_low_wrong += 1
                    else:
                        threshold_low_correct += 1
                else:
                    y_pred_threshold.append(output)
                    y_true_threshold.append(classes)
                    if output != classes:
                        if threshold_value >= 0.8:
                            printResult(fullpath, classes, output, percentage)
                        threshold_high_wrong += 1
                    else:
                        threshold_high_correct += 1
        
        # Build confusion matrix
        if plot_graph:
            # fig, (ax2, ax3) = plt.subplots(2, 1, figsize=(7,12), gridspec_kw={'height_ratios': [5, 5]})
            plt.figure(figsize = (7, 7))
        else:
            ax2 = None
            ax3 = None
        
        ax2 = None
        # plotCM(y_true, y_pred, ax1, folders, -1, plot_graph) 
        f_score = f1_score(y_true_threshold, y_pred_threshold, average='macro')
        print(f"f1 score: {f_score}")
        f_scores.append(f_score)
        # supports.append(support)
        accuracy, support = plotCM(y_true_threshold, y_pred_threshold, ax2, folders, threshold_value, plot_graph)    
        acc.append(accuracy)
        supports.append(support)
        # plotThreshold(threshold_high_correct, threshold_high_wrong, threshold_low_correct, threshold_low_wrong, ax3, folders, threshold_value, plot_graph)
        plt.savefig("./corn1/%s_cfm_%d.png" % (folders, int(threshold_value * 10)), facecolor="white")
        
    return acc, f_scores, supports


### Testing different thresholds

In [None]:
accs = [0, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
fs = [0, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
sups = [0, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]


In [None]:
t = 0
print("Threshold %f" % (t))
acc, f, sup = thresholdAnalysis(t, True)
accs[0] = acc
fs[0] = f
sups[0] = sup

In [None]:
t = 0.4
print("Threshold %f" % (t))
acc, f, sup = thresholdAnalysis(t, True)
idx = 1
accs[idx] = acc
fs[idx] = f
sups[idx] = sup

In [None]:
# a = thresholdAnalysis(0.5)
t = 0.5
print("Threshold %f" % (t))
acc, f, sup = thresholdAnalysis(t, True)
idx = 2
accs[idx] = acc
fs[idx] = f
sups[idx] = sup

In [None]:
# a = thresholdAnalysis(0.5)
threshold = [0.5, 0.6, 0.7, 0.8, 0.9]
t = 0.6
print("Threshold %f" % (t))
acc, f, sup = thresholdAnalysis(t, True)
idx = 3
accs[idx] = acc
fs[idx] = f
sups[idx] = sup

In [None]:
# a = thresholdAnalysis(0.5)
threshold = [0.5, 0.6, 0.7, 0.8, 0.9]
t = 0.7
print("Threshold %f" % (t))
acc, f, sup = thresholdAnalysis(t, True)
idx = 4
accs[idx] = acc
fs[idx] = f
sups[idx] = sup

In [None]:
# a = thresholdAnalysis(0.5)
threshold = [0.5, 0.6, 0.7, 0.8, 0.9]
t = 0.8
print("Threshold %f" % (t))
acc, f, sup = thresholdAnalysis(t, True)
idx = 5
accs[idx] = acc
fs[idx] = f
sups[idx] = sup

In [None]:
t = 0.9
print("Threshold %f" % (t))
acc, f, sup = thresholdAnalysis(t, True)
idx = 6
accs[idx] = acc
fs[idx] = f
sups[idx] = sup

In [None]:
a = np.array(accs)
fm = np.array(fs)
s = np.array(sups)
print(accs)
print(fm)
print(sups)

### Generate figures

In [None]:
t = [0, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
FIGURE_SAVE_PATH = "./corn/"
for i, phase in enumerate(['train', 'val', 'test']):
    plt.figure("%s phase" % (phase))
    plt.title("%s phase Accuracy & F1-macro vs. Thresholds" % (phase))
    plt.plot(t, a[:, i], label="accuracy")
    plt.plot(t, fm[:, i], label="macro-f1")
    plt.xlabel("Threshold value")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.savefig(f".{FIGURE_SAVE_PATH}_graph_{phase}_acc_macro.png", facecolor="white")
    plt.show()


    plt.figure()
    plt.title("%s phase Support vs. Thresholds" % (phase))
    plt.plot(t, s[:, i], label="support")
    plt.xlabel("Threshold value")
    plt.ylabel("Support")
    plt.savefig(f".{FIGURE_SAVE_PATH}_graph_{phase}_support.png" , facecolor="white")
    plt.legend()
    # plt.plot(t, s[:, i])
    plt.show()
    print(a[:, i])
    print(fm[:, i])
    print(s[:, i])

