<a href="https://colab.research.google.com/github/plant-ai-biophysics-lab/DeformableCNN-PlantTraits/blob/main/example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Training and Evaluation Pipeline

### Data and config setup

Import libraries

In [None]:
import os
import time
import torch, torchvision
import numpy as np
import torch.nn as nn
from torch.functional import split
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
from sklearn.model_selection import train_test_split, StratifiedKFold

from torch.utils.tensorboard import SummaryWriter

from datatools import *
from engine import train_single_epoch, validate
from loss import NMSELoss
from architecture import GreenhouseMidFusionRegressor

Download 2021 Autonomous Greenhouse Challenge dataset

In [2]:
from IPython.display import clear_output

!mkdir data
!wget https://data.4tu.nl/ndownloader/files/28906503 -O data/data.zip #for linux OS
# !curl https://data.4tu.nl/ndownloader/files/28906503 -O data.zip #for macOS
!unzip data/data.zip
!rm data/data.zip

clear_output(wait=True)

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  844M    0  844M    0     0  4203k      0 --:--:--  0:03:25 --:--:-- 3329k
curl: (6) Could not resolve host: data.zip
unzip:  cannot find or open data.zip, data.zip.zip or data.zip.ZIP.


Define data and output directories

In [None]:
sav_dir='model_output/'
if not os.path.exists(sav_dir):
    os.mkdir(sav_dir)
RGB_Data_Dir   = '/data/RGBImages/'
Depth_Data_Dir = '/data/DepthImages/'  
JSON_Files_Dir = '/data/GroundTruth_All_388_Images.json'

Set model architectures options:
- single vs. multi input (SI- or MI-)
- single vs. multi output (-SO or -MO)
- deformable vs. standard convolutions

In [1]:
ConvType = 'deformable' # 'standard'

training_category = 'MIMO' #'MIMO', 'MISO', 'SIMO', 'SISO'

# Multi-input, multi-output model
if training_category   == 'MIMO':
    transform_type = get_transforms(train=False) 
    inputs = ['RGB-D']
    outputs = ['All']
    NumOutputs = 5
    
# Multi-input, single-output model
elif training_category == 'MISO':
    transform_type = get_transforms(train=False)
    inputs = ['RGB-D']
    outputs = ['FW','DW','H','D','LA']
    NumOutputs = 1
    
# Single-input, multi-output model
elif training_category == 'SIMO':
    transform_type = get_RGB_transforms(train=False)
    inputs = ['RGB','D']
    outputs = ['All']
    NumOutputs = 5
    
# Single-input, single-output model
elif training_category == 'SISO':
    transform_type = get_RGB_transforms(train=False)
    inputs = ['RGB','D']
    outputs = ['FW','DW','H','D','LA']
    NumOutputs = 1

Set other model config parameters

In [None]:
split_seed = 12    
num_epochs = 400

Create PyTorch dataset, create PyTorch dataloader, and split train/val/test

In [2]:
# Instantiate the PyTorch datalaoder the autonomous greenhouse dataset.
dataset = GreenhouseDataset(rgb_dir = RGB_Data_Dir, 
                            d_dir = Depth_Data_Dir, 
                            jsonfile_dir = JSON_Files_Dir, 
                            transforms = transform_type) 

# Remove last 50 images from training/validation set. These are the test set.                         
dataset.df= dataset.df.iloc[:-50]

# Split train and validation set. Stratify based on variety.
train_split, val_split = train_test_split(dataset.df, 
                                          test_size = 0.2, 
                                          random_state = split_seed, 
                                          stratify = dataset.df['Variety'])
train = torch.utils.data.Subset(dataset, train_split.index.tolist())
val   = torch.utils.data.Subset(dataset, val_split.index.tolist())
dataset.set_indices(train.indices, val.indices)
                                                                                     
# Create train and validation dataloaders
train_loader = torch.utils.data.DataLoader(train, batch_size=10, num_workers=12, shuffle=True)#, sampler=train_sampler)
val_loader   = torch.utils.data.DataLoader(val,   batch_size=10, shuffle=False, num_workers=12)#, sampler=val_sampler)


Define the loss function as Normalized Mean Squared Error, as required for the 2021 Autonomous Greenhouse Challenge 

In [3]:
criterion = NMSELoss()

### Training

Define the training loop and fit the model.

In [None]:
# Training loop
device = torch.device('cuda')

for input in inputs:
    for output in outputs:
        dataset.input = input
        dataset.out = output
        model = GreenhouseMidFusionRegressor(input_data_type = input, num_outputs = NumOutputs, conv_type = ConvType)
        model.to(device)
        params = [p for p in model.parameters() if p.requires_grad]

        optimizer = torch.optim.Adam(params, 
                                     lr=0.0005, 
                                     betas=(0.9, 0.999), 
                                     eps=1e-08, 
                                     weight_decay = 0, 
                                     amsgrad = False)  # select an optimzer for each run
    
                                
        best_val_loss = 9999999 # initial dummy value
        current_val_loss = 0
        # training_val_loss=0
           
        writer = SummaryWriter()
        start = time.time()
                                
        for epoch in range(num_epochs):
            with open('run.txt', 'a') as f:
                f.write('\n')
                f.write('Epoch: '+ str(epoch + 1) + ', Time Elapsed: '+ str((time.time()-start)/60) + ' mins')
            print('Epoch: ', str(epoch + 1), ', Time Elapsed: ', str((time.time()-start)/60), ' mins')

            train_single_epoch(model, dataset, device, criterion, optimizer, writer, epoch, train_loader)

            best_val_loss = validate(model, dataset, device, training_category, sav_dir, criterion, writer, epoch, val_loader, best_val_loss)

### Evalutation

Define the test dataset

In [8]:
# Instantiate the PyTorch datalaoder the autonomous greenhouse dataset.
testset = GreenhouseDataset(rgb_dir = RGB_Data_Dir, 
                            d_dir = Depth_Data_Dir, 
                            jsonfile_dir = JSON_Files_Dir, 
                            transforms = transform_type)

# Grab last 50 images as test dataset
testset.df = testset.df[-50:]

# Get testset_size
testset_size = testset.df.shape[0]

# Create test dataloader
test_loader = torch.utils.data.DataLoader(testset, 
                                          batch_size = 50,
                                          num_workers = 0, 
                                          shuffle = False)

Define loss functions for model evaluation

In [6]:
cri = NMSELoss()
mse = nn.MSELoss()

Run the evaluation Loop

In [11]:
# Evaluation loop
device=torch.device('cuda')

with torch.no_grad():
    for input in inputs:
        final = torch.zeros((testset_size,0))
        all_targets = torch.zeros((testset_size,0))
        for output in outputs:
            print('Input is ', input)
            testset.input = input
            testset.out = output

            device=torch.device('cuda')
            model= GreenhouseMidFusionRegressor(input_data_type = input, 
                                                num_outputs = NumOutputs, 
                                                conv_type = ConvType)
            model.to(device)
            model.load_state_dict(torch.load(sav_dir + 'bestmodel' + training_category + '_' + input + '_' + output + '.pth'))
            model.eval()


            if output=='All':
                ap=torch.zeros((0,5))
                at=torch.zeros((0,5))
            else:
                ap=torch.zeros((0,1))
                at=torch.zeros((0,1))

            for rgbd, targets in test_loader:
                rgbd = rgbd.to(device)
                targets = targets.to(device)
                preds = model(rgbd)
                # mse_loss=mse(preds, targets)
                # nmse=criterion(preds, targets)
                # nmse, pred=cri(preds, targets)
                ap=torch.cat((ap, preds.detach().cpu()), 0)
                at=torch.cat((at, targets.detach().cpu()), 0)

            if output=='All':
                print('FW MSE: ', str(mse(ap[:,0],at[:,0]).tolist()))
                print('DW MSE: ', str(mse(ap[:,1],at[:,1]).tolist()))
                print('H MSE: ', str(mse(ap[:,2],at[:,2]).tolist()))
                print('D MSE: ', str(mse(ap[:,3],at[:,3]).tolist()))
                print('LA MSE: ', str(mse(ap[:,4],at[:,4]).tolist()))
            else:
                final=torch.cat((final, ap.detach().cpu()),1)
                all_targets=torch.cat((all_targets, at.detach().cpu()),1)
                print(output,' MSE: ', str(mse(ap,at).tolist()))

        if output == 'All':
            print('Overall NMSE: ', str(cri(ap,at).tolist()))
        else:
            print('Overall NMSE: ', str(cri(final,all_targets).tolist()))

Input is  RGB-D
FW MSE:  16857.876953125
DW MSE:  4.854626655578613
H MSE:  3.97654390335083
D MSE:  22.738414764404297
LA MSE:  5795591.0
Overall NMSE:  1.632205843925476
