# Preprocessing
1. Take in full image as input
2. Pad image by 32 px on all sides
3. Ensure that image is a multiple of 128 x 128 by adding extra padding
4. initialize global probability map which is same spatial (h x w) size as input image with one channel for every class

# Bookkeeping
1. Double for loop over rows and columns with step size of 32 used for tile generation
2. Track indices between global and local images 
3. Add probabilities to global mask in corresponding location
4. Normalize probabilities in mask by dividing *[H x W x C] / [H x W x sum(C)]* (Taking advantage of softmax)
5. Take the argmax of each px. channelwise and return class predictions
6. Compare prediction with GT using  IOU/other metrics

In [9]:
# Import necessary packages
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
from PIL import Image

# update path
sys.path.append("../")

# custom
from utils.GetLowestGPU import GetLowestGPU
from utils.TileGenerator import TileGenerator
from utils.BuildUNet import BuildUNet

In [35]:
# Load model
# instantiate model
model_kwargs = {
    'input_features': 3,
    'output_features': 4,
    'layers': [32, 64, 128],
    'n_convs': 3,
    'dropout_rate': 0.1,
    'use_batchnorm': True,
    'hidden_activation': torch.nn.SELU(),
    'output_activation': torch.nn.Softmax(dim=1),
}

unet = BuildUNet(**model_kwargs)
opt = torch.optim.AdamW(unet.parameters(), lr=1e-3)
loss_function = torch.nn.BCELoss()


In [40]:
# load model weights
# weights = torch.load("../weights/pennycress_multiclass_best_val_model.pt", map_location=torch.device('cpu'))
checkpoint = torch.load("../checkpoints/checkpoint_61000.pt", map_location=torch.device('cpu'))


#extract weights from checkpoint
weights = checkpoint['model']

unet.load_state_dict(weights)
unet.eval()#inference mode (no dropout, batch norm uses running mean, etc.)

BuildUNet(
  (hidden_activation): SELU()
  (output_activation): Softmax(dim=1)
  (input_conv): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): SELU()
    (3): Dropout2d(p=0.1, inplace=False)
    (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): SELU()
    (7): Dropout2d(p=0.1, inplace=False)
    (8): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): SELU()
    (11): Dropout2d(p=0.1, inplace=False)
  )
  (encoder_ops): ModuleList(
    (0): Sequential(
      (0): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_