In [9]:
from __future__ import annotations

from tqdm import tqdm
from PIL import Image
from typing import Optional, Any, Callable
import numpy as np
from skimage.draw import polygon
from skimage.filters import gaussian
from skimage.feature import peak_local_max
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
import os, glob, copy, random, pickle

import torch
from torch import nn, optim
import torch.nn.functional as F 
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.utils.tensorboard.writer import SummaryWriter

# Models

In [3]:
# Model
class ResidualBlock(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1)
        self.bn2 = nn.BatchNorm2d(in_channels)

    def forward(self, x_in):
        x = self.bn1(self.conv1(x_in))
        x = F.relu(x)
        x = self.bn2(self.conv2(x))
        return x + x_in

class GRConvNet4(nn.Module):

    def __init__(self, input_channels=4, output_channels=1, channel_size=32, dropout=False, prob=0.0, clip=False):
        super(GRConvNet4, self).__init__()
        self.clip = clip
        self.conv1 = nn.Conv2d(input_channels, channel_size, kernel_size=9, stride=1, padding=4)
        self.bn1 = nn.BatchNorm2d(channel_size)

        self.conv2 = nn.Conv2d(channel_size, channel_size // 2, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(channel_size // 2)

        self.conv3 = nn.Conv2d(channel_size // 2, channel_size // 4, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(channel_size // 4)

        self.res1 = ResidualBlock(channel_size // 4, channel_size // 4)
        self.res2 = ResidualBlock(channel_size // 4, channel_size // 4)
        self.res3 = ResidualBlock(channel_size // 4, channel_size // 4)
        self.res4 = ResidualBlock(channel_size // 4, channel_size // 4)
        self.res5 = ResidualBlock(channel_size // 4, channel_size // 4)

        self.conv4 = nn.ConvTranspose2d(channel_size // 4, channel_size // 2, kernel_size=4, stride=2, padding=1,
                                        output_padding=1)
        self.bn4 = nn.BatchNorm2d(channel_size // 2)

        self.conv5 = nn.ConvTranspose2d(channel_size // 2, channel_size, kernel_size=4, stride=2, padding=2,
                                        output_padding=1)
        self.bn5 = nn.BatchNorm2d(channel_size)

        self.conv6 = nn.ConvTranspose2d(channel_size, channel_size, kernel_size=9, stride=1, padding=4)

        self.grasp_outputs = nn.Conv2d(in_channels=channel_size, out_channels=output_channels*4, kernel_size=2)
        self.confidence_output = nn.Conv2d(in_channels=channel_size, out_channels=output_channels, kernel_size=2)

        self.dropout = dropout
        self.dropout_conf = nn.Dropout(p=prob)
        self.dropout_grasp = nn.Dropout(p=prob)


        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.xavier_uniform_(m.weight, gain=1)

    def forward(self, x_in):
        x = F.relu(self.bn1(self.conv1(x_in)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        x = self.res4(x)
        x = self.res5(x)
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        x = self.conv6(x)

        if self.dropout:
            grasp_output = self.grasp_outputs(self.dropout_grasp(x))
            confidence_output = self.confidence_output(self.dropout_conf(x))
        else:
            grasp_output = self.grasp_outputs(x)
            confidence_output = self.confidence_output(x)

        confidence_output = F.sigmoid(confidence_output)      
        grasp_output = F.tanh(grasp_output)  

        output = torch.cat([confidence_output, grasp_output], dim=1)
        if self.clip:
            output = output.clip(-1, 1)
        return output

# Dataset

In [4]:
# Dataset parent class
class JacquardGraspCLSDataset(Dataset):
    """
    Parent class for all dataset objects dealing with the Jacquard Dataset. Handles file handling and combining 
    file paths for every instance of training data. Also handles train and test splitting.
    
    Assumes that all jacquard dataset instances have been cleaned using the methods in dataset/preprocess.py

    Child classes must implement methods:
        - get_instance_from_dataset - Gets a single instance from training data given file paths. Handle any augmentations here.
        - get_instance_from_cache - Gets a single instance of training data from the cached files (only use if heavy preprocessing)
        - visualize_instance - visualizes a single instance of training data
        - cache_dataset - caches the entire dataset (only use if heavy preprocessing)

    Cache assumes the following directory structure:
    cache_location
        |__ class_1
            |__ training_data_1.npz
            |__ training_data_2.npz
            ...
        |__ class_2
        ...
    """

    def __init__(
            self, 
            image_size: int, 
            dataset_path: Optional[str] = None, 
            cache_path: Optional[str] = None,
            random_augment: bool = True
        ) -> None:
        assert dataset_path is not None or cache_path is not None, "One of dataset_path or cache_path must be given"
        assert dataset_path is None or cache_path is None, "One of dataset_path or cache_path much be left empty"
        
        self.dataset_path = dataset_path
        self.random_augment = random_augment
        self.cache_path = cache_path
        if self.dataset_path is not None:
            self.class_to_idx, self.idx_to_class = self.get_class_map(self.dataset_path)    
        else:
            self.class_to_idx, self.idx_to_class = self.get_class_map(self.cache_path)
        self.image_size = image_size
        self.individual_file_paths = self.get_all_file_paths(self.dataset_path)

    def extract_test_dataset(self, test_split_ratio: float):
        """
        Extracts test_split_ratio * len(self) number of data points into a test dataset and returns it.
        Also turns off random_augment for the test dataset.
        """
        random.seed(0)
        test_dataset = copy.copy(self)

        self.train_idxs = random.sample(list(range(self.__len__())), k = int((1 - test_split_ratio) * self.__len__()), )
        self.test_idxs = [i for i in range(self.__len__()) if i not in self.train_idxs]

        train_fps = [self.individual_file_paths[i] for i in self.train_idxs]
        test_fps = [self.individual_file_paths[i] for i in self.test_idxs]

        self.individual_file_paths = train_fps
        test_dataset.individual_file_paths = test_fps
        test_dataset.random_augment = False
        return test_dataset

    def get_instance_from_dataset(self, file_paths: list[str], random_augment: bool = True) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        To be implemented by a child class
        """
        raise NotImplementedError
    
    def get_instance_from_cache(self, file_path: list[str], random_augment: bool = True) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        To be implemented by a child class
        """
        raise NotImplementedError
    
    def visualize_instance(self, file_paths: list[str]) -> None:
        """
        To be implemented by a child class
        """
        raise NotImplementedError
    
    def cache_dataset(self, cache_location: str):
        """
        To be implemented by a child class
        """
        raise NotImplementedError

    def __getitem__(self, idx):
        file_paths = self.individual_file_paths[idx]
        if self.dataset_path is not None:
            return_values = self.get_instance_from_dataset(file_paths, random_augment=self.random_augment)
        else:
            return_values = self.get_instance_from_cache(file_paths, random_augment=self.random_augment)
        return return_values

    def __len__(self):
        return len(self.individual_file_paths)
    
    def visualize(self, idx) -> None:
        self.visualize_instance(self.individual_file_paths[idx])
    
    def get_all_file_paths(self, dataset_path: str) -> list[list[str]]:
        """
        Returns a list of lists \n
        Each list contains the file paths for the files in the following order:
            - rgb image file paths
            - perfect depth file paths
            - stereo depth file paths
            - grasp file paths
            - mask file paths
            - class of object (not a path)

        If loading from cache, outputs a list containing a single file path
        """
        if self.dataset_path is not None:
            rgb_paths = glob.glob(os.path.join(dataset_path, "*/*", "*RGB.png"))
            output = []
            for rgb_path in rgb_paths:
                assert os.path.isfile(rgb_path)
                instance_data = [rgb_path]

                perfect_depth_path = rgb_path.replace("RGB.png", "perfect_depth.tiff")
                assert os.path.isfile(perfect_depth_path)
                instance_data.append(perfect_depth_path)

                stereo_depth_path = rgb_path.replace("RGB.png", "stereo_depth.tiff")
                assert os.path.isfile(stereo_depth_path)
                instance_data.append(stereo_depth_path)

                grasps_path = rgb_path.replace("RGB.png", "grasps.txt")
                assert os.path.isfile(grasps_path)
                instance_data.append(grasps_path)

                mask_path = rgb_path.replace("RGB.png", "mask.png")
                assert os.path.isfile(mask_path)
                instance_data.append(mask_path)

                instance_data.append(rgb_path.split("/")[-3])
                output.append(instance_data)
            return output
        else:
            return glob.glob(os.path.join(self.cache_path, "*/*"))
            
    def get_class_map(self, dataset_path: str) -> tuple[dict[str, int], dict[int, str]]:
        all_classes = self.listdir(dataset_path)
        class_to_idx = {c:i for i, c in enumerate(sorted(all_classes))}
        idx_to_class = {v:k for k, v in class_to_idx.items()}
        return class_to_idx, idx_to_class

    def listdir(self, folder_path: str) -> list[str]:
        files = os.listdir(folder_path)
        fn = [i for i in files if ".DS_Store" in i]
        if len(fn) > 0:
            files.remove(fn[0])
        return files

In [5]:
# Dataset Child class
class MapBasedJacquardDataset(JacquardGraspCLSDataset):

    def __init__(
        self, 
        image_size: int, 
        precision: torch.dtype,
        dataset_path: Optional[str] = None, 
        cache_path: Optional[str] = None,
        random_augment: bool = True,
        width_scale_factor: int = 1
    ) -> None:
        super().__init__(image_size, dataset_path, cache_path, random_augment)
        self.width_scale_factor = width_scale_factor
        self.precision = precision

    ###### Image and depth map loading functions
    def load_rgbd_image(self, rgb_image_path: str, depth_image_path: str) -> torch.Tensor:
        T = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor()
        ])
        rgb_image = T(Image.open(rgb_image_path))[:3]
        depth_image = T(Image.open(depth_image_path))
        return torch.cat([rgb_image, depth_image], dim=0)
    
    def preprocess_rbgd_image(self, rgbd_image: torch.Tensor) -> torch.Tensor:
        ## Add preprocessing steps here (normalization, etc..)
        return rgbd_image
    

    ###### Classification label loading functions
    def load_mask_image(self, mask_path: str) -> torch.Tensor:
        T = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor()
        ])
        return T(Image.open(mask_path))
    
    def load_classification_labels(self, mask_image: torch.Tensor, class_label: str):
        """
        The output of this function has two parts;
            - output[0] is the object/background mask. This map contains a 0 for all background pixels
              and a 1 for all object pixels. A sigmoid non-linearity is applied on the first channel of the model
              output to reflect this.
            - output[1:] are the class maps. This map contains a 0 value for all background pixels. When the class_map index
              matches the target class index, the object pixels are valued at 1 and when it does not, the object pixels
              are valued at -1. A tanh non-linearity is applied to the remaining channels of the model output
              to reflect this.

        Use self.visualize(idx) to visualize instances of cls_maps.
        """
        mask_image = mask_image.round()
        class_idx = self.class_to_idx[class_label]

        # Initially the cls maps are entirely -1 valued.
        class_maps = torch.ones(len(self.class_to_idx), mask_image.shape[1], mask_image.shape[2]) * -1
        # changing the cls map of the correct class to be +1 valued.
        class_maps[class_idx] *= -1

        # making all non-object, background pixels 0
        class_maps = class_maps * mask_image.repeat(len(self.class_to_idx), 1, 1)
        return torch.cat([mask_image, class_maps], dim=0)


    ###### Grasp label loading functions
    def load_grasp_file(self, grasp_path: str) -> np.ndarray:
        """
        Returns in the order [y, x, angle, length, width]
        """
        grasps = []
        with open(grasp_path, "r") as f:
            for l in f:
                x, y, theta, w, h = [float(v) for v in l[:-1].split(';')]
                grasps.append([x, y, -theta / 180.0 * np.pi, w * self.width_scale_factor, h])
        grasps = np.array(grasps)

        # rescaling values based on image size
        grasps[:, :2] *= (self.image_size / 1024)
        grasps[:, 3:] *= (self.image_size / 1024)
        return grasps
    
    def compute_grasp_rectangle(self, grasp_arr: np.ndarray) -> np.ndarray:
        """
        Converts the jacquard dataset grasps into grasp rectangles 
        (returns coordinates of 4 corners of each rectangle)
        """
        x, y, angle, length, width = grasp_arr.transpose()
        xo = np.cos(angle)
        yo = np.sin(angle)

        ya = y + length / 2 * yo
        xa = x - length / 2 * xo
        yb = y - length / 2 * yo
        xb = x + length / 2 * xo

        y1, x1 = ya - width / 2 * xo, xa - width / 2 * yo
        y2, x2 = yb - width / 2 * xo, xb - width / 2 * yo
        y3, x3 = yb + width / 2 * xo, xb + width / 2 * yo
        y4, x4 = ya + width / 2 * xo, xa + width / 2 * yo

        # p1 = np.stack([y1, x1], 1)
        # p2 = np.stack([y2, x2], 1)
        # p3 = np.stack([y3, x3], 1)
        # p4 = np.stack([y4, x4], 1)

        p1 = np.stack([x1, y1], 1)
        p2 = np.stack([x2, y2], 1)
        p3 = np.stack([x3, y3], 1)
        p4 = np.stack([x4, y4], 1)

        output_arr = np.stack([p1, p2, p3, p4], 1)
        return output_arr
    
    def compute_grasp_map(
            self,
            grasp_arr: np.ndarray, 
            grasp_rect_array: np.ndarray
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Returns a tuple containing torch tensors of shape (image_size, image_size) in
        the order (confidence_map, cos, sin, width)
        """
        conf_out = np.zeros((self.image_size, self.image_size))
        angle_out = np.zeros((self.image_size, self.image_size))
        width_out = np.zeros((self.image_size, self.image_size))
        length_out = np.zeros((self.image_size, self.image_size))

        for rect, (_, _, angle, width, length) in zip(grasp_rect_array, grasp_arr):
            cc, rr = polygon(rect[:, 0], rect[:, 1], (self.image_size, self.image_size))
            conf_out[rr, cc] = 1.0
            angle_out[rr, cc] = angle
            width_out[rr, cc] = width
            length_out[rr, cc] = length
        width_out /= self.image_size
        length_out /= self.image_size

        out_tensor = torch.cat([
            torch.from_numpy(conf_out).unsqueeze(0),
            torch.from_numpy(np.cos(angle_out)).unsqueeze(0),
            torch.from_numpy(np.sin(angle_out)).unsqueeze(0),
            torch.from_numpy(width_out).unsqueeze(0),
            torch.from_numpy(length_out).unsqueeze(0)], dim=0)
        return out_tensor
    
    def rotate_augment(
            self,
            rotation_angle: int, 
            rgbd_image: torch.Tensor, 
            grasp_map: torch.Tensor,
            grasp_rect: torch.Tensor,
            cls_map: torch.Tensor,
            rotate: Optional[bool] = None
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        if self.random_augment or rotate:
            if rotation_angle == 270:
                rgbd_image = rgbd_image.transpose(-1, -2)
                grasp_map = grasp_map.transpose(-1, -2)
                grasp_rect = grasp_rect.flip(-1)
                cls_map = cls_map.transpose(-1, -2)
            elif rotation_angle == 180:
                rgbd_image = rgbd_image.flip(-2)
                grasp_map = grasp_map.flip(-2)
                grasp_rect[:, :, 1] = self.image_size - grasp_rect[:, :, 1]
                cls_map = cls_map.flip(-2)
            elif rotation_angle == 90:
                rgbd_image = rgbd_image.flip(-2).transpose(-1, -2)
                grasp_map = grasp_map.flip(-2).transpose(-1, -2)
                grasp_rect = grasp_rect.flip(-1)
                grasp_rect[:, :, 0] = self.image_size - grasp_rect[:, :, 0]
                cls_map = cls_map.flip(-2).transpose(-1, -2)
            elif rotation_angle == 0:
                pass
            else:
                raise Exception("Invalid rotation angle")
        return rgbd_image, grasp_map, grasp_rect, cls_map
    
    def get_random_rotation_angle(self):
        return random.choice([0, 90, 180, 270])
    
    def cast_to_fp_precision(self, rgbd, grasp_map, grasp_rect, cls_map):
        rgbd = rgbd.to(self.precision)
        grasp_map = grasp_map.to(self.precision)
        grasp_rect = grasp_rect.to(self.precision)
        cls_map = cls_map.to(self.precision)
        return rgbd, grasp_map, grasp_rect, cls_map

    def get_instance_from_dataset(self, file_paths: list[str], random_augment: bool = True) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        rgb_fp, perfect_depth_fp, stereo_depth_fp, grasp_fp, mask_fp, class_label = file_paths

        rgbd_image = self.load_rgbd_image(rgb_fp, perfect_depth_fp)
        rgbd_image = self.preprocess_rbgd_image(rgbd_image)

        mask_image = self.load_mask_image(mask_fp)
        cls_labels = self.load_classification_labels(mask_image, class_label)

        grasp_arr = self.load_grasp_file(grasp_fp)
        grasp_rect_arr = self.compute_grasp_rectangle(grasp_arr)
        grasp_map = self.compute_grasp_map(grasp_arr, grasp_rect_arr)
        grasp_rect_arr = torch.from_numpy(grasp_rect_arr)

        rgbd_image, grasp_map, grasp_rect, cls_labels = self.rotate_augment(self.get_random_rotation_angle(), rgbd_image, grasp_map, grasp_rect_arr, cls_labels)
        rgbd_image, grasp_map, grasp_rect, cls_labels = self.cast_to_fp_precision(rgbd_image, grasp_map, grasp_rect, cls_labels)

        return rgbd_image, grasp_map, grasp_rect, cls_labels
    
    def get_instance_from_cache(self, file_paths: str, random_augment: bool = True) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        rgbd, grasp_map, grasp_rect, cls_map = np.load(file_paths).values()
        rgbd, grasp_map, grasp_rect, cls_map = torch.from_numpy(rgbd), torch.from_numpy(grasp_map), torch.from_numpy(grasp_rect), torch.from_numpy(cls_map)
        rgbd, grasp_map, grasp_rect, cls_map = self.rotate_augment(self.get_random_rotation_angle(), rgbd, grasp_map, grasp_rect, cls_map)
        rgbd, grasp_map, grasp_rect, cls_map = self.cast_to_fp_precision(rgbd, grasp_map, grasp_rect, cls_map)
        return rgbd, grasp_map, grasp_rect, cls_map
    
    def visualize_instance(self, file_paths: list[str]) -> None:
        if self.dataset_path is not None:
            rgbd_image, (conf, cos, sin, width, length), grasp_rect,  cls_labels = self.get_instance_from_dataset(file_paths)
        else:
            rgbd_image, (conf, cos, sin, width, length), grasp_rect, cls_labels = self.get_instance_from_cache(file_paths)
        
        fig1, ax1 = plt.subplots(nrows=1, ncols=3, figsize=(20, 10))
        ax1 = ax1.flatten()

        ax1[0].imshow(rgbd_image[:-1].permute(1, 2, 0))
        ax1[0].axis(False)
        ax1[0].set_title("RGB Image")

        ax1[1].imshow(rgbd_image[-1])
        ax1[1].axis(False)
        ax1[1].set_title("Depth Map")

        ax1[2].imshow(rgbd_image[:-1].permute(1, 2, 0))
        for rect in grasp_rect:
            rect = Polygon(rect, linewidth=1, edgecolor="r", facecolor="none")
            ax1[2].add_patch(rect)
        ax1[2].axis(False)
        ax1[2].set_title("Bounding boxes")
        plt.show()

        fig2, ax2 = plt.subplots(nrows=1, ncols=5, figsize=(20, 6))
        ax2 = ax2.flatten()
        for ax, img, name in zip(ax2, (conf, cos, sin, width, length), ("Confidence map", "cos", "sin", "width", "length")):
            im = ax.imshow(img)
            ax.axis(False)
            ax.set_title(name)
        plt.show()

        fig3, ax3 = plt.subplots(nrows=1, ncols=len(self.class_to_idx) + 1, figsize=(20, 6))
        ax3 = ax3.flatten()
        for ax, img, idx in zip(ax3, cls_labels, range(len(self.class_to_idx) + 1)):
            im = ax.imshow(img, vmin=-1, vmax=1)
            ax.axis(False)
            if idx == 0:
                ax.set_title("Background/Object")
            else:
                ax.set_title(self.idx_to_class[idx - 1])
        plt.colorbar(im)
        plt.show()

    def cache_dataset(self, cache_location: str):
        os.mkdir(cache_location)
        for class_name in self.class_to_idx.keys():
            class_path = os.path.join(cache_location, class_name)
            os.mkdir(class_path)

        class_item_indices = {cn: 0 for cn in self.class_to_idx.keys()}
        print("Creating cache...")
        loop = tqdm(range(len(self)))
        self.random_augment = False
        for i in loop:
            rgbd, grasp_map, grasp_rect_arr, cls_map = self[i]
            class_name = self.individual_file_paths[i][-1]
            save_dir = os.path.join(cache_location, class_name, class_name + "_" + str(class_item_indices[class_name]))
            class_item_indices[class_name] += 1

            rgbd, grasp_map, cls_map = rgbd.numpy(), grasp_map.numpy(), cls_map.numpy()
            np.savez_compressed(save_dir, rgbd=rgbd, grasp_map=grasp_map, grasp_rect=grasp_rect_arr, cls_map=cls_map)

# Loss functions

In [19]:
# Loss functions
class DoubleLogLoss:

    def __init__(self, mean_reduction: bool = True):
        self.mean_reduction = mean_reduction

    def check_range(self, y, yhat):
        assert ((y < -1).sum() + (y > 1).sum()).item() == 0, "Target outside valid range"
        assert ((yhat < -1).sum() + (yhat > 1).sum()).item() == 0, "Predicted value outside valid range"

    def fp_error_recentre(self, yhat, tolerance = 1e-5):
        max_filter_map = torch.zeros_like(yhat)
        min_filter_map = torch.zeros_like(yhat)
        max_filter_map[yhat > 1] += tolerance
        min_filter_map[yhat < -1] += tolerance
        yhat = yhat - max_filter_map
        yhat = yhat + min_filter_map
        return yhat

    def __call__(self, yhat, y):
        """
        yhat should be a tensor of predictions of shape [batch_size, num_classes], with each element being
        between -1 and 1 (tanh activation after final layer).

        y should be a tensor of target labels of shape [batch_size, num_classes].
        IMPORTANT_NOTE : each element of y will be a tensor of shape [num_classes]. Each element of this
                        SHOULD be either -1 or 1 when training for cls, with -1 (not 0) indicating that the object
                        is not a particular class and 1 indicating that it is.
        """
        self.check_range(y=y, yhat=yhat)
        y_less_than_loss = - torch.log(1 + (1 / (1 + yhat + 1e-5))*(y - yhat))
        yhat_less_than_loss = - torch.log(1 + (1 / (1 - yhat + 1e-5))*(yhat - y))
        output = torch.where(y < yhat, y_less_than_loss, yhat_less_than_loss)
        if self.mean_reduction:
            output = output.mean()
        return output
    
class DoubleLogMapLoss:

    def __init__(self):
        self.double_log = DoubleLogLoss(mean_reduction=False)
        self.bce = nn.BCELoss()

    def __call__(self, predicted_map: torch.Tensor, target_map: torch.Tensor) -> torch.Tensor:
        """
        predicted_map and target_map are of shape [batch_size, n_channels, img_size, img_size].
        indexing with [:, 0, :, :] represents confidence maps for both predicted and target.
        """
        # Reshape to [n_channels, batch_size, img_size, img_size]
        predicted_map = predicted_map.permute(1, 0, 2, 3)
        target_map = target_map.permute(1, 0, 2, 3)

        # Computing a confidence map loss with the first channel
        confidence_loss = self.bce(predicted_map[0], target_map[0])

        # Computing the grasp/cls losses with remaining channels
        grasp_cls_loss = self.double_log(predicted_map[1:], target_map[1:])

        # Valid pixels for grasp_cls_loss are those where the target confidence map is not 0, 
        # since these are pixels which belong to the object and not the background.
        valid_pixels = target_map[0] != 0
        valid_pixels = valid_pixels.unsqueeze(0).repeat(predicted_map.shape[0] - 1, 1, 1, 1)
        grasp_cls_loss = (grasp_cls_loss * valid_pixels).mean()
        return confidence_loss + grasp_cls_loss * 2
    
class CrossEntropyMapLoss:

    def __init__(self):
        self.bce = nn.BCELoss()

    def __call__(self, predicted_map: torch.Tensor, target_map: torch.Tensor) -> torch.Tensor:
        """
        predicted_map and target_map are of shape [batch_size, n_channels, img_size, img_size].
        indexing with [:, 0, :, :] represents confidence maps for both predicted and target.
        """
        # Reshape to [n_channels, batch_size, img_size, img_size]
        predicted_map = predicted_map.permute(1, 0, 2, 3)
        target_map = target_map.permute(1, 0, 2, 3)

        confidence_predicted_map = predicted_map[0]
        confidence_target_map = target_map[0]
        grasp_or_cls_predictions = predicted_map[1:]
        grasp_or_cls_targets = target_map[1:]

        # Computing a confidence map loss with the first channel
        confidence_loss = self.bce(confidence_predicted_map, confidence_target_map)

        # Rescaling predicted and target outputs to be between 0 and 1 for bce loss
        grasp_or_cls_predictions = (grasp_or_cls_predictions + 1) / 2
        grasp_or_cls_targets = (grasp_or_cls_targets + 1) / 2

        # Computing the grasp/cls losses with remaining channels
        grasp_cls_loss = self.bce(grasp_or_cls_predictions, grasp_or_cls_targets)

        # Valid pixels for grasp_cls_loss are those where the target confidence map is not 0, 
        # since these are pixels which belong to the object and not the background.
        valid_pixels = target_map[0] != 0
        valid_pixels = valid_pixels.unsqueeze(0).repeat(predicted_map.shape[0] - 1, 1, 1, 1)
        grasp_cls_loss = (grasp_cls_loss * valid_pixels).mean()
        return confidence_loss + grasp_cls_loss * 2

# Utils

In [7]:
# Grasp utility functions
def post_process_map_output(
        q_img: torch.Tensor, 
        cos_img: torch.Tensor, 
        sin_img: torch.Tensor, 
        width_img: torch.Tensor,
        length_img: torch.Tensor,
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    q_img = q_img.cpu().numpy().squeeze()
    ang_img = (torch.atan2(sin_img, cos_img) / 2.0).cpu().numpy().squeeze()
    width_img = width_img.cpu().numpy().squeeze()
    width_img *= width_img.shape[-1]
    length_img = length_img.cpu().numpy().squeeze()
    length_img *= length_img.shape[-1]

    q_img = gaussian(q_img, 2.0, preserve_range=True)
    ang_img = gaussian(ang_img, 2.0, preserve_range=True)
    width_img = gaussian(width_img, 1.0, preserve_range=True)
    length_img = gaussian(length_img, 1.0, preserve_range=True)

    return q_img, ang_img, width_img, length_img

def cls_from_map(cls_map: torch.Tensor, verbose: bool = False):
    """
    Expects input tensor cls_map to be of shape [batch_size, num_classes + 1, image_size, image_size]
    cls_map[:, 0] is the confidence map
    """
    cls_map = cls_map.permute(1, 0, 2, 3)
    conf, preds = cls_map[0].round(), cls_map[1:]
    conf = conf.repeat(preds.shape[0], 1, 1, 1)
    masked_preds = conf * preds
    masked_preds = masked_preds.view(masked_preds.shape[0], masked_preds.shape[1], -1)
    masked_preds = masked_preds.mean(-1)
    if verbose:
        print(masked_preds)
    masked_preds = masked_preds.argmax(0)
    return masked_preds

def grasp_rect_from_grasps(grasps: np.ndarray) -> np.ndarray:
    """
    Expects inputs to be in the Grasp format with shape (N, 5).
    Each of the N in the first dim represent a different grasp, with each being in the order
    [center_x, center_y, angle, length, width]
    """
    y, x, angle, length, width = grasps.transpose()
    xo = np.cos(angle)
    yo = np.sin(angle)

    ya = y + length / 2 * yo
    xa = x - length / 2 * xo
    yb = y - length / 2 * yo
    xb = x + length / 2 * xo

    y1, x1 = ya - width / 2 * xo, xa - width / 2 * yo
    y2, x2 = yb - width / 2 * xo, xb - width / 2 * yo
    y3, x3 = yb + width / 2 * xo, xb + width / 2 * yo
    y4, x4 = ya + width / 2 * xo, xa + width / 2 * yo

    p1 = np.stack([x1, y1], 1)
    p2 = np.stack([x2, y2], 1)
    p3 = np.stack([x3, y3], 1)
    p4 = np.stack([x4, y4], 1)

    output_arr = np.stack([p1, p2, p3, p4], 1)
    return output_arr

def check_grasp_success(
    predicted_rects: np.ndarray,
    target_rects: np.ndarray,
    image_size: int,
    angle_threshold: float,
    iou_threshold: float,
    verbose: bool = False
    ):
    """
    Given an array list of predicted grasp rectangles and target grasp rectangles, outputs
    if max_iou between at least one predicted grasp rectangle > threshold.
    """
    from metrics.iou import max_iou # placed here due to circular import problems

    for rect in predicted_rects:
        iou_score = max_iou(
            predicted_rect=rect, 
            target_rects=target_rects,
            angle_threshold=angle_threshold,
            image_size=image_size,
            verbose=verbose
        )
        if iou_score > iou_threshold:
            return True
    return False

def grasps_from_map(
    conf_map: np.ndarray,
    angle_map: np.ndarray,
    width_map: np.ndarray,
    length_map: np.ndarray,
    num_peaks: int,
    verbose: bool = False
    ):
    """
    Computes and returns the top num_peaks grasps from the maps.    
    A Grasp describes a bounding box in the following format
    [center_x, center_y, angle, length, width]

    Output is of shape (num_peaks, 5)
    """
    local_max = peak_local_max(conf_map, min_distance=10, threshold_abs=0.2, num_peaks=num_peaks)

    predicted_grasps = []
    for grasp_coord in local_max:
        grasp_coord = tuple(grasp_coord)
        grasp_angle = angle_map[grasp_coord]
        grasp_width = width_map[grasp_coord]
        grasp_length = length_map[grasp_coord]
        predicted_grasps.append((grasp_coord[0], grasp_coord[1], grasp_angle, grasp_width, grasp_length))

    predicted_grasps = np.array(predicted_grasps)

    if verbose:
        print(predicted_grasps)

    return predicted_grasps

def map_based_iou(
    conf_map: torch.Tensor,
    cos_map: torch.Tensor,
    sin_map: torch.Tensor,
    width_map: torch.Tensor,
    length_map: torch.Tensor,
    target_grasp_rects: torch.Tensor,
    angle_threshold: float = np.pi / 6,
    iou_threshold: float = 0.25,
    num_peaks: int = 3,
    verbose: bool = False
    ) -> bool:
    """ 
    Returns a boolean value prediction whether the predicted grasps as successful or not
    """
    target_grasp_rects = target_grasp_rects.numpy()
    image_size = conf_map.shape[-1]

    conf, angle, width, length = post_process_map_output(conf_map, cos_map, sin_map, width_map, length_map)
    predicted_grasps = grasps_from_map(conf, angle, width, length, num_peaks=num_peaks, verbose=verbose)
    if predicted_grasps.size == 0:
        return False
    predicted_grasp_rects = grasp_rect_from_grasps(predicted_grasps)
    return check_grasp_success(predicted_grasp_rects, target_grasp_rects, image_size, angle_threshold, iou_threshold, verbose=verbose)

In [8]:
# Cls utility functions
def cls_from_map(cls_map: torch.Tensor, verbose: bool = False):
    """
    Expects input tensor cls_map to be of shape [batch_size, num_classes + 1, image_size, image_size]
    cls_map[:, 0] is the confidence map
    """
    cls_map = cls_map.permute(1, 0, 2, 3)
    conf, preds = cls_map[0].round(), cls_map[1:]
    conf = conf.repeat(preds.shape[0], 1, 1, 1)
    masked_preds = conf * preds
    masked_preds = masked_preds.view(masked_preds.shape[0], masked_preds.shape[1], -1)
    masked_preds = masked_preds.mean(-1)
    if verbose:
        print(masked_preds)
    masked_preds = masked_preds.argmax(0)
    return masked_preds

# Trainer

In [10]:
# Trainer
class MapBasedTrainer:
    def __init__(
            self,
            training_mode: str,
            model: nn.Module,
            device: str,
            loss_fn: Any,
            dataset: MapBasedJacquardDataset,
            optimizer: optim.Optimizer,
            lr: float,
            train_batch_size: int,
            test_split_ratio: float,
            checkpoint_dir: str,
            log_dir: str,
            scheduler: Callable[[float, int], float] = lambda lr, step: lr,
            num_accumulate_batches: int = 1,
            test_batch_size: int = 16,
        ):
        """
        Initializer for MapBasedTrainer;
        
        Arguments:
            - training_mode: expected to be either "grasp" or "cls"
            - model: torch model to be trained
            - loss_fn: a callable loss function class that outputs the loss given predicted and target maps
            - dataset: an instance of MapBasedJacquardDataset (cached gives better performance)
            - optimizer: an uninitialized optimizer (eg. optimizer = torch.optim.Adam)
            - log_dir: directory for saving tensorboard logs
            - scheduler: a functions that takes the current lr and step number/epoch and outputs the next lr
            - num_accumulate_batches: if the batch size is small, we may want to accumulate gradients over multiple batches
                before updating model weights. This argument controls the number of batches we accumulate gradients for.
        """
        self.training_mode = training_mode
        assert self.training_mode == "grasp" or self.training_mode == "cls", "training_mode must be 'grasp' or 'cls'"

        self.device = torch.device(device)
        self.model = model.to(self.device)
        self.loss_fn = loss_fn
        self.lr = lr
        self.optimizer = optimizer(self.model.parameters(), self.lr)
        self.num_accumulate_batches = num_accumulate_batches
        self.scheduler = scheduler
        self.log_dir = log_dir
        self.tb_writer = SummaryWriter(log_dir=log_dir)

        self.checkpoint_dir = checkpoint_dir
        if not os.path.exists(self.checkpoint_dir):
            os.mkdir(checkpoint_dir)

        self.train_batch_size, self.test_batch_size = train_batch_size, test_batch_size
        self.train_dataset = dataset
        self.test_split_ratio = test_split_ratio

        self.test_dataset = self.train_dataset.extract_test_dataset(self.test_split_ratio)
        self.train_loader = DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=True, collate_fn=self.rect_collate)
        self.test_loader = DataLoader(self.test_dataset, batch_size=self.test_batch_size, shuffle=True, collate_fn=self.rect_collate)
        self.step_number = 1

    def run(self, num_steps: int = 10):
        """
        Run function that runs the appropriate train functions based on whether training_mode was set to grasp or cls.
        """
        if self.training_mode == "grasp":
            self.grasp_run(num_steps)
        elif self.training_mode == "cls":
            self.cls_run(num_steps)

    ################## Functions for saving and loading checkpoints ##################
    @staticmethod
    def load_state(load_path: str) -> MapBasedTrainer:
        """
        Returns a MapBasedTrainer object loaded from the checkpoint file located at load_path.
        """
        f = open(load_path, "rb")
        trainer_obj = pickle.load(f)
        trainer_obj.tb_writer = SummaryWriter(log_dir=trainer_obj.log_dir)
        return trainer_obj
    
    def save_state(self, grasp_or_cls: str, iteration: int, save_loss: float, save_acc: float, decimal_place: int = 6):
        """
        Saves the entire MapBasedTrainer class as a serialized pickle object. This saves both model weights
        and training state (iteration number, optimizer state, current lr, etc.).

        File names are generated according to model type (grasp or cls) and test metrics of most recent test step.
        """
        tb_writer = self.tb_writer
        self.tb_writer = None
        save_name = grasp_or_cls + "_Step_" + str(iteration) + "_Acc_" + str(round(save_acc, decimal_place)) + "_Loss_" + str(round(save_loss, decimal_place)) + ".pth"
        save_path = os.path.join(self.checkpoint_dir, save_name)
        with open(save_path, "wb") as outfile:
            pickle.dump(self, outfile, pickle.HIGHEST_PROTOCOL)
        self.tb_writer = tb_writer


    ################## Grasping functions ##################
    def grasp_train_step(self):
        """
        Runs a single grasp train step (trains model on the entire dataset once)
        """
        self.model = self.model.train()
        loop = tqdm(self.train_loader, total=len(self.train_loader), leave=True, position=0)
        loop.set_description(f"Grasp training step {self.step_number}")

        for i, (rgbd_image, target_grasp_maps, _, _) in enumerate(loop):
            rgbd_image, target_grasp_maps = rgbd_image.to(self.device), target_grasp_maps.to(self.device)
            predicted_grasp_maps = self.model(rgbd_image)
            loss = self.loss_fn(predicted_grasp_maps, target_grasp_maps)
            loss.backward()
            if i % self.num_accumulate_batches == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
            loop.set_postfix(loss = loss.item())

    def grasp_test_step(self):
        """
        Runs a single grasp test step (computes test metrics for model on entire test dataset)
        """
        self.model = self.model.eval()
        loop = tqdm(self.test_loader, total=len(self.test_loader), leave=True, position=0)
        loop.set_description(f"Grasp test step {self.step_number}")

        total_loss = 0
        num_correct = 0
        with torch.no_grad():
            for rgbd_image, target_grasp_maps, target_grasp_rects, _ in loop:
                rgbd_image, target_grasp_maps = rgbd_image.to(self.device), target_grasp_maps.to(self.device)
                predicted_maps = self.model(rgbd_image)
                loss = self.loss_fn(predicted_maps, target_grasp_maps)
                total_loss += loss.item()

                for grasp_map, target_rect in zip(predicted_maps, target_grasp_rects):
                    conf, cos, sin, width, length = grasp_map
                    num_correct += map_based_iou(conf, cos, sin, width, length, target_rect)

        avg_loss = total_loss / len(self.test_loader)
        avg_acc = num_correct / len(self.test_dataset)
        print(f"Average Loss: {avg_loss} | Accuracy: {avg_acc}")
        return avg_loss, avg_acc

    def grasp_run(self, num_steps: int = 10):
        """
        This function does the following;
            - Runs a grasp train step
            - Runs a grasp test step
            - Saves model weights and training state
            - Updates lr based on self.scheduler
        """
        while self.step_number <= num_steps:
            print("-" * 50)
            self.grasp_train_step()
            test_loss, test_acc = self.grasp_test_step()

            self.tb_writer.add_scalar("Grasp Test loss", test_loss, self.step_number)
            self.tb_writer.add_scalar("Grasp Test Accuracy", test_acc, self.step_number)

            self.save_state("Grasp", self.step_number, test_loss, test_acc)
            self.step_number += 1
            self.set_lr(self.scheduler(self.lr, self.step_number))
    

    ################## Classification functions ##################
    def cls_train_step(self):
        """
        Runs a single cls train step (trains model on the entire dataset once)
        """
        self.model = self.model.train()
        loop = tqdm(self.train_loader, total=len(self.train_loader), leave=True, position=0)
        loop.set_description(f"Cls training step {self.step_number}")

        for i, (rgbd_image, _, _, cls_maps) in enumerate(loop):
            rgbd_image, cls_maps = rgbd_image.to(self.device), cls_maps.to(self.device)
            predicted_cls_maps = self.model(rgbd_image)
            loss = self.loss_fn(predicted_cls_maps, cls_maps)
            loss.backward()
            if i % self.num_accumulate_batches == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
            loop.set_postfix(loss = loss.item())

    def cls_test_step(self):
        """
        Runs a single cls test step (computes test metrics on entire test dataset)
        """
        self.model = self.model.eval()
        loop = tqdm(self.test_loader, total=len(self.test_loader), leave=True, position=0)
        loop.set_description(f"Cls test step {self.step_number}")

        total_loss = 0
        num_correct = 0
        with torch.no_grad():
            for rgbd_image, _, _, cls_maps in loop:
                rgbd_image, cls_maps = rgbd_image.to(self.device), cls_maps.to(self.device)
                predicted_cls_maps = self.model(rgbd_image)
                loss = self.loss_fn(predicted_cls_maps, cls_maps)
                total_loss += loss.item()

                predicted_labels = cls_from_map(predicted_cls_maps)
                target_labels = cls_from_map(cls_maps)
                num_correct += (predicted_labels == target_labels).sum().item()

        avg_loss = total_loss / len(self.test_loader)
        avg_acc = num_correct / len(self.test_dataset)
        print(f"Average Loss: {avg_loss} | Accuracy: {avg_acc}")
        return avg_loss, avg_acc
    
    def cls_run(self, num_steps: int = 10):
        """
        This function does the following;
            - Runs a cls train step
            - Runs a cls test step
            - Saves model weights and training state
            - Updates lr based on self.scheduler
        """
        while self.step_number <= num_steps:
            print("-" * 50)
            self.cls_train_step()
            test_loss, test_acc = self.cls_test_step()

            self.tb_writer.add_scalar("Cls Test loss", test_loss, self.step_number)
            self.tb_writer.add_scalar("Cls Test Accuracy", test_acc, self.step_number)

            self.save_state("Cls", self.step_number, test_loss, test_acc)
            self.step_number += 1
            self.set_lr(self.scheduler(self.lr, self.step_number))


    ################## Utility functions ##################
    def set_lr(self, lr: float):
        self.lr = lr
        for p in self.optimizer.param_groups:
            p["lr"] = lr

    def rect_collate(self, batch):
        rgbds = []
        grasp_maps = []
        rects = []
        cls_maps = []
        for a, b, c, d in batch:
            rgbds.append(a.unsqueeze(0))
            grasp_maps.append(b.unsqueeze(0))
            rects.append(c)
            cls_maps.append(d.unsqueeze(0))
        return torch.cat(rgbds, dim=0), torch.cat(grasp_maps, dim=0), rects, torch.cat(cls_maps, dim=0)

# Training

In [20]:
dataset = MapBasedJacquardDataset(
    image_size = 224, 
    precision = torch.float32,
    cache_path = "/Users/gursi/Desktop/jacquard/cache",
    random_augment = True
)

model = GRConvNet4(clip=True)
loss_fn = CrossEntropyMapLoss()
lr = 1e-6
optimizer = torch.optim.Adam

# Scheduler that halves learning rate every 25 iterations
def scheduler(lr, step):
    if (step+1) % 25 == 0:
        return lr/2
    return lr

trainer = MapBasedTrainer(
    training_mode = "cls",
    model = model,
    device = "mps",
    loss_fn = loss_fn,
    dataset = dataset,
    optimizer = optimizer,
    lr = lr,
    train_batch_size = 8,
    test_split_ratio = 0.2,
    checkpoint_dir = "/Users/gursi/Desktop/new_trials",
    log_dir = "logs",
    scheduler = scheduler,
    num_accumulate_batches = 8,
)

In [21]:
trainer.run(100)

--------------------------------------------------


Cls training step 1: 100%|██████████| 111/111 [00:25<00:00,  4.42it/s, loss=0.619]
Cls test step 1: 100%|██████████| 14/14 [00:05<00:00,  2.41it/s]


Average Loss: 0.6246466892106193 | Accuracy: 0.3963963963963964
--------------------------------------------------


Cls training step 2: 100%|██████████| 111/111 [00:21<00:00,  5.13it/s, loss=0.594]
Cls test step 2: 100%|██████████| 14/14 [00:02<00:00,  5.53it/s]


Average Loss: 0.6052048121179853 | Accuracy: 0.3963963963963964
--------------------------------------------------


Cls training step 3: 100%|██████████| 111/111 [00:21<00:00,  5.14it/s, loss=0.665]
Cls test step 3: 100%|██████████| 14/14 [00:02<00:00,  5.65it/s]


Average Loss: 0.587119996547699 | Accuracy: 0.3963963963963964
--------------------------------------------------


Cls training step 4:  47%|████▋     | 52/111 [00:10<00:11,  5.15it/s, loss=0.575]


KeyboardInterrupt: 