# Deep Pheno Net
### A convolutional neural network for classification of FHD in bread wheate
By Samuel Horovatin, s.horovatin@usask.ca

First, we must import relevant packages for data import and pre-processing. Note, images being fed to this network have been pre-processed using the phenoSEED script included in [BELT source code](https://gitlab.com/usask-speclab/phenoseed).

In [16]:
# general math libraries
import pandas as pd
import numpy as np

# for import of images and displaying images
from skimage.io import imread
import matplotlib.pyplot as plt
%matplotlib inline

# for creating training, validation, and test split
from sklearn.model_selection import train_test_split

# for evaluating the model
from sklearn.metrics import accuracy_score
# use for showing progress of loops
from tqdm import tqdm

# various pytorch libraries
import torch
from torch.autograd import Variable
from torch.nn import Linear, ReLU, CrossEntropyLoss, Sequential, Conv2d, MaxPool2d, Module, Softmax, BatchNorm2d, Dropout
from torch.optim import Adam, SGD

# imports for arguments and logging
import logging
from datetime import datetime, timedelta
import argparse
import os
import glob


The first step in the pipeline is to set up the logger.

In [11]:
class DeepPhenoNet:
    def __init__(self):
        self.outpath = "/birl2/users/sch923/Thesis/Data/deep_phen_out"
        self.inpath = "/birl2/users/sch923/Thesis/Data/phenoSEEDOutput"
        self.netsavepath = "/birl2/users/sch923/Thesis/Data/deep_pheno_net"
        self.labels = "birl2/users/sch923/Thesis/Data/deep_pheno_netlabels.csv"
        self.extension = '*.npz'

In [15]:
# Construct argument parser and define command line arguments

net_settings = DeepPhenoNet()
# arg_parser = argparse.ArgumentParser(description='A convolutional neural network script which classifys fusarium head blight in images of bread wheat kernels.')
# arg_parser.add_argument('-o', "--output", type=str, default=net_settings.outpath,
#                         help='path to output classification findings.')
# arg_parser.add_argument('-i', "--input", type=str, default=net_settings.inpath,
#                         help='path to input data - will be searched recursively.')
# arg_parser.add_argument('-n', "--network", type=str, default=net_settings.netsavepath,
#                         help='path load/save the trained network.')
# arg_parser.add_argument('-l', "--labels", type=str, default=net_settings.labels,
#                         help='path to csv label table. Should have headers: id, fdh')
# args = arg_parser.parse_args()

debug_level = getattr(logging, 'INFO', None)
logging.basicConfig(level=debug_level)

# Check validity of provided paths
# if not os.path.isdir(args.output):
#     logging.error(f'Provided output directory ({args.output}) cannot be found.')
#     exit()
# elif not os.path.isdir(args.input):
#     logging.error(f'Provided input directory ({args.input}) cannot be found.')
#     exit()
# elif not os.path.isdir(args.network):
#     logging.error(f'Provided network save directory ({args.network}) cannot be found.')
#     exit()
# elif not os.path.isfile(args.labels):
#     logging.error(f'Provided labels file ({args.labels}) cannot be found.')
#     exit()
# else:
#     # Slightly redundent in case where defaults are used.
#     net_settings.outpath = args.output
#     net_settings.inpath = args.input
#     net_settings.netsavepath = args.network
#     net_settings.labels = args.labels

logging.info(f' Output Directory Set: {net_settings.outpath}')
logging.info(f' Input Directory Set: {net_settings.inpath}')
logging.info(f' Network Save Directory Set: {net_settings.netsavepath}')
logging.info(f' Labels CSV found: {net_settings.labels}')

INFO:root: Output Directory Set: /birl2/users/sch923/Thesis/Data/deep_phen_out
INFO:root: Input Directory Set: /birl2/users/sch923/Thesis/Data/phenoSEEDOutput
INFO:root: Network Save Directory Set: /birl2/users/sch923/Thesis/Data/deep_pheno_net
INFO:root: Labels CSV found: birl2/users/sch923/Thesis/Data/deep_pheno_netlabels.csv


In [None]:
# Load the label metadata
labels = pd.read_csv(net_settings.labels)
images = glob.glob(os.path.join(net_settings.inpath, '**', net_settings.extension), recursive=True)
if len(images) == 0:
    logging.error(f' there are no {net_settings.extension} found in supplied directory: \n{net_settings.inpath}')
    exit()

images.sort()
   