In [2]:
import os
import glob
import cv2 as cv
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
import argparse

import torch
import torchvision
from torch.utils.data import Dataset
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

import utils
from engine import train_one_epoch, evaluate
import transforms as T

In [1]:
imgs_dir = "../RiceDiseaseDataset/images/"
labels_dir = "../RiceDiseaseDataset/labels/"

In [5]:
class RiceDiseaseDataset(Dataset):
    """Builds a dataset with image and their labels.
    
    
    """
    
    def __init__(self, imgs_dir, labels_dir, transforms=None):
        """
        
        Args:
            
        """
        self.imgs_dir = imgs_dir
        self.labels_dir = labels_dir
        self.imgs = list(sorted(os.listdir(self.imgs_dir)))
        self.xml_df = self.xml_to_csv(self.labels_dir)
        self.class_names = self.xml_df["class"].unique()
        self.labels_dict = self.create_labels_dict(self.xml_df)
        self.transforms = transforms
        
    
    def __len__(self):
        """
        """
        return len(self.imgs)
    
    def __getitem__(self, idx):
        """
        """
        img_name = self.imgs[idx]
        img_path = os.path.join(self.imgs_dir, img_name)
        img = cv.imread(img_path)
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
        
        img_df = self.xml_df[self.xml_df["filename"] == img_name]
        boxes = img_df[["xmin", "ymin", "xmax", "ymax"]].values
        labels = img_df[["class"]].values


        # change this
        classes = []
        for key, value in self.labels_dict.items():
            for i in labels:
                if i == value:
                    classes.append(key)
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(classes, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.tensor(classes, dtype=torch.int64)
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def create_labels_dict(self, xml_df):
        labels = self.class_names
        labels_dict = {}
        for index, label in enumerate(labels):
            labels_dict.__setitem__(index, label)
            
        return labels_dict
    
    def get_box_class(self, xml_df, img_name):
        labels = self.xml_df[self.xml_df["filename"] == img_name]
        return labels
    
    def xml_to_csv(self, path):
        # parses xml files and adds results to Pandas DataFrame 
        xml_list = []
        for xml_file in glob.glob(path + '/*.xml'):
            tree = ET.parse(xml_file)
            root = tree.getroot()
            for child in root.findall('object'):
                value = (root.find('filename').text,
                        int(root.find('size')[0].text),
                        int(root.find('size')[1].text),
                        child[0].text,
                        int(child[4][0].text),
                        int(child[4][1].text),
                        int(child[4][2].text),
                        int(child[4][3].text)
                        )
                xml_list.append(value)
        column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
        xml_df = pd.DataFrame(xml_list, columns=column_name)
        xml_df.to_csv("../RiceDiseaseDataset/labels.csv")
        return xml_df

In [6]:
x = RiceDiseaseDataset(imgs_dir, labels_dir)
x[1]

(array([[[154, 148, 150],
         [ 93,  88,  92],
         [ 71,  69,  71],
         ...,
         [186, 154, 157],
         [185, 153, 156],
         [184, 152, 155]],
 
        [[142, 140, 141],
         [ 79,  76,  77],
         [ 57,  57,  54],
         ...,
         [187, 155, 158],
         [185, 153, 156],
         [184, 152, 155]],
 
        [[133, 133, 132],
         [ 65,  66,  61],
         [ 47,  49,  41],
         ...,
         [188, 156, 159],
         [185, 153, 156],
         [184, 152, 155]],
 
        ...,
 
        [[234, 243, 235],
         [237, 244, 239],
         [241, 248, 243],
         ...,
         [138, 183,  79],
         [137, 182,  75],
         [136, 182,  74]],
 
        [[219, 237, 211],
         [215, 234, 213],
         [220, 237, 220],
         ...,
         [139, 183,  85],
         [139, 184,  82],
         [138, 184,  80]],
 
        [[203, 230, 170],
         [186, 220, 158],
         [188, 219, 171],
         ...,
         [140, 184,  94],
  

In [7]:
x.labels_dict

{0: 'Bacterial_Blight', 1: 'Rice_Blast', 2: 'Brown_Spot'}

In [8]:
def get_model(num_classes=3):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

In [9]:
def get_transforms(train):
    transforms = []
    transforms.append(T.ToTensor())

    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [10]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

dataset = RiceDiseaseDataset(imgs_dir, labels_dir, transforms=get_transforms(train=True))
dataset_test = RiceDiseaseDataset(imgs_dir, labels_dir, transforms=get_transforms(train=False))

torch.manual_seed(1)
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-test_set_length])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-test_set_length:])

data_loader = torch.utils.data.DataLoader(dataset, batch_size=train_batch_size, 
                    shuffle=True, num_workers=0, collate_fn=utils.collate_fn)
data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=test_batch_size,
                    shuffle=False, num_workers=0, collate_fn=utils.collate_fn)
print("Dataset Details\nImages: {}, Train: {}, Test: {}.\n".format(len(indices),
                                len(dataset), len(dataset_test)))

NameError: name 'T' is not defined