In [1]:
from plantcv import plantcv as pcv
import numpy as np
import torch

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

# n = 1
# img_path = f"./images/sized/{n}.jpg"
# img, _, _ = pcv.readimage(filename=img_path, mode='rgb')
# img.shape

In [2]:
# tt, _, _ = pcv.readimage(filename="./tt.png", mode="rgb")
# tt.shape

In [3]:
# mask_path = f"./images/mask/{n}.jpg"
# mask, _, _ = pcv.readimage(filename=mask_path, mode="gray")
# # black:0, white: 255
# mask[mask>0]=1
# mask=mask[:, :, np.newaxis]
# mask.shape

In [4]:
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader


class SegDataset(Dataset):
    def __init__(self, mask_dir, img_dir, transform=None, mask_transform=None):
        self.mask_dir = mask_dir
        self.img_dir = img_dir
        self.n_mask = len(os.listdir(self.mask_dir))
        self.n_img = len(os.listdir(self.img_dir))
        self.transform = transform
        self.mask_transform = mask_transform

    def __len__(self):     
        return min(self.n_img, self.n_mask)

    def __getitem__(self, idx):
        
        filename = os.listdir(self.img_dir)[idx]
        
        img_path = f"{self.img_dir}{filename}"
        img, _, _ = pcv.readimage(filename=img_path, mode='rgb')
        img = img/255
        """
        Normalization to be added
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        """
        # plantcv read returns image by Height x Width x Channels
        img = np.transpose(img, (2,0,1))
        # adjust image dimension to Channels x Height x Width
        
        mask_path = f"{self.mask_dir}{filename}"
        mask, _, _ = pcv.readimage(filename=mask_path, mode="gray")
        mask[mask>0]=1    # black:0, white: 255
        mask=mask[np.newaxis,:, :]
        
        if self.transform:
            image = self.transform(img)
        if self.mask_transform:
            label = self.mask_transform(mask)
        return img, mask

train_mask_dir = f"./images/train/mask/"
train_img_dir = f"./images/train/img/"
training_set = SegDataset(train_mask_dir, train_img_dir)
training_generator = DataLoader(training_set, batch_size=2, shuffle=True)

val_mask_dir = f"./images/validation/mask/"
val_img_dir = f"./images/validation/img/"
validation_set = SegDataset(val_mask_dir, val_img_dir)
validation_generator = DataLoader(validation_set, batch_size=2, shuffle=True)
    


In [5]:
max_epochs = 1

for x in iter(training_generator):
    print("\n", type(x), len(x))
    print("\n", type(x[0]), x[0].size())
    print(x[0])
    break


 <class 'list'> 2

 <class 'torch.Tensor'> torch.Size([2, 3, 500, 500])
tensor([[[[0.1765, 0.1608, 0.1569,  ..., 0.3333, 0.3647, 0.3529],
          [0.1412, 0.1294, 0.1255,  ..., 0.3490, 0.3647, 0.4000],
          [0.1451, 0.1373, 0.1216,  ..., 0.3725, 0.3647, 0.4235],
          ...,
          [0.2196, 0.1882, 0.2784,  ..., 0.2078, 0.2314, 0.2275],
          [0.1765, 0.1843, 0.2588,  ..., 0.1882, 0.2039, 0.2000],
          [0.2314, 0.1294, 0.1412,  ..., 0.1725, 0.1804, 0.1765]],

         [[0.1373, 0.1216, 0.1137,  ..., 0.4745, 0.5059, 0.4941],
          [0.0941, 0.0902, 0.0863,  ..., 0.4902, 0.5059, 0.5412],
          [0.1020, 0.0941, 0.0824,  ..., 0.5137, 0.5059, 0.5647],
          ...,
          [0.2039, 0.1725, 0.2431,  ..., 0.1255, 0.1490, 0.1373],
          [0.1608, 0.1686, 0.2235,  ..., 0.1059, 0.1216, 0.1098],
          [0.2118, 0.1098, 0.1020,  ..., 0.0902, 0.0980, 0.0863]],

         [[0.0706, 0.0549, 0.0588,  ..., 0.1451, 0.1765, 0.1725],
          [0.0235, 0.0235, 0.0196, 

In [6]:
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision import models


def createDeepLabv3(outputchannels=1):
    """DeepLabv3 class with custom head
    Args:
        outputchannels (int, optional): The number of output channels
        in your dataset masks. Defaults to 1.
    Returns:
        model: Returns the DeepLabv3 model with the ResNet101 backbone.
    """
    model = models.segmentation.deeplabv3_resnet101(pretrained=True,
                                                    progress=True)
    model.classifier = DeepLabHead(2048, outputchannels)
    # Set the model in training mode
    model.train()
    return model

model = createDeepLabv3()


In [7]:
from sklearn.metrics import f1_score, roc_auc_score

# Specify the loss function
criterion = torch.nn.MSELoss(reduction='mean')
# Specify the optimizer with a lower learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Specify the evaluation metrics
metrics = {'f1_score': f1_score, 'auroc': roc_auc_score}

In [8]:
import copy
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = 1e10
model.to(device)

train_losses = []
test_losses = []
train_f1 = []
test_f1 = []
train_auroc = []
test_auroc = []

In [13]:
num_epochs = 1

for epoch in range(1, num_epochs + 1):
    print('Epoch {}/{}'.format(epoch, num_epochs))
    print('-' * 10)
    
    for phase in ['Train', 'Test']:
        if phase == 'Train':
            model.train()  # Set model to training mode
            for sample in iter(training_generator):
                inputs = sample[0].to(device)
                inputs = inputs.to(torch.float32)
                masks = sample[1].to(device)
                masks = masks.to(torch.float32)

                optimizer.zero_grad()
#                 print(inputs.size(), masks.size())
#                 print(inputs.dtype, masks.dtype)
                
                outputs = model(inputs)
                loss = criterion(outputs['out'], masks)
                epoch_loss = loss.item()
                
                y_pred = outputs['out'].data.cpu().numpy().ravel()
                y_true = masks.data.cpu().numpy().ravel()
                train_f1.append(f1_score(y_true > 0, y_pred > 0.1))
                train_auroc.append(roc_auc_score(y_true.astype('uint8'), y_pred))

                loss.backward()
                optimizer.step()
                train_losses.append(epoch_loss)
                print('{} Loss: {:.4f}'.format(phase, epoch_loss))
        
        else:
            model.eval()  # Set model to evaluate mode
        
print("loss", train_losses)
print("f1", train_f1)
print("roc", train_auroc)

Epoch 1/1
----------
Train Loss: 0.2131
Train Loss: 0.2821
Train Loss: 0.2390


KeyboardInterrupt: 