In [None]:
from dataset.utils import pNormalize, classCount
from dataset.datasets import sentinel
from dataset.stats import quantiles
from model.models import UNET
from torch.utils.data import DataLoader
from train.utils import plots
from train.metrics import computeConfMats, computeClassMetrics, wma, printClassMetrics, printModelMetrics, plotConfusionMatrices, plotConfusionMatrix
import torch

In [None]:
# Hyperparams For Models
BATCH_SIZE = 10
NUM_WORKERS = 1

In [None]:
############################### MODEL 1 ############################### 
# <RGB Sentinel-2 TIMEPERIOD 1>

q_hi = quantiles['high']['1'][0:3]           # NB! RGB!
q_lo = quantiles['low']['1'][0:3]            # NB! RGB!
norm = pNormalize(maxPer=q_hi, minPer=q_lo)

# Create experimental dataset, rgb=True for 3 channels (default = False)
# POINT TO FOLDER WITH TIMEPERIOD(S) WITH SUBFOLDERS: 'test, 'train, 'val
test_set = sentinel(root_dir='./', img_transform=norm, data="test", timeperiod=1, rgb=True) # NB! RGB!

# Pass in the dataset into DataLoader to create an iterable over the dataset
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

# Define model
dataiter = iter(test_loader)
images, labels = dataiter.next()
model_1 = UNET(in_channels=images.shape[1],classes=28)
model_1.load_state_dict(torch.load('model_epoch_87.pt',map_location=torch.device('cpu'))) # Pass in 'trained_model.pt' and load model

# 1 EPOCH TESTING
dataiter = iter(test_loader) # Create an object which can be iterated one element at a time
model_1.eval() # TOGGLE ON EVALUATION MODE
with torch.no_grad():
     cMats_1 = torch.zeros((27,2,2),dtype=torch.int32) # n_class - unclassified class, i.e. 28-1 = 27
     
     for images, labels in dataiter:
          outputs = model_1(images)
          preds = torch.nn.functional.softmax(outputs,dim=1)
          preds = torch.argmax(preds,dim=1)
          cMats_1 += computeConfMats(labels,preds)      

#model_1.train() # TOGGLE ON TRAIN MODE WHEN EVALUATION IS DONE

In [None]:
# Compute Class IoU for Model 1
iou_1 = computeClassMetrics(cMats_1)[:,4]

In [None]:
############################### MODEL 2 ###############################
# <RGB Sentinel-2 TIMEPERIOD 2> 

q_hi = quantiles['high']['2'][0:3]           # NB! RGB!
q_lo = quantiles['low']['2'][0:3]            # NB! RGB!
norm = pNormalize(maxPer=q_hi, minPer=q_lo)

# Create experimental dataset, rgb=True for 3 channels (default = False)
# POINT TO FOLDER WITH TIMEPERIOD(S) WITH SUBFOLDERS: 'test, 'train, 'val
test_set = sentinel(root_dir='./', img_transform=norm, data="test", timeperiod=2, rgb=True) # NB! RGB!

# Pass in the dataset into DataLoader to create an iterable over the dataset
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

# Define model
dataiter = iter(test_loader)
images, labels = dataiter.next()
model_2 = UNET(in_channels=images.shape[1],classes=28)
model_2.load_state_dict(torch.load('model_epoch_99.pt',map_location=torch.device('cpu'))) # Pass in 'trained_model.pt' and load model

# 1 EPOCH TESTING
dataiter = iter(test_loader) # Create an object which can be iterated one element at a time
model_2.eval() # TOGGLE ON EVALUATION MODE
with torch.no_grad():
     cMats_2 = torch.zeros((27,2,2),dtype=torch.int32) # n_class - unclassified class, i.e. 28-1 = 27
         
     for images, labels in dataiter:
          outputs = model_2(images)
          preds = torch.nn.functional.softmax(outputs,dim=1)
          preds = torch.argmax(preds,dim=1)
          cMats_2 += computeConfMats(labels,preds)

#model_2.train() # TOGGLE ON TRAIN MODE WHEN EVALUATION IS DONE

In [None]:
# Compute Class IoU for Model 2
iou_2 = computeClassMetrics(cMats_2)[:,4]

In [None]:
# Compute Late Fusion Weights (Performance Weighting)
iou_sum = iou_1 + iou_2

for i in (range(len(iou_sum))):
    if iou_sum[i] != 0:
        iou_sum[i] = 1/iou_sum[i]

lf_weights_1 = torch.multiply(iou_1, iou_sum)
lf_weights_2 = torch.multiply(iou_2, iou_sum)

# NB! Label 0 is uniformly-weighted (not performance weighted)
s = torch.tensor([0.5])

lf_weights_1 = torch.cat((s,lf_weights_1),dim=0)
lf_weights_2 = torch.cat((s,lf_weights_2),dim=0)

In [None]:
################# Late Fusion: Model 1 & Model 2 Using Performance Weighted Bayesian Sum Rule #################

# Best way to iterate over 2 dataloaders (so memory leakage problem is avoided)
# https://stackoverflow.com/questions/51444059/how-to-iterate-over-two-dataloaders-simultaneously-using-pytorch
# Late Fusion as proposed by:
# https://github.com/alessandrosebastianelli/S1-S2-DataFusion/blob/main/Main.ipynb

q_hi_1 = quantiles['high']['1'][0:3]                   # NB! RGB!
q_lo_1 = quantiles['low']['1'][0:3]                    # NB! RGB!
norm_1 = pNormalize(maxPer=q_hi_1, minPer=q_lo_1)

q_hi_2 = quantiles['high']['2'][0:3]                     # NB! RGB!
q_lo_2 = quantiles['low']['2'][0:3]                      # NB! RGB!
norm_2 = pNormalize(maxPer=q_hi_2, minPer=q_lo_2)

# Create experimental dataset, rgb=True for 3 channels (default = False)
# POINT TO FOLDER WITH TIMEPERIOD(S) WITH SUBFOLDERS: 'test, 'train, 'val
test_set_1 = sentinel(root_dir='./', img_transform=norm_1, data="test", timeperiod=1, rgb=True)  ### MODEL 1 <> NB! RGB!
test_set_2 = sentinel(root_dir='./', img_transform=norm_2, data="test", timeperiod=2, rgb=True)  ### MODEL 2 <> NB! RGB!

# Pass in the dataset into DataLoader to create an iterable over the dataset
test_loader_1 = DataLoader(test_set_1, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)                  ### MODEL 1
test_loader_2 = DataLoader(test_set_2, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)                  ### MODEL 2

# 1 EPOCH TESTING
# Create an object which can be iterated one element at a time
dataiter_1 = iter(test_loader_1) ### MODEL 1

model_1.eval() # TOGGLE ON EVALUATION MODE
model_2.eval() # TOGGLE ON EVALUATION MODE
with torch.no_grad():         
     cMats_lf = torch.zeros((27,2,2),dtype=torch.int32) # n_class - unclassified class, i.e. 28-1 = 27
         
     predarr = torch.tensor([],dtype=torch.int32)
     labelarr = torch.tensor([],dtype=torch.int32)
     
     for i, (images_2, labels_2) in enumerate(test_loader_2):
          
          try:
               (images_1, labels_1) = next(dataiter_1)
          except StopIteration:
               dataiter_1 = iter(test_loader_1)
               (images_1, labels_1) = next(dataiter_1)
               
          outputs_1 = model_1(images_1)
          outputs_2 = model_2(images_2)
     
          softmaxOutput_1 = torch.nn.functional.softmax(outputs_1,dim=1)
          softmaxOutput_2 = torch.nn.functional.softmax(outputs_2,dim=1)
     
          for i in range(len(iou_sum)):
               softmaxOutput_1[:,i,:,:] = torch.multiply(softmaxOutput_1[:,i,:,:],lf_weights_1[i])
               softmaxOutput_2[:,i,:,:] = torch.multiply(softmaxOutput_2[:,i,:,:],lf_weights_2[i])
     
          softmaxWeightedSum = torch.add(softmaxOutput_1,softmaxOutput_2)
     
          preds = torch.argmax(softmaxWeightedSum,dim=1)
          
          cMats_lf += computeConfMats(labels_1,preds)                 ## NB! Labels for Model 1
          
          # Plot predictions
          plots(preds, labels_1, images_1, savedir='./', idx=i, source='S2')
          
          # Flatten dimensions BxHxW --> B*H*W and concatenate
          predarr = torch.cat((predarr, preds.reshape(-1)))
          labelarr = torch.cat((labelarr, labels_1.reshape(-1)))      ## NB! Labels for Model 1
          
#model_1.train() # TOGGLE ON TRAIN MODE WHEN EVALUATION IS DONE
#model_2.train()     

In [None]:
# Get Class Counts for Dataset 1
classCounts,_ = classCount(test_loader_1)

In [None]:
# Compute Class and Model Metrics for Late Fusion Model
class_metrics_lf = computeClassMetrics(cMats_lf)
model_metrics_lf = wma(class_metrics_lf,classCounts)

In [None]:
# Print Model Metrics for Late Fusion Model
printModelMetrics(model_metrics_lf)

In [None]:
# Print Class Metrics for Late Fusion Model
printClassMetrics(class_metrics_lf,classCounts)

In [None]:
# Plot N_CLASS X N_CLASS Confusion Matrix for Late Fusion Model
plotConfusionMatrix(yTrue=labelarr,yPred=predarr)

In [None]:
# Plot Confusion Matrices for Late Fusion Model
plotConfusionMatrices(cMats_lf)