In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import PIL
from sklearn.model_selection import train_test_split
import torch
import torchvision
import os
import time
import pandas as pd

# Connect to the GPU if one exists.
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
# Convert human readable str label to int.
label_dict = {"abw": 0, "pbw" : 1}
# Convert label int to human readable str.
reverse_label_dict = {0: "abw", 1: "pbw"}

class WormDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms = None):
        """
        Inputs
            root: str
                Path to the data folder.
            transforms: Compose or list
                Torchvision image transformations.
        """
        self.root = root
        self.transforms = transforms
        self.files = sorted(os.listdir("images"))
        for i in range(len(self.files)):
            self.files[i] = self.files[i].split(".")[0]
            self.label_dict = label_dict
    def __getitem__(self, i):
        # Load image from the hard disc.
        img = PIL.Image.open(os.path.join(self.root, "images/" + self.files[i] + ".png")).convert("RGB")
        # Load annotation file from the hard disc.
        ann = pd.read_csv(os.path.join(self.root, "annotations/" + self.files[i] + ".txt"))
        # The target is given as a dict.
        target = {}
        target["boxes"] = torch.as_tensor([[ann["x1"], 
                                            ann["y1"], 
                                            ann["x2"], 
                                            ann["y2"]]], 
                                   dtype = torch.float32)
        target["labels"]=torch.as_tensor([label_dict[ann["label"]]],
                         dtype = torch.int64)
        target["image_id"] = torch.as_tensor(i)
        # Apply any transforms to the data if required.
        if self.transforms is not None:
            img, target = self.transforms(img, target)
        return img, target
    def __len__(self):
        return len(self.files)