In [1]:
%load_ext autoreload
%autoreload 2
import os
import torch
import numpy as np
import torch.utils.data
from PIL import Image
 
 
class PlantDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to ensure that they are aligned
        
        total = list(sorted(os.listdir(os.path.join(root))))
        self.imgs, self.masks = [], []
        for file in total:
            #print(file)
            if 'rgb' in file:
                self.imgs.append(file)
                label = file.replace('rgb', 'label_mask')
                self.masks.append(label)
        #print(self.masks)
        return
    
    def __getitem__(self, idx):
        # load images ad masks
        img_path = os.path.join(self.root, self.imgs[idx])
        #print(img_path)
        mask_path = os.path.join(self.root, self.masks[idx])
        img = Image.open(img_path).convert("RGB")
        # note that we haven't converted the mask to RGB,
        # because each color corresponds to a different instance with 0 being background
        mask = Image.open(mask_path)
        mask = np.array(mask)
        #print(mask.shape)
        # instances are encoded as different colors
        obj_ids = np.unique(mask)
        
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]
        
        # split the color-encoded mask into a set of binary masks
        masks = np.array(mask == obj_ids[:, None, None])
       
        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):

            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            if xmin > xmax:
                print(1)
            if ymin > ymax:
                print(1)
            boxes.append([xmin, ymin, xmax, ymax])
 
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
 
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
 
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd
 
        if self.transforms is not None:
            img, target = self.transforms(img, target)
 
        return img, target
 
    def __len__(self):
        return len(self.imgs)

In [2]:
class FullDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to ensure that they are aligned
        list1 = ['LSC_2017/A1', 'LSC_2017/A2', 'LSC_2017/A3', 'LSC_2017/A4']
        for root in list1:
            total = list(sorted(os.listdir(os.path.join(root))))
            self.imgs, self.masks = [], []
            for file in total:
                #print(file)
                if 'rgb' in file:
                    p = os.path.join(root, file)
                    self.imgs.append(p)
                    label = p.replace('rgb', 'label_mask')
                    self.masks.append(label)
        #print(self.masks)
        return
    
    def __getitem__(self, idx):
        # load images ad masks
        img_path = os.path.join(self.root, self.imgs[idx])
        #print(img_path)
        mask_path = os.path.join(self.root, self.masks[idx])
        img = Image.open(self.imgs[idx]).convert("RGB")
        
        # note that we haven't converted the mask to RGB,
        # because each color corresponds to a different instance with 0 being background
        mask = Image.open(self.masks[idx])
        mask = np.array(mask)
        #print(mask.shape)
        # instances are encoded as different colors
        obj_ids = np.unique(mask)
        
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]
        
        # split the color-encoded mask into a set of binary masks
        masks = np.array(mask == obj_ids[:, None, None])
       
        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):

            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            if xmin > xmax:
                print(1)
            if ymin > ymax:
                print(1)
            boxes.append([xmin, ymin, xmax, ymax])
 
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
 
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
 
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd
 
        if self.transforms is not None:
            img, target = self.transforms(img, target)
 
        return img, target
 
    def __len__(self):
        return len(self.imgs)

In [3]:
import utils
import transforms as T
from torchvision import transforms

from engine import train_one_epoch, evaluate

def get_transform(train):
    trans = []
    # converts the image, a PIL image, into a PyTorch Tensor
    trans.append(T.ToTensor())
    if train:
        # during training, randomly flip the training images
        # and ground-truth for data augmentation
        trans.append(T.RandomHorizontalFlip(0.5))
        #trans.append(transforms.RandomRotation(90,resample=False,expand=False,center=None))
 
    return T.Compose(trans)

In [4]:
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
 
# load a pre-trained model for classification and return only the features
backbone = torchvision.models.mobilenet_v2(pretrained=True).features
 
# FasterRCNN needs to know the number of output channels in a backbone. 
# For mobilenet_v2, it's 1280. So we need to add it here
backbone.out_channels = 1280
 
# let's make the RPN generate 5 x 3 anchors per spatial
# location, with 5 different sizes and 3 different aspect
# ratios. We have a Tuple[Tuple[int]] because each feature
# map could potentially have different sizes and aspect ratios 
anchor_generator = AnchorGenerator(sizes=((16, 32, 64, 128, 256, 512),),
                                   aspect_ratios=((0.5, 1.0, 2.0),))
 
# let's define what are the feature maps that we will use to perform the region of 
# interest cropping, as well as the size of the crop after rescaling.
# if your backbone returns a Tensor, featmap_names is expected to
# be [0]. More generally, the backbone should return an OrderedDict[Tensor], 
# and in featmap_names you can choose which feature maps to use.
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
                                                output_size=7,
                                                sampling_ratio=2)
 
# put the pieces together inside a FasterRCNN model
model = FasterRCNN(backbone,
                   num_classes=2,
                   rpn_anchor_generator=anchor_generator,
                   box_roi_pool=roi_pooler)

In [5]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
 
      
def get_instance_segmentation_model(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
 
    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
 
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
 
    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
 
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)
 
    return model

In [6]:
# use the PennFudan dataset and defined transformations
dataset = PlantDataset('LSC_2017/A2', get_transform(train=True))
dataset_test = PlantDataset('LSC_2017/A1', get_transform(train=False))
 
# split the dataset in train and test set
torch.manual_seed(1)
indices = torch.randperm(len(dataset)).tolist()

dataset = torch.utils.data.Subset(dataset, list(range(len(dataset))))
dataset_test = torch.utils.data.Subset(dataset_test, indices[:])
 
# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=0,
    collate_fn=utils.collate_fn)
 
data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=0,
    collate_fn=utils.collate_fn)

In [7]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
 
# the dataset has two classes only - background and person
num_classes = 2
 
# get the model using the helper function
model = get_instance_segmentation_model(num_classes)
# move model to the right device
model.to(device)
 
# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
#optimizer = torch.optim.SGD(params, lr=0.001,
#                            momentum=0.9, weight_decay=0.0005)
 
# the learning rate scheduler decreases the learning rate by 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=3,
                                               gamma=0.1)

In [14]:
# training

num_epochs = 30
losses = []
evals = []
for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    
    loss = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)

    # update the learning rate
    lr_scheduler.step()
    
    losses.append(loss)
    # evaluate on the test dataset
    
    x = evaluate(model, data_loader_test, device=device)
    evals.append(x)

Epoch: [0]  [ 0/16]  eta: 0:00:13  lr: 0.000068  loss: 4.8927 (4.8927)  loss_classifier: 0.5674 (0.5674)  loss_box_reg: 0.5784 (0.5784)  loss_mask: 2.5959 (2.5959)  loss_objectness: 1.0679 (1.0679)  loss_rpn_box_reg: 0.0831 (0.0831)  time: 0.8549  data: 0.1287  max mem: 4003


KeyboardInterrupt: 

In [86]:
print(x)

<coco_eval.CocoEvaluator object at 0x0000029635008B88>


In [115]:
print(losses)
all_loss = []
cls_loss = []
bbx_loss = []
mask_loss = []
for loss in losses:
    print(loss)
    all_loss.append(loss.loss.value)
    cls_loss.append(loss.loss_classifier.value)
    bbx_loss.append(loss.loss_box_reg.value)
    mask_loss.append(loss.loss_mask.value)

print(evals[0].summarize())

[<utils.MetricLogger object at 0x000002961BB6B788>, <utils.MetricLogger object at 0x00000296528FCF48>, <utils.MetricLogger object at 0x00000294F9BBFF48>, <utils.MetricLogger object at 0x00000296194DC848>, <utils.MetricLogger object at 0x0000029635404B08>, <utils.MetricLogger object at 0x0000029619F74C88>, <utils.MetricLogger object at 0x000002963E7FD648>, <utils.MetricLogger object at 0x00000296319D6248>, <utils.MetricLogger object at 0x000002963190ECC8>, <utils.MetricLogger object at 0x00000296194A11C8>, <utils.MetricLogger object at 0x0000029631126E48>, <utils.MetricLogger object at 0x000002963F9845C8>, <utils.MetricLogger object at 0x0000029633368708>, <utils.MetricLogger object at 0x0000029634E34A48>, <utils.MetricLogger object at 0x0000029631A54948>, <utils.MetricLogger object at 0x000002963E074888>, <utils.MetricLogger object at 0x000002963E09FB08>, <utils.MetricLogger object at 0x00000296407BD448>, <utils.MetricLogger object at 0x000002963E010D48>, <utils.MetricLogger object at 

In [123]:
for e in evals:
    evals[0].summarize()

IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.427
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.886
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.358
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.341
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.461
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.415
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.039
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.347
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.532
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.447
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.578
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.474
IoU metric: segm
 Avera

In [121]:
print(mask_loss)

[0.21040065586566925, 0.18052609264850616, 0.20139296352863312, 0.15309494733810425, 0.12295417487621307, 0.1492757350206375, 0.13650934398174286, 0.1570550501346588, 0.15765519440174103, 0.22076646983623505, 0.16934970021247864, 0.1563444435596466, 0.15594248473644257, 0.21010935306549072, 0.14290395379066467, 0.13467177748680115, 0.12209370732307434, 0.14645618200302124, 0.13524404168128967, 0.13391563296318054, 0.1447090059518814, 0.15345405042171478, 0.14448238909244537, 0.13554979860782623, 0.14319033920764923, 0.17906273901462555, 0.2140408754348755, 0.1639394462108612, 0.17139047384262085, 0.1544235348701477]


In [118]:
import numpy as np
import matplotlib.pyplot as plt

matplotlib.use('TkAgg')
x = range(1,31)
fig = plt.figure()                     


plt.plot(x,all_loss,label="Total loss")
plt.plot(x,cls_loss, label = "Classifer loss")
plt.plot(x,bbx_loss, label = "Bbox regressor loss")
plt.plot(x,mask_loss, label = "Mask loss")
plt.legend(loc='best')
plt.show()

In [124]:
torch.save(model.state_dict(), 'model_save_A1_train')
torch.cuda.empty_cache()

In [8]:
model.load_state_dict(torch.load( 'model_save_A5_train'))
model.eval()

MaskRCNN(
  (transform): GeneralizedRCNNTransform()
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d()
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d()
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d()
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d()
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): FrozenBatchNorm2d()
          )
    

In [9]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2


def select_top_predictions(predictions, threshold):
    idx = (predictions["scores"] > threshold).nonzero().squeeze(1)
    new_predictions = {}
    for k, v in predictions.items():
        new_predictions[k] = v[idx]
    return new_predictions


def compute_colors_for_labels(labels, palette=None):
    """
    Simple function that adds fixed colors depending on the class
    """
    if palette is None:
        palette = torch.tensor([123, 86, 212])

    colors = labels[:, None] * palette
    #print(colors)
    colors = (colors % 255).numpy().astype("uint8")
    colors = colors / 255.
    return colors

def overlay_boxes(image, predictions):
    """
    Adds the predicted boxes on top of the image
    Arguments:
        image (np.ndarray): an image as returned by OpenCV
        predictions (BoxList): the result of the computation by the model.
            It should contain the field `labels`.
    """
    labels = predictions["labels"]
    boxes = predictions['boxes']
    n = len(labels)
    
    colors = compute_colors_for_labels(torch.tensor(range(1,n+1))).tolist()

    for box, color in zip(boxes, colors):
        box = box.to(torch.int64)
        top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
        image = cv2.rectangle(
            image, tuple(top_left), tuple(bottom_right), tuple(color), 1
        )

    return image

import random
def overlay_mask(image, predictions):
    """
    Adds the instances contours for each predicted object.
    Each label has a different color.
    Arguments:
        image (np.ndarray): an image as returned by OpenCV
        predictions (BoxList): the result of the computation by the model.
            It should contain the field `mask` and `labels`.
    """
    masks = predictions["masks"].ge(0.5).mul(255).byte().numpy()
    labels = predictions["labels"]

    n = len(labels)
    print(n)
    colors = compute_colors_for_labels(torch.tensor(range(1,n+1))).tolist()
    

    for mask, color in zip(masks, colors):
        
        r = random.randint(0, 255)/ 255.
        g = random.randint(0, 255)/ 255.
        b = random.randint(0, 255)/ 255.
        color = [r, g, b] 
        thresh = mask[0, :, :, None]
        #thresh=cv2.cvtColor(thresh,cv2.COLOR_GRAY2RGB) 
        contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(image, contours, -1, color, -1)
    composite = image

    return composite


def overlay_keypoints(image, predictions):
    kps = predictions["keypoints"]
    scores = predictions["keypoints_scores"]
    kps = torch.cat((kps[:, :, 0:2], scores[:, :, None]), dim=2).numpy()
    for region in kps:
        image = vis_keypoints(image, region.transpose((1, 0)))
    return image

def vis_keypoints(img, kps, kp_thresh=2, alpha=0.7):
    """Visualizes keypoints (adapted from vis_one_image).
    kps has shape (4, #keypoints) where 4 rows are (x, y, logit, prob).
    """
    dataset_keypoints = PersonKeypoints.NAMES
    kp_lines = PersonKeypoints.CONNECTIONS

    # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
    cmap = plt.get_cmap('rainbow')
    colors = [cmap(i) for i in np.linspace(0, 1, len(kp_lines) + 2)]
    colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]

    # Perform the drawing on a copy of the image, to allow for blending.
    kp_mask = np.copy(img)

    # Draw mid shoulder / mid hip first for better visualization.
    mid_shoulder = (
        kps[:2, dataset_keypoints.index('right_shoulder')] +
        kps[:2, dataset_keypoints.index('left_shoulder')]) / 2.0
    sc_mid_shoulder = np.minimum(
        kps[2, dataset_keypoints.index('right_shoulder')],
        kps[2, dataset_keypoints.index('left_shoulder')])
    mid_hip = (
        kps[:2, dataset_keypoints.index('right_hip')] +
        kps[:2, dataset_keypoints.index('left_hip')]) / 2.0
    sc_mid_hip = np.minimum(
        kps[2, dataset_keypoints.index('right_hip')],
        kps[2, dataset_keypoints.index('left_hip')])
    nose_idx = dataset_keypoints.index('nose')
    if sc_mid_shoulder > kp_thresh and kps[2, nose_idx] > kp_thresh:
        cv2.line(
            kp_mask, tuple(mid_shoulder), tuple(kps[:2, nose_idx]),
            color=colors[len(kp_lines)], thickness=2, lineType=cv2.LINE_AA)
    if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh:
        cv2.line(
            kp_mask, tuple(mid_shoulder), tuple(mid_hip),
            color=colors[len(kp_lines) + 1], thickness=2, lineType=cv2.LINE_AA)

    # Draw the keypoints.
    for l in range(len(kp_lines)):
        i1 = kp_lines[l][0]
        i2 = kp_lines[l][1]
        p1 = kps[0, i1], kps[1, i1]
        p2 = kps[0, i2], kps[1, i2]
        if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh:
            cv2.line(
                kp_mask, p1, p2,
                color=colors[l], thickness=2, lineType=cv2.LINE_AA)
        if kps[2, i1] > kp_thresh:
            cv2.circle(
                kp_mask, p1,
                radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)
        if kps[2, i2] > kp_thresh:
            cv2.circle(
                kp_mask, p2,
                radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)

    # Blend the keypoints.
    return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0)

def kp_connections(keypoints):
    kp_lines = [
        [keypoints.index('left_eye'), keypoints.index('right_eye')],
        [keypoints.index('left_eye'), keypoints.index('nose')],
        [keypoints.index('right_eye'), keypoints.index('nose')],
        [keypoints.index('right_eye'), keypoints.index('right_ear')],
        [keypoints.index('left_eye'), keypoints.index('left_ear')],
        [keypoints.index('right_shoulder'), keypoints.index('right_elbow')],
        [keypoints.index('right_elbow'), keypoints.index('right_wrist')],
        [keypoints.index('left_shoulder'), keypoints.index('left_elbow')],
        [keypoints.index('left_elbow'), keypoints.index('left_wrist')],
        [keypoints.index('right_hip'), keypoints.index('right_knee')],
        [keypoints.index('right_knee'), keypoints.index('right_ankle')],
        [keypoints.index('left_hip'), keypoints.index('left_knee')],
        [keypoints.index('left_knee'), keypoints.index('left_ankle')],
        [keypoints.index('right_shoulder'), keypoints.index('left_shoulder')],
        [keypoints.index('right_hip'), keypoints.index('left_hip')],
    ]
    return kp_lines

#PersonKeypoints.CONNECTIONS = kp_connections(PersonKeypoints.NAMES)


def overlay_class_names(image, predictions):
    """
    Adds detected class names and scores in the positions defined by the
    top-left corner of the predicted bounding box
    Arguments:
        image (np.ndarray): an image as returned by OpenCV
        predictions (BoxList): the result of the computation by the model.
            It should contain the field `scores` and `labels`.
    """
    scores = predictions["scores"].tolist()
    labels = predictions["labels"].tolist()
    labels = [CATEGORIES[i] for i in labels]
    boxes = predictions['boxes']

    template = "{}: {:.2f}"
    for box, score, label in zip(boxes, scores, labels):
        x, y = box[:2]
        s = template.format(label, score)
        cv2.putText(
            image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1
        )

    return image

import visual__ as vs
def predict(img, model, device):
    cv_img = np.array(img)
    
    #img_tensor = torchvision.transforms.ToPILImage(img)
    with torch.no_grad():
        model.eval()
        output = model([img.to(device)])
    top_predictions = select_top_predictions(output[0], 0.7)
    top_predictions = {k:v.cpu() for k, v in top_predictions.items()}

    result = cv_img.copy()
    result = result.transpose(1,2,0)

    result=cv2.cvtColor(result,cv2.COLOR_RGB2BGR) 

    #result = overlay_boxes(result, top_predictions)
    
    if 'masks' in top_predictions:
        result = overlay_mask(result, top_predictions)
        
    #if 'keypoints' in top_predictions:
    #    result = overlay_keypoints(result, top_predictions)
    #result = overlay_class_names(result, top_predictions)
    return result, output, top_predictions



In [11]:
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to ensure that they are aligned
        total = list(sorted(os.listdir(os.path.join(root))))
        self.imgs = []
        for file in total:
            #print(file)
            if 'filted' in file:
                self.imgs.append(file)
        #print(self.masks)
        return
    
    def __getitem__(self, idx):
        # load images ad masks
        img_path = os.path.join(self.root, self.imgs[idx])
        #print(img_path)

        mask_path = os.path.join('LSC_2017/A2', 'plant001_label_mask.png')
        img = Image.open(img_path).convert("RGB")
        
        # note that we haven't converted the mask to RGB,
        # because each color corresponds to a different instance with 0 being background
        mask = Image.open(mask_path)
        mask = np.array(mask)
        #print(mask.shape)
        # instances are encoded as different colors
        obj_ids = np.unique(mask)
        
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]
        
        # split the color-encoded mask into a set of binary masks
        masks = np.array(mask == obj_ids[:, None, None])
       
        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):

            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            if xmin > xmax:
                print(1)
            if ymin > ymax:
                print(1)
            boxes.append([xmin, ymin, xmax, ymax])
 
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
 
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
 
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd
 
        if self.transforms is not None:
            img, target = self.transforms(img, target)
 
        return img, target
 
    def __len__(self):
        return len(self.imgs)
    

In [12]:
#dataset_init = PennFudanDataset('LSC_2017/test/A4', get_transform(train=False))

dataset_init = TestDataset('LSC_2017/A1train', get_transform(train=False))
#dataset_init = PennFudanDataset('LSC_2017/A3train', get_transform(train=False))

test_len = 0

# split the dataset in train and test set
torch.manual_seed(1)
indices = torch.randperm(len(dataset_init)).tolist()
print(len(indices))
dataset = torch.utils.data.Subset(dataset_init, indices[:])
dataset_test = torch.utils.data.Subset(dataset_init, indices[-test_len:])


# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=1, shuffle=False, num_workers=0,
    collate_fn=utils.collate_fn)
'''
data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=0,
    collate_fn=utils.collate_fn)
''' 

def batch_predict(dataset, ids, dataset_init):
    for img, id in zip(dataset, ids):
        torch.cuda.empty_cache()
        img = img[0]
        #print(mask)
        #img, _ = dataset_init[id]
        #img = torch.tensor(np.array(img).transpose(2,1,0))
        print(img.shape)
        name = dataset_init.imgs[id]
        
        print(name)
        #cv2.imwrite('LSC_2017/A1train/'+name.replace('rgb','fore'), np.array(img))
        result, output, top_predictions = predict(img, model, device)
        #cv.imwrite('LSC_2017/test/demo/A4/'+name, res))
        masks = top_predictions['masks']
        res = masks[0][0].ge(0.5).mul(255).byte().numpy()
        
        
        result = result*255.
        print(result[123][123])
        print(result.shape)
        #cv2.imshow('',result)
        #cv2.waitKey(0)
        cv2.imwrite('LSC_2017/test/demo/3d/'+name,result)
        #new_im = Image.fromarray(result)
        
        #new_im.save('LSC_2017/test/demo/A4/'+name)
    
        #Image.('LSC_2017/test/demo/A4/'+name, np.array(result)))

        index = 2
        for i in range(res.shape[0]):
            for j in range(res.shape[1]):
                if res[i][j] != 0:
                    res[i][j] = 1
        for x in range(1, len(masks)):
            mask = masks[x].ge(0.5).mul(255).byte().numpy()
            for i in range(mask[0].shape[0]):
                for j in range(mask[0].shape[1]):
                    if mask[0][i][j] == 1:
                        #print(mask[0][i][j])
                        res[i][j] = index
                        #print(res[i][j])
            index += 1
        
        name = name.replace('rgb', 'label')
        #print(name, res)
        
        #print(cv2.imwrite('LSC_2017/test/result/A4/'+name, res))
        
batch_predict(dataset, indices[:], dataset_init)

108
torch.Size([3, 530, 500])
plant127_filted.png
19
[0. 0. 0.]
(530, 500, 3)
torch.Size([3, 530, 500])
plant129_filted.png
16
[0. 0. 0.]
(530, 500, 3)
torch.Size([3, 530, 500])
plant099_filted.png
20
[0. 0. 0.]
(530, 500, 3)
torch.Size([3, 530, 500])
plant060_filted.png
17
[0. 0. 0.]
(530, 500, 3)
torch.Size([3, 530, 500])
plant113_filted.png
12
[0. 0. 0.]
(530, 500, 3)
torch.Size([3, 530, 500])
plant065_filted.png
17
[ 90. 216. 107.]
(530, 500, 3)
torch.Size([3, 530, 500])
plant067_filted.png
12
[0. 0. 0.]
(530, 500, 3)
torch.Size([3, 530, 500])
plant032_filted.png
15
[0. 0. 0.]
(530, 500, 3)
torch.Size([3, 530, 500])
plant100_filted.png
14
[0. 0. 0.]
(530, 500, 3)
torch.Size([3, 530, 500])
plant013_filted.png
11
[0. 0. 0.]
(530, 500, 3)
torch.Size([3, 530, 500])
plant080_filted.png
19
[0. 0. 0.]
(530, 500, 3)
torch.Size([3, 530, 500])
plant116_filted.png
17
[0. 0. 0.]
(530, 500, 3)
torch.Size([3, 530, 500])
plant130_filted.png
16
[0. 0. 0.]
(530, 500, 3)
torch.Size([3, 530, 500])
pl