# Classification
### Step 1: resize/crop images
Requires having run: `beescrape.py`
Depends on yolov5 and `shutil`

In [2]:
# init
import os
from os import listdir
from shutil import copyfile
import shutil
from PIL import Image
import skimage
import random
import numpy as np

# Directories
DATA_DIR = '/home/wilber/Documents/RESEARCH/research/beespotter'
OUT_DIR = '/m2docs/res/data'
CROPPED_PATH = '/m2docs/res/cropped_imgs'
cropped_files = listdir(CROPPED_PATH)

# Args:
VAL_SIZE = 256
TEST_SIZE = 128
preprocess = True
resize = True
length = 512
size = (length, length) # (512, 512)

# Uncropped args: (use the same images except without passing through yolo bee finder for comparison)
OUT_DIR_RAW = '/m2docs/res/data_raw'
UNCROPPED_PATH = '/m2docs/res/uncropped_imgs'
COPY_UNCROPPED = True

print("resizing images to ({},{}) and copying uncropped = {}".format(length,length,COPY_UNCROPPED))

# Species with >= 1000 images:
classes = ['Apis_mellifera','Bombus_impatiens','Bombus_auricomus','Bombus_bimaculatus','Bombus_griseocollis']


resizing images to (512,512) and copying uncropped = True


In [29]:
from tqdm import tqdm
from os import listdir
from os.path import isfile, join
import cv2
from IPython.display import display

TXT_DIR = '/m2docs/res/predicted_text/' # text directory
IMG_DIR = '/m2docs/res/bees_unsorted_all' # images directory

%rm -R /m2docs/res/data/*
%rm -R /m2docs/res/data_raw/*

VAL_SIZE = 2
TEST_SIZE = 2
NEW_ROTATE = True
# print("resizing images to {}".format(size))

def rotate(filename, angle):
    if NEW_ROTATE is True:
        pic = Image.open(join(IMG_DIR,filename))
        img = np.asarray(pic)
        [img_dim_x, img_dim_y, z] = img.shape
        print("{} size: {}x{}x{} mode {}".format(filename,img_dim_x,img_dim_y,z,pic.mode))

        # Get prediction
        if (isfile(join(TXT_DIR,filename.replace('.jpg','.txt')))):
            inference = open(join(TXT_DIR,filename.replace('.jpg','.txt')),"r")

            for line in inference:
                [id, x, y, xlen, ylen] = [float(i) for i in line.split()]

                box_center_x = int(y*img_dim_x)
                box_center_y = int(x*img_dim_y)
                box_length	 = int(xlen*img_dim_x)
                box_height	 = int(ylen*img_dim_y)
        else:
            return None

        edge_length = 0
        edge_length = max(box_length, box_height)
        
        print("{} preds: {}, {}, {}, {}".format(filename, box_center_x,box_center_y,edge_length,edge_length))
        (x1, x2, y1, y2) = (int(box_center_x - edge_length),
                            int(box_center_x + edge_length), 
                            int(box_center_y - edge_length),
                            int(box_center_y + edge_length))
        print("Bounds x {}-{} y {}-{}".format(x1,x2,y1,y2))

        if x1 < 0: x1 = 0
        if y1 < 0: y1 = 0
        if x2 > img_dim_x: x2 = img_dim_x
        if y2 > img_dim_y: y2 = img_dim_y

        im_arr = np.asarray(img)[x1:x2, y1:y2]
        
        if img_dim_x < edge_length*2 or img_dim_y < edge_length*2:
            # Just do the old version then
            print("Using old cropper (img too small)")
            pic = Image.open(os.path.join(CROPPED_PATH,filename))
            return pic.rotate(angle)
        
        rotated_image = np.asarray(Image.fromarray(im_arr,pic.mode).rotate(angle))
        crop_rot = rotated_image[rotated_image.shape[0]//4:
                            3*rotated_image.shape[0]//4,
                            rotated_image.shape[0]//4:
                            3*rotated_image.shape[0]//4]
        img = Image.fromarray(crop_rot,pic.mode).resize(size)
        return img
    else:
        pic = Image.open(os.path.join(CROPPED_PATH,filename))
        return pic.rotate(angle)

# from https://gist.github.com/Prasad9/28f6a2df8e8d463c6ddd040f4f6a028a
noise_modes = [None,'salt','pepper','s&p']
NOISE_AMOUNT = .01 # default amount
def add_noise(img, mode, noise_amount = .03):
    if mode is not None:
        gimg = skimage.util.random_noise(img, mode = mode, amount = random.uniform(0,noise_amount))
        return gimg
    else:
        print("oops, you shouldn't see this")

total_skipped = 0
for label in classes:  # for each type of bee
    skipnum = 0
    classID = classes.index(label)
    
    input_path = os.path.join(DATA_DIR, label)
    im_list = os.listdir(input_path)
    random.shuffle(im_list)
    
    # Attempt to create directories:
    for a in [OUT_DIR,OUT_DIR_RAW]:
        for b in ['/train/','/test/','/valid/']:
            if not os.path.exists(a+b+label):
                os.makedirs(a+b+label)
        
    train_path = os.path.join(OUT_DIR,'train/'+ label + '/')
    test_path  = os.path.join(OUT_DIR,'test/' + label + '/')
    valid_path = os.path.join(OUT_DIR,'valid/' + label + '/')
    train_path_raw = os.path.join(OUT_DIR_RAW,'train/'+ label + '/')
    test_path_raw  = os.path.join(OUT_DIR_RAW,'test/' + label + '/')
    valid_path_raw = os.path.join(OUT_DIR_RAW,'valid/' + label + '/')
    
    index = 0
    for img in tqdm(im_list):
        if img in cropped_files:
            pic = Image.open(os.path.join(CROPPED_PATH,img))
            
            if (COPY_UNCROPPED):
                    pic2 = Image.open(os.path.join(UNCROPPED_PATH,img))
            
            if resize: #resize image
                out = pic.resize(size)
                if (COPY_UNCROPPED):
                    out2 = pic2.resize(size)
            else:
                out = pic
                if (COPY_UNCROPPED):
                    out2 = pic2
            
            if index < VAL_SIZE:
                out.save(os.path.join(valid_path,img))
            elif index < VAL_SIZE + TEST_SIZE:
                out.save(os.path.join(test_path,img))
            if index < VAL_SIZE and COPY_UNCROPPED:
                out2.save(os.path.join(valid_path_raw,img))
            elif index < VAL_SIZE + TEST_SIZE and COPY_UNCROPPED:
                out2.save(os.path.join(test_path_raw,img))    
            else:    #training set, rotate
                out.save(os.path.join(train_path,img.replace('.jpg','-0.jpg')))
                if COPY_UNCROPPED:
                    out2.save(os.path.join(train_path_raw,img.replace('.jpg','-0.jpg')))
                if preprocess:
                    for rot,mode,ext in zip([random.randint(-44,45),random.randint(-90,90),random.randint(0,359)],[random.choice(noise_modes),random.choice(noise_modes),random.choice(noise_modes)],[1,2,3]):
                        print(' ',img,rot,mode,ext,pic.mode)
                        if mode is not None:
                            rotimg = rotate(img,rot)
                            if rotimg is not None:
                                Image.fromarray((add_noise(np.array(rotimg),mode,NOISE_AMOUNT)*255).astype(np.uint8),pic.mode).save(os.path.join(train_path,img.replace('.jpg','-'+str(ext)+'.jpg')))
                            else:
                                print("Rotation failed.")
                            if COPY_UNCROPPED:
                                Image.fromarray((add_noise(np.array(out2.rotate(rot)),mode,NOISE_AMOUNT)*255).astype(np.uint8),pic2.mode).save(os.path.join(train_path_raw,img.replace('.jpg','-'+str(ext)+'.jpg')))
                        else:
                            rotimg = rotate(img,rot)
                            if rotimg is not None:
                                rotimg.save(os.path.join(train_path,img.replace('.jpg','-'+str(ext)+'.jpg')))
                            else:
                                print("Rotation failed.")
                            if COPY_UNCROPPED:
                                out2.rotate(rot).save(os.path.join(train_path_raw,img.replace('.jpg','-'+str(ext)+'.jpg')))
                
        else:
            skipnum += 1
            index -= 1
            pass
        index += 1
#     print(label, "skipped", skipnum)
#     print(label, "total", index)
    
    total_skipped += skipnum
    
print("Images moved and rotated. Skipped ", total_skipped)

  0%|          | 4/2757 [00:00<06:56,  6.60it/s]

  3545-4.jpg -33 salt 1 RGB
3545-4.jpg size: 3072x4608x3 mode RGB
3545-4.jpg preds: 1334, 2210, 927, 927
Bounds x 407-2261 y 1283-3137
  3545-4.jpg -19 salt 2 RGB
3545-4.jpg size: 3072x4608x3 mode RGB
3545-4.jpg preds: 1334, 2210, 927, 927
Bounds x 407-2261 y 1283-3137
  3545-4.jpg 108 None 3 RGB
3545-4.jpg size: 3072x4608x3 mode RGB
3545-4.jpg preds: 1334, 2210, 927, 927
Bounds x 407-2261 y 1283-3137


  0%|          | 7/2757 [00:01<10:03,  4.56it/s]

  458-2.jpg 1 salt 1 RGB
458-2.jpg size: 301x373x3 mode RGB
458-2.jpg preds: 138, 231, 167, 167
Bounds x -29-305 y 64-398
Using old cropper (img too small)
  458-2.jpg -7 s&p 2 RGB
458-2.jpg size: 301x373x3 mode RGB
458-2.jpg preds: 138, 231, 167, 167
Bounds x -29-305 y 64-398
Using old cropper (img too small)
  458-2.jpg 7 salt 3 RGB
458-2.jpg size: 301x373x3 mode RGB
458-2.jpg preds: 138, 231, 167, 167
Bounds x -29-305 y 64-398
Using old cropper (img too small)
  7975-2.jpg -15 s&p 1 RGB
7975-2.jpg size: 3456x5184x3 mode RGB
7975-2.jpg preds: 1380, 3059, 917, 917
Bounds x 463-2297 y 2142-3976
  7975-2.jpg 24 pepper 2 RGB
7975-2.jpg size: 3456x5184x3 mode RGB
7975-2.jpg preds: 1380, 3059, 917, 917
Bounds x 463-2297 y 2142-3976
  7975-2.jpg 306 s&p 3 RGB
7975-2.jpg size: 3456x5184x3 mode RGB
7975-2.jpg preds: 1380, 3059, 917, 917
Bounds x 463-2297 y 2142-3976


  0%|          | 8/2757 [00:02<22:34,  2.03it/s]

  1608-1.jpg 39 None 1 RGB
1608-1.jpg size: 1015x651x3 mode RGB
1608-1.jpg preds: 780, 305, 592, 592
Bounds x 188-1372 y -287-897
Using old cropper (img too small)
  1608-1.jpg 9 None 2 RGB
1608-1.jpg size: 1015x651x3 mode RGB
1608-1.jpg preds: 780, 305, 592, 592
Bounds x 188-1372 y -287-897
Using old cropper (img too small)
  1608-1.jpg 251 None 3 RGB
1608-1.jpg size: 1015x651x3 mode RGB
1608-1.jpg preds: 780, 305, 592, 592
Bounds x 188-1372 y -287-897
Using old cropper (img too small)
  5822-2.jpg 0 pepper 1 RGB
5822-2.jpg size: 664x632x3 mode RGB
5822-2.jpg preds: 341, 326, 132, 132
Bounds x 209-473 y 194-458
  5822-2.jpg 57 s&p 2 RGB
5822-2.jpg size: 664x632x3 mode RGB
5822-2.jpg preds: 341, 326, 132, 132
Bounds x 209-473 y 194-458


  1%|          | 14/2757 [00:03<16:29,  2.77it/s]

  5822-2.jpg 250 s&p 3 RGB
5822-2.jpg size: 664x632x3 mode RGB
5822-2.jpg preds: 341, 326, 132, 132
Bounds x 209-473 y 194-458
  7917-1.jpg 33 s&p 1 RGB
7917-1.jpg size: 3024x4032x3 mode RGB
7917-1.jpg preds: 2177, 1762, 614, 614
Bounds x 1563-2791 y 1148-2376
  7917-1.jpg 46 s&p 2 RGB
7917-1.jpg size: 3024x4032x3 mode RGB
7917-1.jpg preds: 2177, 1762, 614, 614
Bounds x 1563-2791 y 1148-2376


  1%|          | 17/2757 [00:03<15:09,  3.01it/s]

  7917-1.jpg 101 None 3 RGB
7917-1.jpg size: 3024x4032x3 mode RGB
7917-1.jpg preds: 2177, 1762, 614, 614
Bounds x 1563-2791 y 1148-2376
  6043-1.jpg -33 s&p 1 RGB
6043-1.jpg size: 3456x4608x3 mode RGB
6043-1.jpg preds: 1823, 2289, 764, 764
Bounds x 1059-2587 y 1525-3053
  6043-1.jpg -41 s&p 2 RGB
6043-1.jpg size: 3456x4608x3 mode RGB
6043-1.jpg preds: 1823, 2289, 764, 764
Bounds x 1059-2587 y 1525-3053
  6043-1.jpg 155 None 3 RGB
6043-1.jpg size: 3456x4608x3 mode RGB
6043-1.jpg preds: 1823, 2289, 764, 764
Bounds x 1059-2587 y 1525-3053


  1%|          | 22/2757 [00:05<10:51,  4.20it/s]

  3136-1.jpg -39 s&p 1 RGB





ValueError: not enough values to unpack (expected 3, got 0)

In [6]:
import fnmatch
cropped_count = []
uncropped_count = []
for c in classes:
    print(c)
    cropped_count.append(len(fnmatch.filter(os.listdir(os.path.join(OUT_DIR + '/train/' + c + "/")), '*.jpg')))
    uncropped_count.append(len(fnmatch.filter(os.listdir(os.path.join(OUT_DIR_RAW + '/train/' + c + "/")), '*.jpg')))
    for t in ['/train/','/test/','/valid/']:
        print("images in {}: {}".format(t,len(fnmatch.filter(os.listdir(os.path.join(OUT_DIR + t + c + "/")), '*.jpg'))))
        print("images in {}: {} (uncropped)".format(t,len(fnmatch.filter(os.listdir(os.path.join(OUT_DIR_RAW + t + c + "/")), '*.jpg'))))
print(cropped_count)
print(uncropped_count)

Apis_mellifera
images in /train/: 3332
images in /train/: 3332 (uncropped)
images in /test/: 128
images in /test/: 128 (uncropped)
images in /valid/: 256
images in /valid/: 256 (uncropped)
Bombus_impatiens
images in /train/: 6532
images in /train/: 6532 (uncropped)
images in /test/: 128
images in /test/: 128 (uncropped)
images in /valid/: 256
images in /valid/: 256 (uncropped)
Bombus_auricomus
images in /train/: 1360
images in /train/: 1360 (uncropped)
images in /test/: 128
images in /test/: 128 (uncropped)
images in /valid/: 256
images in /valid/: 256 (uncropped)
Bombus_bimaculatus
images in /train/: 2908
images in /train/: 2908 (uncropped)
images in /test/: 128
images in /test/: 128 (uncropped)
images in /valid/: 256
images in /valid/: 256 (uncropped)
Bombus_griseocollis
images in /train/: 5788
images in /train/: 5788 (uncropped)
images in /test/: 128
images in /test/: 128 (uncropped)
images in /valid/: 256
images in /valid/: 256 (uncropped)
[3332, 6532, 1360, 2908, 5788]
[3332, 6532

In [7]:
# Flip images over axes to even out training sets:
import fnmatch
import os
from PIL import Image
import random
import skimage

classes = ['Apis_mellifera','Bombus_impatiens','Bombus_auricomus','Bombus_bimaculatus','Bombus_griseocollis']
ABS_PATH_TRAIN = '/m2docs/res/data/train'
MULTIPLIER = 1.05 # how many times more images need to be created: mult * len(largest class)
verbose = False

cropped_count = []
uncropped_count = []
for c in classes:
    print(c)
    cropped_count.append(len(fnmatch.filter(os.listdir(os.path.join(OUT_DIR + '/train/' + c + "/")), '*.jpg')))
    uncropped_count.append(len(fnmatch.filter(os.listdir(os.path.join(OUT_DIR_RAW + '/train/' + c + "/")), '*.jpg')))

print(classes)
print(cropped_count)
print(uncropped_count)

# add more images to largest class,
# and bring the others to the same count
target = max(cropped_count) * 1.1
print("Target number of images: {}".format(target))
    
print("Copying cropped images...")

noise_modes = [None,'salt','pepper','s&p']
NOISE_AMOUNT = .01 # default amount
def add_noise(img, mode, noise_amount = .03):
    if mode is not None:
        gimg = skimage.util.random_noise(img, mode = mode, amount = random.uniform(0, noise_amount))
        return gimg
    else:
        print("oops, you shouldn't see this")

for c_name in classes:
    current = cropped_count[classes.index(c_name)]
    if (verbose): 
        print(c_name, current, "->", target)
    im_list = os.listdir(os.path.join(ABS_PATH_TRAIN, c_name))
    while current < target:
        filename = random.choice(im_list)
        if(verbose):
            print(filename)
        out = Image.open(os.path.join(ABS_PATH_TRAIN, c_name, filename))
        
        rot = random.randint(0,359)
        mode = random.choice(noise_modes)
        if (verbose): 
            print(filename,rot,mode,out.mode)
        if (verbose): 
            print(filename.replace('.jpg','-x'+str(current)+'.jpg'))
        if mode is not None:
            Image.fromarray((add_noise(np.array(out.rotate(rot)),mode,NOISE_AMOUNT)*255).astype(np.uint8),pic.mode).save(os.path.join(ABS_PATH_TRAIN, c_name, filename.replace('.jpg','-x'+str(current)+'.jpg')))
        else:
            out.rotate(rot).save(os.path.join(ABS_PATH_TRAIN, c_name, filename.replace('.jpg','-x'+str(current)+'.jpg')))
        current += 1

cropped_count = []
for c in classes:
    cropped_count.append(len(fnmatch.filter(os.listdir(os.path.join(OUT_DIR + '/train/' + c + "/")), '*.jpg')))
print(cropped_count)
print("Done.\n")

if COPY_UNCROPPED:
    print("Copying uncropped images...")
    ABS_PATH_TRAIN_RAW = '/m2docs/res/data_raw/train'
    c_count = [len(fnmatch.filter(os.listdir(os.path.join(ABS_PATH_TRAIN_RAW,c_name)), '*')) for c_name in classes]
    print(classes)
    print(c_count)

    # add more images to largest class,
    # and bring the others to the same count
    print("Target count: n =",target)

    noise_modes = [None,'salt','pepper','s&p']
    NOISE_AMOUNT = .01 # default amount
    def add_noise(img, mode, noise_amount = .03):
        if mode is not None:
            gimg = skimage.util.random_noise(img, mode = mode, amount = random.uniform(0, noise_amount))
            return gimg
        else:
            print("oops, you shouldn't see this")

    for c_name in classes:
        current = uncropped_count[classes.index(c_name)]
        if (verbose): 
            print(c_name, current, "->", target)
        im_list = os.listdir(os.path.join(ABS_PATH_TRAIN_RAW, c_name))
        while current < target:
            filename = random.choice(im_list)
            if (verbose): 
                print(filename)
            out = Image.open(os.path.join(ABS_PATH_TRAIN_RAW, c_name, filename))

            rot = random.randint(0,359)
            mode = random.choice(noise_modes)
            if (verbose): 
                print(filename,rot,mode,out.mode)
            if (verbose): 
                print(filename.replace('.jpg','-x'+str(current)+'.jpg'))
            if mode is not None:
                Image.fromarray((add_noise(np.array(out.rotate(rot)),mode,NOISE_AMOUNT)*255).astype(np.uint8),pic.mode).save(os.path.join(ABS_PATH_TRAIN_RAW, c_name, filename.replace('.jpg','-x'+str(current)+'.jpg')))
            else:
                out.rotate(rot).save(os.path.join(ABS_PATH_TRAIN_RAW, c_name, filename.replace('.jpg','-x'+str(current)+'.jpg')))
            current += 1

    uncropped_count = []
    for c in classes:
        uncropped_count.append(len(fnmatch.filter(os.listdir(os.path.join(OUT_DIR_RAW + '/train/' + c + "/")), '*.jpg')))
    print(uncropped_count)
    print("Done.")

Apis_mellifera
Bombus_impatiens
Bombus_auricomus
Bombus_bimaculatus
Bombus_griseocollis
['Apis_mellifera', 'Bombus_impatiens', 'Bombus_auricomus', 'Bombus_bimaculatus', 'Bombus_griseocollis']
[3332, 6532, 1360, 2908, 5788]
[3332, 6532, 1360, 2908, 5788]
Target number of images: 7185.200000000001
Copying cropped images...
[7186, 7186, 7186, 7186, 7186]
Done.

Copying uncropped images...
['Apis_mellifera', 'Bombus_impatiens', 'Bombus_auricomus', 'Bombus_bimaculatus', 'Bombus_griseocollis']
[3332, 6532, 1360, 2908, 5788]
Target count: n = 7185.200000000001
[7186, 7186, 7186, 7186, 7186]
Done.


In [5]:
import torch
import torchvision
from torchvision import datasets, transforms
import torchvision.transforms.functional as TF
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn as nn
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
import os, datetime, math, io
import tensorflow as tf

from torch.utils.tensorboard import SummaryWriter 
import tensorflow as tf
from tensorboard.plugins.hparams import api as hp 

print("Libraries loaded")

Libraries loaded


## feature visualizer
* [From Sovit Ranjan Rath](https://debuggercafe.com/visualizing-filters-and-feature-maps-in-convolutional-neural-networks-using-pytorch/)

* Adapted to work with any model passed into it.

In [6]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import cv2 as cv
import argparse
from torchvision import models, transforms
from PIL import Image

def savefeatures(model, writer, image = "/m2docs/res/cropped_imgs/428-2.jpg"):
    if model is not None and writer is not None and image is not None:
        # load the model
        print(model)
        model_cpu = model.cpu()
        model_weights = [] # we will save the conv layer weights in this list
        conv_layers = [] # we will save the 49 conv layers in this list
        # get all the model children as list
        model_children = list(model_cpu.children())

        # counter to keep count of the conv layers
        counter = 0 
        # append all the conv layers and their respective weights to the list
        for i in range(len(model_children)):
            if type(model_children[i]) == nn.Conv2d:
                counter += 1
                model_weights.append(model_children[i].weight)
                conv_layers.append(model_children[i])
            elif type(model_children[i]) == nn.Sequential:
                for j in range(len(model_children[i])):
                    for child in model_children[i][j].children():
                        if type(child) == nn.Conv2d:
                            counter += 1
                            model_weights.append(child.weight)
                            conv_layers.append(child)
        print(f"Total convolutional layers: {counter}")

        # take a look at the conv layers and the respective weights
        for weight, conv in zip(model_weights, conv_layers):
            # print(f"WEIGHT: {weight} \nSHAPE: {weight.shape}")
            print(f"CONV: {conv} ====> SHAPE: {weight.shape}")

        # visualize the first conv layer filters
        plt.figure(figsize=(20, 17))
        for i, filter in enumerate(model_weights[0]):
            plt.subplot(8, 8, i+1) # (8, 8) because in conv0 we have 7x7 filters and total of 64 (see printed shapes)
            plt.imshow(filter[0, :, :].detach(), cmap='gray')
            plt.axis('off')
            plt.savefig('outputs/filter.png')
        plt.show()

        # read and visualize an image
        img = cv.imread(f"{image}")
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
        plt.imshow(img)
        plt.show()
        # define the transforms
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
        ])
        img = np.array(img)
        # apply the transforms
        img = transform(img)
        print(img.size())
        # unsqueeze to add a batch dimension
        img = img.unsqueeze(0)
        print(img.size())

        # pass the image through all the layers
        results = [conv_layers[0](img)]
        for i in range(1, len(conv_layers)):
            # pass the result from the last layer to the next layer
            results.append(conv_layers[i](results[-1]))
        # make a copy of the `results`
        outputs = results

        # visualize 64 features from each layer 
        # (although there are more feature maps in the upper layers)
        for num_layer in range(len(outputs)):
            figure = plt.figure(figsize=(30, 30))
            layer_viz = outputs[num_layer][0, :, :, :]
            layer_viz = layer_viz.data
            print(layer_viz.size())
            for i, filter in enumerate(layer_viz):
                if i == 64: # we will visualize only 8x8 blocks from each layer
                    break
                plt.subplot(8, 8, i + 1)
                plt.imshow(filter, cmap='gray')
                plt.axis("off")
            print(f"Saving layer {num_layer} feature maps...")
            plt.savefig(f"outputs/layer_{num_layer}.png")
            writer.add_figure("features/layer_"+str(num_layer), figure)
            # plt.show()
            plt.close()
print("save_features loaded.")

save_features loaded.


## Train classifier
Based on [this](https://nextjournal.com/gkoehler/pytorch-mnist) pytorch tutorial

Load images and train classifier.

In [23]:
length = 256 #256
skip_all = False

classes = ['Apis_mellifera','Bombus_impatiens','Bombus_auricomus','Bombus_bimaculatus','Bombus_griseocollis']
#classes = ['Bombus_auricomus','Bombus_bimaculatus','Bombus_griseocollis']

ABS_PATH_TRAIN = '/m2docs/res/data/train'
ABS_PATH_VALID = '/m2docs/res/data/valid'
ABS_PATH_TEST = '/m2docs/res/data/test'
ABS_PATH_TRAIN_RAW = '/m2docs/res/data_raw/train'
ABS_PATH_VALID_RAW = '/m2docs/res/data_raw/valid'
ABS_PATH_TEST_RAW = '/m2docs/res/data_raw/test'

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

def load(dir_name, batch_size, shuffle = False):
    return(
        #create a data loader
        torch.utils.data.DataLoader(
            datasets.ImageFolder(root = dir_name, 
                                 transform = transforms.Compose([
                                     transforms.Resize((length,length)),
                                     transforms.ToTensor()
                                 ])),
            batch_size = batch_size,
            num_workers = 8,
            shuffle = shuffle,
        )
    )

class Net1(nn.Module):
#     def __init__(self, hparams):
#         super(Net, self).__init__()
#         self.pool = nn.MaxPool2d(2, 2)
#         self.dropout = nn.Dropout2d(p = hparams[HP_DROPOUT])
        
#         self.conv1 = nn.Conv2d(3, hparams[HP_NUM_UNITS]//32, 2, padding = 2)
#         self.conv2 = nn.Conv2d(hparams[HP_NUM_UNITS]//32, hparams[HP_NUM_UNITS]//32, 5, padding = 4) 
#         self.conv3 = nn.Conv2d(hparams[HP_NUM_UNITS]//32, hparams[HP_NUM_UNITS]//16, 2, padding = 1)
#         self.conv4 = nn.Conv2d(hparams[HP_NUM_UNITS]//16, hparams[HP_NUM_UNITS]//8, 5, padding = 4)
#         self.conv5 = nn.Conv2d(hparams[HP_NUM_UNITS]//8, hparams[HP_NUM_UNITS], 3) 
                
#         self.fc1 = nn.Linear(hparams[HP_NUM_UNITS]*34**2, (hparams[HP_NUM_UNITS]*34**2)//2)
#         self.fc2 = nn.Linear((hparams[HP_NUM_UNITS]*34**2)//2, length)
#         self.fc3 = nn.Linear(length, len(classes))

    def __init__(self, hparams):
        super(Net1, self).__init__()
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout2d(p = hparams[HP_DROPOUT])
        
        self.conv1 = nn.Conv2d(3, hparams[HP_NUM_UNITS]//8, 2)
        self.conv2 = nn.Conv2d(hparams[HP_NUM_UNITS]//8, hparams[HP_NUM_UNITS]//8, 2) 
        self.conv3 = nn.Conv2d(hparams[HP_NUM_UNITS]//8, hparams[HP_NUM_UNITS]//4, 2)
        self.conv4 = nn.Conv2d(hparams[HP_NUM_UNITS]//4, hparams[HP_NUM_UNITS]//4, 2)
        self.conv5 = nn.Conv2d(hparams[HP_NUM_UNITS]//4, hparams[HP_NUM_UNITS], 2) 
                
        self.fc1 = nn.Linear(hparams[HP_NUM_UNITS]*30**2, length*4)
        self.fc2 = nn.Linear(length*4, length//2)
        self.fc3 = nn.Linear(length//2, len(classes))

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool(F.relu(self.conv5(x)))
        #print(x.size())
        x = x.view(x.size(0),-1)
        #print(x.size())
        x = F.relu(self.dropout(self.fc1(x)))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return x

class Net2(nn.Module):
    def __init__(self, hparams):
        super(Net2, self).__init__()
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout2d(p = hparams[HP_DROPOUT])
        
        self.conv1 = nn.Conv2d(3, hparams[HP_NUM_UNITS]//4, 2)
        self.conv2 = nn.Conv2d(hparams[HP_NUM_UNITS]//4, hparams[HP_NUM_UNITS]//2, 2) 
        self.conv3 = nn.Conv2d(hparams[HP_NUM_UNITS]//2, hparams[HP_NUM_UNITS], 2)
        
        self.fc1 = nn.Linear(hparams[HP_NUM_UNITS]*31**2, length)
        self.fc2 = nn.Linear(length, len(classes))

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        #print(x.size())
        x = x.view(x.size(0),-1)
        #print(x.size())
        x = F.relu(self.dropout(self.fc1(x)))
        x = F.relu(self.fc2(x))
        return x

def validate(network,device,load_valid,optimizer,criterion = nn.CrossEntropyLoss()):
    network.train().to(device)
    correct = 0
    valid_loss = 0
    for index, data in enumerate(load_valid, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = network(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        valid_loss += loss.item()

        # gather accuracy stats:
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == labels).float().sum().item()
    print(" Validation correct: {} / {}".format(correct,len(load_valid.dataset)))  
    accuracy = 100 * correct / len(load_valid.dataset)
    valid_loss = valid_loss / len(load_valid.dataset)
    return accuracy, valid_loss

def test(network,device,load_test,criterion = nn.CrossEntropyLoss()):
    network.eval().to(device)
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for index, data in enumerate(load_test, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            # forward + backward + optimize
            outputs = network(inputs).to(device)
            test_loss += nn.functional.nll_loss(outputs, labels).item()

            # gather accuracy stats:
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).float().sum().item()
    print(" Testing correct: {} / {}".format(correct,len(load_test.dataset)))        
    accuracy = 100 * correct / len(load_test.dataset)
    valid_loss = test_loss / len(load_test.dataset)
    return (accuracy, test_loss)

def train(max_epochs, min_epochs, epoch_stretch, batch_size, train_path, valid_path, test_path, labels, hparams, writer):
    epochs = max_epochs
    class_names = labels
    num_classes = len(class_names)
    train_batch = batch_size
    test_batch = batch_size
    SAVE_PATH = '/m2docs/res/trained_models/model'

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Device: {}".format(device))
    if (hparams[HP_NETWORK] == "Net1"):
        print(" Loading network 1")
        net = Net1(hparams).to(device)
    else:
        print(" Loading network 2")
        net = Net2(hparams).to(device)
    print(net)
    print("\n")
    
    tag = datetime.datetime.now().strftime(".%Y%m%d-%H%M%S")
    
    VISUAL_OUT = '/m2docs/res/visualizations/' + tag + "/"
    if not os.path.exists(VISUAL_OUT):
        os.makedirs(VISUAL_OUT)

    criterion = nn.CrossEntropyLoss()
    
    # Choose optimizer: (from ['adam','sgd','adagrad'])
    #optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    if (hparams[HP_OPTIMIZER] == 'adam'):
        optimizer = optim.Adam(net.parameters(), lr = .001)
    elif (hparams[HP_OPTIMIZER] == 'sgd'):
        optimizer = optim.SGD(net.parameters(), lr = .001, momentum=0.1)
    elif (hparams[HP_OPTIMIZER] == 'adagrad'):
        optimizer = optim.Adagrad(net.parameters(), lr = .001)
    else:
        # default
        optimizer = optim.Adam(net.parameters(), lr = 0.001)

    torch.manual_seed(417)
    
    load_train = load(train_path, batch_size, shuffle=True)
    load_valid = load(valid_path, batch_size, shuffle=True)
    load_test  = load(test_path, batch_size, shuffle=True)

    validation_accuracies = []
    best_epoch = 0
    epoch = 0
    killed = False
    print(type(list(net.children())[2].weight))
    while (epoch <= best_epoch + epoch_stretch or epoch < min_epochs) and epoch < max_epochs and not killed:
        net.train()
        run_loss = 0.0
        sum_loss = 0.0
        count = 0
        correct = 0.0
        categorical_correct = [0.0 for i in range(num_classes)]
        for index, data in enumerate(load_train, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            # gather accuracy stats:
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).float().sum().item()
            for i in range(num_classes):
                categorical_correct[i] += ((predicted==i) == (labels==i)).float().sum().item()
                #print(classes[i],categorical_correct[i]/((index+1)*batch_size))
            
            # print statistics
            run_loss += loss.item()
            sum_loss += loss.item()
            count += 1 
            if index % 100 == 0:    # print every 100 mini-batches
                print('  Epoch: {} [{}/{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f}%'.format(
                    epoch, int(correct), (index + 1) * batch_size, len(load_train.dataset),
                    100. * index / len(load_train), loss.item(), 100. * correct / ((index+1) * batch_size)))
                run_loss = 0.0
        accuracy = 100. * correct / len(load_train.dataset)
        print('Epoch: {}\tLoss: {:.6f}\tAcc: {:.6f}'.format(
                epoch, sum_loss/count, accuracy))
        writer.add_scalar("Loss/train", sum_loss/count, epoch)
        writer.add_scalar("Acc/train", accuracy, epoch)
        
        modules_list = iter(net.named_modules())
        next(modules_list)
        for module in modules_list:
            try:
                writer.add_histogram("Model/"+module[0]+".weights", module[1].weight, epoch)
                writer.add_histogram("Model/"+module[0]+".bias", module[1].bias, epoch)
            except:
                pass
        # Categorical accuracy:
        for i in range(num_classes):
            writer.add_scalar("Acc/" + classes[i],categorical_correct[i]/len(load_train.dataset), epoch)
        torch.save(net.state_dict(), SAVE_PATH + tag + "-progress")
        
        # get validation accuracy: 
        valid_acc, valid_loss = validate(net, device, load_valid, optimizer, criterion)
        writer.add_scalar("Loss/valid", valid_loss, epoch)
        writer.add_scalar("Acc/valid", valid_acc, epoch)
        print('Validation: acc: {:.6f}%\tloss: {:.6f}'.format(
                valid_acc, valid_loss))
        validation_accuracies.append(valid_acc)
        best_epoch = validation_accuracies.index(max(validation_accuracies))
        
        # get test accuracy: 
        test_acc, test_loss = test(net, device, load_test, criterion)
        writer.add_scalar("Loss/test", test_loss, epoch)
        writer.add_scalar("Acc/test", test_acc, epoch)
        print('Testing: acc: {:.6f}%\tloss: {:.6f}'.format(
                test_acc, test_loss))
        
        # this is the best epoch so far, save these weights:
        if (best_epoch == epoch):
            torch.save(net.state_dict(), SAVE_PATH + tag + "-best")
        if (epoch >= 2):
            if (validation_accuracies[epoch] == 100/num_classes):
                killed = True
                print("[!] This run has failed, accuracies are bad. Aborting.")
        
        print('Best: {} @ {:.6f}% -> epoch target {}'.format(best_epoch,validation_accuracies[best_epoch],max([best_epoch+epoch_stretch,min_epochs])))
        epoch += 1
        
    print('Done training.')
    torch.save(net.state_dict(), SAVE_PATH + tag + "-final")
    
    ## TODO: test here

    prediction_list = torch.zeros(0,dtype=torch.long).to(device)
    label_list = torch.zeros(0,dtype=torch.long).to(device)
    
    ## Testing
    net.eval().to(device)
    correct = 0
    valid_loss = 0
    with torch.no_grad():
        for index, data in enumerate(load_test, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            outputs = net(inputs)
            valid_loss += criterion(outputs, labels).item()

            _, predicted = torch.max(outputs.data, 1)
            prediction_list = torch.cat([prediction_list, predicted.view(-1)])  
            label_list = torch.cat([label_list, labels.view(-1)])
            #print("Predictions, ground:")
            #print(prediction_list)
            #print(label_list)
            correct += (predicted == labels).float().sum().item()
    print(" Testing correct: {} / {}".format(correct,len(load_test.dataset)))        
    t_acc = 100 * correct / len(load_test.dataset)
    t_loss = valid_loss / len(load_test.dataset)

    
    matrix = confusion_matrix(label_list.cpu().numpy(), prediction_list.cpu().numpy())
    print(matrix)
    fig, ax = plt.subplots()
    im = ax.imshow(matrix)
    # We want to show all ticks...
    ax.set_xticks(np.arange(len(classes)))
    ax.set_yticks(np.arange(len(classes)))
    # ... and label them with the respective list entries
    ax.set_xticklabels(classes)
    ax.set_yticklabels(classes)
    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    for i in range(len(classes)):
        for j in range(len(classes)):
            text = ax.text(j, i, "{:4.2f}".format(matrix[i, j]/sum(matrix[i])),
                           ha="center", va="center", color="w")

    ax.set_title("Testing confusion matrix (n = {})".format(len(load_test.dataset)))

    plt.show()
    writer.add_figure('Testing/conf',fig)
    
    if (hparams[HP_YOLOCROPPED]):
        # cropped images
        savefeatures(net, writer, image = "/m2docs/res/data/test/Apis_mellifera/2001-1.jpg")
    else:
        savefeatures(net, writer, image = "/m2docs/res/data_raw/test/Apis_mellifera/2001-1.jpg")
    
    ## Visualise features
#     FV = FilterVisualizer(net,VISUAL_OUT)
#     image_out = reconstructions_single_layer(list((net.children()))[2],'Layer 1 Block 1 Conv1',
#                                              list(range(6,12)),n_cols=3,
#                                              save_fig=True,album_hash=None)
    
    class_accuracy=100*matrix.diagonal() / matrix.sum(1)
    #print(classes)
    #print(class_accuracy)
    for i in range(len(classes)):
        print("{}: {:.4f}".format(classes[i],class_accuracy[i]))
    print("Best val_acc: {:6.4f}".format(max(validation_accuracies)))
    return max(validation_accuracies)

# magic here.
def run(run_dir, hparams):
    with tf.summary.create_file_writer(run_dir).as_default():
        hp.hparams(hparams)  # record the values used in this trial
        run_writer = SummaryWriter(log_dir = run_dir)            
        if (hparams[HP_YOLOCROPPED]):
            train_path = ABS_PATH_TRAIN
            valid_path = ABS_PATH_VALID
            test_path = ABS_PATH_TEST
            print("Running using cropped (yolo) images")
        else:
            train_path = ABS_PATH_TRAIN_RAW
            valid_path = ABS_PATH_VALID_RAW
            test_path = ABS_PATH_TEST_RAW
            print("Running using uncropped (plain) images")
        
        accuracy = train(max_epochs = 100, min_epochs = 15, epoch_stretch = 5, batch_size = 128, train_path = train_path, valid_path = valid_path, test_path = test_path, labels = classes, hparams = hparams, writer = run_writer)
        tf.summary.scalar(METRIC_ACCURACY, accuracy, step=1)

if (skip_all is False): 
    HP_NUM_UNITS = hp.HParam('channels', hp.Discrete([64, 128, 256]))
    HP_DROPOUT = hp.HParam('dropout', hp.RealInterval(0.1, 0.5))
    #HP_OPTIMIZER = hp.HParam('optimizer', hp.Discrete(['sgd','adagrad']))
    HP_OPTIMIZER = hp.HParam('optimizer', hp.Discrete(['sgd']))
    HP_YOLOCROPPED = hp.HParam('processed (yolo)', hp.Discrete([True, False]))
    HP_NETWORK = hp.HParam('network', hp.Discrete(["Net2","Net1"]))
    METRIC_ACCURACY = 'accuracy'
    
    hpdirname = 'runs/' + (datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + ':hparam_tuning')
    with tf.summary.create_file_writer(hpdirname).as_default():
        hp.hparams_config(
            hparams=[HP_NUM_UNITS, HP_DROPOUT, HP_OPTIMIZER, HP_YOLOCROPPED],
            metrics=[hp.Metric(METRIC_ACCURACY, display_name='best accuracy (validation)')]
        )
    
    session_num = 0
    print("Hyperparameter tuning.")
    print("Channels: {} options: {} to {}".format(len(HP_NUM_UNITS.domain.values), HP_NUM_UNITS.domain.values[0], HP_NUM_UNITS.domain.values[len( HP_NUM_UNITS.domain.values)-1]))
    print("Dropout: {} options: {}, {}, {}".format(3, HP_DROPOUT.domain.min_value, ((HP_DROPOUT.domain.min_value + HP_DROPOUT.domain.max_value)/2), HP_DROPOUT.domain.max_value))
    print("Optimizer: {} options, including '{}'".format(len(HP_NUM_UNITS.domain.values), HP_OPTIMIZER.domain.values[0]))
    print("Cropped: {} options (True/False)".format(len(HP_NUM_UNITS.domain.values)))
    print("Network: {} or {}".format(HP_NETWORK.domain.values[0],HP_NETWORK.domain.values[1]))
    for num_units in HP_NUM_UNITS.domain.values:
        for dropout_rate in (HP_DROPOUT.domain.min_value, ((HP_DROPOUT.domain.min_value + HP_DROPOUT.domain.max_value)/2), HP_DROPOUT.domain.max_value):
            for optimizer in HP_OPTIMIZER.domain.values:
                for network in HP_NETWORK.domain.values:
                    for yolo in HP_YOLOCROPPED.domain.values:
                        hparams = {
                            HP_NUM_UNITS: num_units,
                            HP_DROPOUT: float("%0.2f"%float(dropout_rate)),
                            HP_OPTIMIZER: optimizer,
                            HP_YOLOCROPPED: yolo,
                            HP_NETWORK: network,
                        }
                        torch.cuda.empty_cache()
                        run_name = "run-%d" % session_num
                        print("-> Starting trial %s" % (run_name))
                        print({h.name: hparams[h] for h in hparams})
                        if hparams[HP_YOLOCROPPED]:
                            run(os.path.join(hpdirname, '_' + str(session_num) + '_true'), hparams)
                        else:
                            run(os.path.join(hpdirname, '_' + str(session_num)) + '_false', hparams)
                        session_num += 1

Hyperparameter tuning.
Channels: 3 options: 64 to 256
Dropout: 3 options: 0.1, 0.3, 0.5
Optimizer: 3 options, including 'sgd'
Cropped: 3 options (True/False)
Network: Net1 or Net2
-> Starting trial run-0
{'channels': 64, 'dropout': 0.1, 'optimizer': 'sgd', 'processed (yolo)': False, 'network': 'Net1'}
Running using uncropped (plain) images
Device: cuda:0
 Loading network 1
Net1(
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout): Dropout2d(p=0.1, inplace=False)
  (conv1): Conv2d(3, 8, kernel_size=(2, 2), stride=(1, 1))
  (conv2): Conv2d(8, 8, kernel_size=(2, 2), stride=(1, 1))
  (conv3): Conv2d(8, 16, kernel_size=(2, 2), stride=(1, 1))
  (conv4): Conv2d(16, 16, kernel_size=(2, 2), stride=(1, 1))
  (conv5): Conv2d(16, 64, kernel_size=(2, 2), stride=(1, 1))
  (fc1): Linear(in_features=57600, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=5, bias=True)


KeyboardInterrupt: 

# 2 stage classifier

1. Train to detect *Apis Mellifera* vs. *Bombus Auricomus* vs. a merged set of the other classes.
2. Train a second model to tell apart the other three classes.'
3. Note that both of these datasets will by default be unbalanced unless something is changed above.
4. When running detections, if the first model's confidence is below a threshold, apply the second model to see if there is an improved prediction.

`todo: balance datasets`

In [None]:
import torch
import torchvision
from torchvision import datasets, transforms
import torchvision.transforms.functional as TF
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn as nn
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
import os, datetime, math

from torch.utils.tensorboard import SummaryWriter 
length = 512
var_droupout = 0.2
skip_all = True
tag1 = ':' + str(length) + '_STAGE-1_' + str(var_droupout)
tag2 = ':' + str(length) + '_STAGE-2_' + str(var_droupout)
writer1 = SummaryWriter(log_dir = os.path.join('runs/',(datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + tag1)))
writer2 = SummaryWriter(log_dir = os.path.join('runs/',(datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + tag2)))
print("tensorboard writing"+tag1)
print("tensorboard writing"+tag2)

classes1 = ['Apis_mellifera','Bombus_impatiens','Merged']
classes2 = ['Bombus_auricomus','Bombus_bimaculatus','Bombus_griseocollis']

ABS_PATH_TRAIN1 = '/m2docs/res/data1/train'
ABS_PATH_VALID1 = '/m2docs/res/data1/valid'
ABS_PATH_TEST1 = '/m2docs/res/data1/test'
ABS_PATH_TRAIN2 = '/m2docs/res/data2/train'
ABS_PATH_VALID2 = '/m2docs/res/data2/valid'
ABS_PATH_TEST2 = '/m2docs/res/data2/test'

class SaveFeatures():
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        self.features = output
    def close(self):
        self.hook.remove()

class FilterVisualizer():
        def __init__(self, network, OUTPUT_DIR):
                self.model = nn.Sequential(*list(network.children())[:-2]).cuda().eval()
                self.network = network
                self.OUTPUT_DIR = OUTPUT_DIR

        def visualize(self, sz, layer, filter, upscaling_steps=12, upscaling_factor=1.2, lr=0.1, opt_steps=20, blur=None, save=False, print_losses=False):
                with Torch.no_grad():
                    img = (np.random.random((sz, sz, 3)) * 20 + 128.)/255.
    #                img = np.random.uniform(0, 1, size=(sz, sz, 3)).astype(np.float32)
    #                median_filter_size = 4 if sz < 100 else 8
    #                img = scipy.ndimage.filters.median_filter(img, [median_filter_size,median_filter_size,1])

                    activations = SaveFeatures(layer)  # register hook

                    for i in range(upscaling_steps):  # scale the image up upscaling_steps times
                            train_tfms, val_tfms = tfms_from_model(network, sz)
                            img_var = V(val_tfms(img)[None], requires_grad=True)  # convert image to Variable that requires grad
                            optimizer = torch.optim.Adam([img_var], lr=lr, weight_decay=1e-6)
                            if i > upscaling_steps/2:
                                    opt_steps_ = int(opt_steps*1.3)
                            else:
                                    opt_steps_ = opt_steps
                            for n in range(opt_steps_):  # optimize pixel values for opt_steps times
                                    optimizer.zero_grad()
                                    self.model(img_var)
                                    loss = -1 * activations.features[0, filter].mean()
                                    if print_losses:
                                            if i%3==0 and n%5==0:
                                                    print(f'{i} - {n} - {float(loss)}')
                                    loss.backward()
                                    optimizer.step()
                            img = val_tfms.denorm(np.rollaxis(to_np(img_var.data),1,4))[0]
                            self.output = img
                            sz = int(upscaling_factor * sz)  # calculate new image size
                            img = cv2.resize(img, (sz, sz), interpolation = cv2.INTER_CUBIC)  # scale image up
                            if blur is not None: img = cv2.blur(img,(blur,blur))  # blur image to reduce high frequency patterns
                    activations.close()
                    return np.clip(self.output, 0, 1)
        
        def get_transformed_img(self,img,sz):
            with Torch.no_grad():
                train_tfms, val_tfms = tfms_from_model(network, sz)
                return val_tfms.denorm(np.rollaxis(to_np(val_tfms(img)[None]),1,4))[0]
        
        def most_activated(self, image, layer, limit_top=None):
            with Torch.no_grad():
                train_tfms, val_tfms = tfms_from_model(network, 224)
                transformed = val_tfms(image)

                activations = SaveFeatures(layer)  # register hook
                self.model(V(transformed)[None]);
                
                mean_act = [activations.features[0,i].mean().data.cpu().numpy()[0] for i in range(activations.features.shape[1])]
                activations.close()
                return mean_act

def plot_reconstructions_single_layer(imgs,layer_name,filters,
                                      n_cols=3,
                                      cell_size=4,save_fig=True,
                                      album_hash=None):
        n_rows = ceil((len(imgs))/n_cols)

        fig,axes = plt.subplots(n_rows,n_cols, figsize=(cell_size*n_cols,cell_size*n_rows))
                    
        for i,ax in enumerate(axes.flat):
                ax.grid(False)
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)

                if i>=len(filters):
                        pass

                ax.set_title(f'fmap {filters[i]}')

                ax.imshow(imgs[i])
        fig.suptitle(f'cnn {layer_name}', fontsize="x-large",y=1.0)
        plt.tight_layout()
        plt.subplots_adjust(top=0.88)
        save_name = layer_name.lower().replace(' ','_')
        if save_fig:
                plt.savefig(f'' + OUTPUT_DIR + 'network_{save_name}_fmaps_{"_".join([str(f) for f in filters])}.png')
                plt.close()
                return True
        else:
                plt.show()
                return None

def reconstructions_single_layer(layer,layer_name,filters,
                                 init_size=56, upscaling_steps=12, 
                                 upscaling_factor=1.2, 
                                 opt_steps=20, blur=5,
                                 lr=1e-1,print_losses=False,
                                 n_cols=3, cell_size=4,
                                 save_fig=True,album_hash=None):
        
        imgs = []
        for i in range(len(filters)):
                imgs.append(FV.visualize(init_size,layer, filters[i], 
                            upscaling_steps=upscaling_steps, 
                            upscaling_factor=upscaling_factor, 
                            opt_steps=opt_steps, blur=blur,
                            lr=lr,print_losses=print_losses))
                
        return plot_reconstructions_single_layer(imgs,layer_name,filters,
                                                 n_cols=n_cols,cell_size=cell_size,
                                                 save_fig=save_fig,album_hash=album_hash)

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

def load(dir_name, batch_size, shuffle = False):
    return(
        #create a data loader
        torch.utils.data.DataLoader(
            datasets.ImageFolder(root = dir_name, transform = transforms.ToTensor()),
            batch_size = batch_size,
            num_workers = 2,
            shuffle = shuffle
        )
    )

# Layer/network 1:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout2d(p = var_droupout)
        
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 32, 5)
        self.conv3 = nn.Conv2d(32, 64, 5)
        self.conv4 = nn.Conv2d(64, 128, 5)
        
        self.fc1 = nn.Linear(128*28*28, length)
        self.fc2 = nn.Linear(length, int(length/2))
        #self.fc3 = nn.Linear(int(length/2), int(length/4))
        self.fc3 = nn.Linear(int(length/2), 5)

    def forward(self, x):
        x = F.relu(self.pool(self.conv1(x)))
        x = self.dropout(F.relu(self.pool(self.conv2(x))))
        x = self.dropout(F.relu(self.pool(self.conv3(x))))
        x = (F.relu(self.pool(self.conv4(x))))
        #print(x.size())
        x = x.view(x.size(0),-1)
        #print(x.size())
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.dropout(self.fc3(x))
        return x
    
# Layer/network 2:    
class Net2(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout2d(p = var_droupout)
        
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 32, 5)
        self.conv3 = nn.Conv2d(32, 64, 5)
        self.conv4 = nn.Conv2d(64, 128, 5)
        
        self.fc1 = nn.Linear(128*28*28, length)
        self.fc2 = nn.Linear(length, int(length/2))
        #self.fc3 = nn.Linear(int(length/2), int(length/4))
        self.fc3 = nn.Linear(int(length/2), 5)

    def forward(self, x):
        x = F.relu(self.pool(self.conv1(x)))
        x = self.dropout(F.relu(self.pool(self.conv2(x))))
        x = self.dropout(F.relu(self.pool(self.conv3(x))))
        x = (F.relu(self.pool(self.conv4(x))))
        #print(x.size())
        x = x.view(x.size(0),-1)
        #print(x.size())
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.dropout(self.fc3(x))
        return x

def validate(network,device,load_valid,criterion = nn.CrossEntropyLoss()):
    network.eval().to(device)
    correct = 0
    valid_loss = 0
    with torch.no_grad():
        for index, data in enumerate(load_valid, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            # forward + backward + optimize
            outputs = network(inputs)
            valid_loss += criterion(outputs, labels).item()

            # gather accuracy stats:
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).float().sum().item()
    accuracy = 100 * correct / len(load_valid.dataset)
    valid_loss = valid_loss / len(load_valid.dataset)
    return accuracy , valid_loss

def train(max_epochs = 50, min_epochs = 5, epoch_stretch = 5, train_path = ABS_PATH_TRAIN, valid_path = ABS_PATH_VALID, test_path = ABS_PATH_TEST, labels = classes, batch_size = 24):
    epochs = max_epochs
    class_names = labels
    num_classes = len(class_names)
    train_batch = batch_size
    test_batch = 128
    SAVE_PATH = '/m2docs/res/models'

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Device: {}".format(device))
    net = Net().to(device)
    print(net)

    criterion = nn.CrossEntropyLoss()
    #optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    optimizer = optim.Adam(net.parameters(), lr = 0.001)

    torch.manual_seed(417)
    
    load_train = load(train_path, batch_size, shuffle=True)
    load_valid = load(valid_path, batch_size, shuffle=True)
    load_test  = load(test_path, batch_size, shuffle=True)

    validation_accuracies = []
    best_epoch = 0
    epoch = 0
    while (epoch <= best_epoch + epoch_stretch or epoch < min_epochs) and epoch < max_epochs:
        net.train()
        run_loss = 0.0
        sum_loss = 0.0
        count = 0
        correct = 0.0
        categorical_correct = [0.0 for i in range(num_classes)]
        for index, data in enumerate(load_train, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            # gather accuracy stats:
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).float().sum().item()
            for i in range(num_classes):
                categorical_correct[i] += ((predicted==i) == (labels==i)).float().sum().item()
                #print(classes[i],categorical_correct[i]/((index+1)*batch_size))
            
            # print statistics
            run_loss += loss.item()
            sum_loss += loss.item()
            count += 1 
            if index % 200 == 0:    # print every 200 mini-batches
                print('  Epoch: {} [{}/{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f}%'.format(
                    epoch, int(correct), (index + 1) * batch_size, len(load_train.dataset),
                    100. * index / len(load_train), loss.item(), 100. * correct / ((index+1) * batch_size)))
                run_loss = 0.0
        accuracy = 100. * correct / len(load_train.dataset)
        print('Epoch: {}\tLoss: {:.6f}\tAcc: {:.6f}'.format(
                epoch, sum_loss/count, accuracy))
        writer.add_scalar("Loss/train", sum_loss/count, epoch)
        writer.add_scalar("Acc/train", accuracy, epoch)
        
        modules_list = iter(net.named_modules())
        next(modules_list)
        for module in modules_list:
            try:
                writer.add_histogram("Model/"+module[0]+".weights", module[1].weight, epoch)
                writer.add_histogram("Model/"+module[0]+".bias", module[1].bias, epoch)
            except:
                pass
        # Categorical accuracy:
        for i in range(num_classes):
            writer.add_scalar("Acc/" + classes[i],categorical_correct[i]/len(load_train.dataset), epoch)
        torch.save(net.state_dict(), SAVE_PATH+"_progress")
        
        # get validation accuracy: 
        valid_acc, valid_loss = validate(net, device, load_valid, criterion)
        writer.add_scalar("Loss/valid", valid_loss, epoch)
        writer.add_scalar("Acc/valid", valid_acc, epoch)
        print('Validation: acc: {:.6f}%\tloss: {:.6f}'.format(
                valid_acc, valid_loss))
        validation_accuracies.append(valid_acc)
        best_epoch = validation_accuracies.index(max(validation_accuracies))
        print('Best: {} @ {:.6f}% -> epoch target {}'.format(best_epoch,validation_accuracies[best_epoch],max([best_epoch+epoch_stretch,min_epochs])))
        epoch += 1
        
    print('Done training.')
    torch.save(net.state_dict(), SAVE_PATH)

    prediction_list = torch.zeros(0,dtype=torch.long).to(device)
    label_list = torch.zeros(0,dtype=torch.long).to(device)
              
    with torch.no_grad():
        for data in load_test:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
              
            prediction_list = torch.cat([prediction_list, predicted.view(-1)])  
            label_list = torch.cat([label_list, labels.view(-1)])
    
    matrix = confusion_matrix(label_list.cpu().numpy(), prediction_list.cpu().numpy())
    print(matrix)
    fig, ax = plt.subplots()
    im = ax.imshow(matrix)
    # We want to show all ticks...
    ax.set_xticks(np.arange(len(classes)))
    ax.set_yticks(np.arange(len(classes)))
    # ... and label them with the respective list entries
    ax.set_xticklabels(classes)
    ax.set_yticklabels(classes)
    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    for i in range(len(classes)):
        for j in range(len(classes)):
            text = ax.text(j, i, matrix[i, j],
                           ha="center", va="center", color="w")

    ax.set_title("Testing confusion matrix")

    plt.show()
    writer.add_figure('Testing/conf',fig)
    
    class_accuracy=100*matrix.diagonal() / matrix.sum(1)
    print(classes)
    print(class_accuracy)

if (skip_all is False):
    # setup: create 2-stage directories:
    # data1: first stage: mellifera, impatiens, 3-merged
    %mkdir data1
    %rm -r data1/*
    %mkdir data1/train
    %cp -r data/train/Apis_mellifera data1/train/Apis_mellifera
    %cp -r data/train/Bombus_auricomus data1/train/Bombus_auricomus
    %cp -r data/train/Bombus_bimaculatus data1/train/Merged
    %cp data/train/Bombus_griseocollis/* data1/train/Merged
    %cp data/train/Bombus_impatiens/* data1/train/Merged
    %mkdir data1/valid
    %cp -r data/valid/Apis_mellifera data1/valid/Apis_mellifera
    %cp -r data/valid/Bombus_auricomus data1/valid/Bombus_auricomus
    %cp -r data/valid/Bombus_bimaculatus data1/valid/Merged
    %cp data/valid/Bombus_griseocollis/* data1/valid/Merged
    %cp data/valid/Bombus_impatiens/* data1/valid/Merged
    %mkdir data1/test
    %cp -r data/test/Apis_mellifera data1/test/Apis_mellifera
    %cp -r data/test/Bombus_auricomus data1/test/Bombus_auricomus
    %cp -r data/test/Bombus_bimaculatus data1/test/Merged
    %cp data/test/Bombus_griseocollis/* data1/test/Merged
    %cp data/test/Bombus_impatiens/* data1/test/Merged

    # data2: auricomus vs bimaculatus vs griseocollis
    %mkdir data2
    %rm -r data2/*
    %mkdir data2/train
    %cp -r data/train/Bombus_bimaculatus data2/train/Bombus_bimaculatus
    %cp -r data/train/Bombus_griseocollis data2/train/Bombus_griseocollis
    %cp -r data/train/Bombus_impatiens data2/train/Bombus_impatiens
    %mkdir data2/valid
    %cp -r data/valid/Bombus_bimaculatus data2/valid/Bombus_bimaculatus
    %cp -r data/valid/Bombus_griseocollis data2/valid/Bombus_griseocollis
    %cp -r data/valid/Bombus_impatiens data2/valid/Bombus_impatiens
    %mkdir data2/test
    %cp -r data/test/Bombus_bimaculatus data2/test/Bombus_bimaculatus
    %cp -r data/test/Bombus_griseocollis data2/test/Bombus_griseocollis
    %cp -r data/test/Bombus_impatiens data2/test/Bombus_impatiens
    
    # train
    train(max_epochs = 160, min_epochs = 70, epoch_stretch = 15, train_path = ABS_PATH_TRAIN, valid_path = ABS_PATH_VALID, test_path = ABS_PATH_TEST, labels = classes, batch_size = 32)