In [1]:
import torch

In [109]:
class YoloDetect:
    """
    Parameters:
    """
    S: int=7
    B: int=2
    C: int=2
    N: int=B * 5 + C
    ProbThreshold: int=0.1
    ConfidenceThreshold: int=0.1
    IoUThreshold: int=0.5
    def Decoder(self, Prediction: torch.Tensor):
        assert Prediction.dim() < 4
        # X is i, Y is j
        I = torch.arange(self.S).unsqueeze(-1).expand(self.S, self.S).reshape(-1, 1)
        J = torch.arange(self.S).unsqueeze(-1).expand(self.S, self.S).transpose(1, 0).reshape(-1, 1)
        IJ = torch.concat((I, J), dim=-1) / self.S
        Ones = torch.ones(self.S * self.S, 1)
        Ones = torch.concat((Ones, Ones), dim=-1)
        IJIJ = torch.concat((IJ, Ones, IJ, Ones), dim=-1)

        BBoxes = Prediction[..., [0, 1, 2, 3, 5, 6, 7, 8]].reshape(-1, 8)
        BBoxesNormalized = BBoxes / self.S + IJIJ
        XYMIN = BBoxesNormalized[..., [0, 1, 4, 5]] - 0.5 * BBoxesNormalized[..., [2, 3, 6, 7]]
        XYMAX = BBoxesNormalized[..., [0, 1, 4, 5]] + 0.5 * BBoxesNormalized[..., [2, 3, 6, 7]]
        
        BBoxes = torch.concat((XYMIN, XYMAX), dim=-1).reshape(-1, 4)
        Confidence = Prediction[..., [4, 9]].reshape(self.S * self.S, 2).reshape(-1, 1)
        LabelScores, Labels = torch.max(Prediction[..., self.B * 5:], dim=-1)
        LabelScores = LabelScores.reshape(self.S * self.S, 1); LabelScores = torch.concat((LabelScores, LabelScores), dim=-1).reshape(-1, 1)
        Labels = Labels.reshape(self.S * self.S, 1); Labels = torch.concat((Labels, Labels), dim=-1).reshape(-1, 1)

        Mask = ((Confidence * LabelScores) > self.ProbThreshold)
        BBoxes = BBoxes[Mask.expand_as(BBoxes)].reshape(-1, 4)
        Confidence = Confidence[Mask].reshape(-1, 1)
        LabelScores = LabelScores[Mask].reshape(-1, 1)
        Labels = Labels[Mask].reshape(-1, 1)

        return torch.concat((BBoxes, Confidence, LabelScores, Labels), dim=-1)

    def NMS(self, Boxes: torch.Tensor, Scores: torch.Tensor, top_k: int=200):
        count = 0
        keep = Scores.new(Scores.size(0)).zero_().long()
        x1 = Boxes[:, 0]
        y1 = Boxes[:, 1]
        x2 = Boxes[:, 2]
        y2 = Boxes[:, 3]
        area = torch.mul(x2 - x1, y2 - y1)
        tmp_x1 = Boxes.new()
        tmp_y1 = Boxes.new()
        tmp_x2 = Boxes.new()
        tmp_y2 = Boxes.new()
        tmp_w =  Boxes.new()
        tmp_h =  Boxes.new()
        v, idx = Scores.sort(0)
        idx = idx[-top_k:]
        while idx.numel() > 0:
            i = idx[-1]
            
            keep[count] = i
            count += 1
            
            if idx.size(0) == 1: break
            idx = idx[:-1]
            
            tmp_x1 = torch.index_select(x1, 0, idx)
            tmp_y1 = torch.index_select(y1, 0, idx)
            tmp_x2 = torch.index_select(x2, 0, idx)
            tmp_y2 = torch.index_select(y2, 0, idx)
            
            tmp_x1 = torch.clamp(tmp_x1, min=x1[i])
            tmp_y1 = torch.clamp(tmp_y1, min=y1[i])
            tmp_x2 = torch.clamp(tmp_x2, max=x2[i])
            tmp_y2 = torch.clamp(tmp_y2, max=y2[i])

            tmp_w.resize_as_(tmp_x2)
            tmp_h.resize_as_(tmp_y2)

            tmp_w = tmp_x2 - tmp_x1
            tmp_h = tmp_y2 - tmp_y1

            tmp_w = torch.clamp(tmp_w, min=0.0)
            tmp_h = torch.clamp(tmp_h, min=0.0)

            inter = tmp_w*tmp_h

            rem_areas = torch.index_select(area, 0, idx)
            union = (rem_areas - inter) + area[i]
            IoU = inter/union

            idx = idx[IoU.le(self.IoUThreshold)]
            
        return keep, count
    
    def Detect(self, Prediction: torch.Tensor):
        Decoder = self.Decoder(Prediction)
        Result = {}
        for ClassesLabel in range(self.C):
            Mask = (Decoder[..., -1] == ClassesLabel).unsqueeze(-1).expand_as(Decoder)
            if Mask.numel() == 0: continue
            BBoxes = Decoder[Mask].reshape(-1, 7)
            Boxes = BBoxes[..., [0, 1, 2, 3]].reshape(-1, 4)
            Confs = BBoxes[..., 4].reshape(-1)
            LabelScores = BBoxes[..., 5].reshape(-1)
            keep, count = self.NMS(Boxes, Confs)
            Boxes = Boxes[keep].reshape(-1, 4)
            Probs = (Confs[keep] * LabelScores[keep]).reshape(-1, 1)
            Result[ClassesLabel] = torch.concat((Boxes, Probs), dim=-1)
        return Result

    def AP(self):
        return

    def MeanAP(self):
        return

In [112]:
torch.manual_seed(1)
ret = YoloDetect().Detect(torch.sigmoid(torch.rand(7, 7, 2 * 5 + 2)))