# Load Results

## Import packages

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
import gc
gc.collect()

# Change this to your own weight dir
WEIGHT_DIR = './model_weight/corn_model_resize_wd.pth'
# Change this to your data dir
DATA_DIR = './corn_dataset'

In [None]:
torch.cuda.empty_cache()

## Load Models

In [None]:
# load Resnet pretrained model
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features

# freeze certain parameters
for module, param in zip(model_ft.modules(), model_ft.parameters()):
	if isinstance(module, nn.BatchNorm2d):
		param.requires_grad = False

# tune the last layers
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 4)

# load the weights
model_ft.load_state_dict(torch.load(WEIGHT_DIR))
model_ft.eval()


## Load Images

In [None]:
image_data_dir = DATA_DIR
# Resize only
image_data_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

image_data_loader = torch.utils.data.DataLoader(datasets.ImageFolder(image_data_dir, image_data_transforms), batch_size=4,shuffle=True, num_workers=4)

GPU Support

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft.to(device)

## Confusion Matrices

In [None]:
y_pred = []
y_true = []

# iterate over test data
for inputs, labels in image_data_loader:
        output = model_ft(inputs.to(device)) # Feed Network

        output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
        y_pred.extend(output) # Save Prediction
        
        labels = labels.data.cpu().numpy()
        y_true.extend(labels) # Save Truth

In [None]:
classes = ('grass', 'high_tillage', 'low_tillage', 'no_tillage')

# Build confusion matrix
cf_matrix = confusion_matrix(y_true, y_pred)
print(cf_matrix)

df_cm = pd.DataFrame((cf_matrix.T/np.sum(cf_matrix, axis=1)).T, index = [i for i in classes],
                        columns = [i for i in classes])
ax = plt.axes()
print(df_cm)
ax.set_title('Confusion Matrix: Tillage Classification')
sn.heatmap(df_cm, ax=ax, annot=True)

## Print out results

In [None]:
import PIL

image_dir = DATA_DIR[:-1]
SHOW_FIGURE = False
class_map = {
    0: 'grass',
    1: 'high_tillage',
    2: 'low_tillage',
    3: 'no_tillage'
}
for phase in os.listdir(image_dir):
    for classes in os.listdir(image_dir + "/" + phase):
        for file in os.listdir(image_dir + "/" + "/" + phase + "/" + classes):
            fullpath = image_dir + "/" + "/" + phase + "/" + classes + "/" + file
            im = PIL.Image.open(fullpath)
            image = image_data_transforms(im)
            image = image.unsqueeze(0)
            image = image.to(device)
            # predict the image class
            out = model_ft(image)
            output = class_map[(torch.max(torch.exp(out), 1)[1]).data.cpu().numpy()[0]]
            # if predicted wrong
            if output != classes:
                print("Full path: ", fullpath)
                print("True: ", classes)
                print("Predicted: ", output)
                print("Probability: ", torch.exp(out).data.cpu().numpy()[0])
                if SHOW_FIGURE:
                    plt.figure()
                    plt.imshow(im)
                    plt.show()
                print("\n")

