In [1]:
import sys

import torch
import torch.nn as nn

import numpy as np

sys.path.append('../../../')

%load_ext autoreload
%autoreload 2

from video_processing.yolov7.utils.general import make_grid
from video_processing.yolov7.models.common import Conv, autopad, ImplicitA, ImplicitM

In [4]:
class IAuxDetect(nn.Module):
    # class variables shared across all instances
    stride=None # compute during build
    export = False # onnx export
    end2end=False
    include_nms=False
    concat=False
    def __init__(self, nc=80, anchors=(), ch=()):# detection layer
        '''
        Args:
            nc (int): number of object classes
            anchors (list[list[int]]): 3 pairs of anchor width/heights for small, medium, and large bounding boxes per level,
                e.g., [[19,27,  44,40,  38,94], 
                       [96,68,  86,152,  180,137], 
                       [140,301,  303,264,  238,542], 
                       [436,615,  739,380,  925,792]]
            ch (list[int]): list of input channels for each level for its m and m2 modules, e.g., 
                [256, 512, 768, 1024, 320, 640, 960, 1280] where the first 4 are input channels for 
                each level of m and the remainings are the same for m2
        '''
        super(IAuxDetect, self).__init__()
        self.nc=nc # number of classes
        self.no=nc+5 # number of output per anchors
        self.nl=len(anchors) # number of detection layers
        self.na=len(anchors[0])//2 # number of anchors
        self.grid=[torch.zeros(1)]*self.nl # init grid
        a=torch.tensor(anchors).float().view(self.nl, -1, 2)
        self.register_buffer('anchors', a) # shape nl, na, 2 
        # nl 1 na 1 1 2 
        self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2))
        print(f'In IAxDetect nl: {self.nl} na: {self.na}')
        print(f'In IAxDetect anchors: {self.anchors.shape} {self.nl}x{self.na}x{2}')
        print(f'In IAxDetect anchor_grid: {self.anchor_grid.shape} {self.nl}x1x{self.na}x1x1x{2}')
        self.m=nn.ModuleList(nn.Conv2d(x, self.no*self.na, 1) for x in ch[:self.nl]) # output conv
        self.m2=nn.ModuleList(nn.Conv2d(x, self.no*self.na, 1) for x in ch[self.nl:]) # output conv

        self.ia=nn.ModuleList(ImplicitA(x) for x in ch[:self.nl])
        self.im=nn.ModuleList(ImplicitM(self.no*self.na) for _ in ch[:self.nl])
    def forward(self, x, verbose=False):
        '''
        Args:
            x (list[Tensor])
        '''
        ## see https://github.com/WongKinYiu/yolov7/blob/main/models/yolo.py#L116
        z=[]
        self.training|=self.export
        for i in range(self.nl):
            if verbose:
                print(i, '-'*100)
                print('\tx[i] ', x[i].shape,  ' detector.m[i] ', detector.m[i])
            x[i]=self.m[i](self.ia[i](x[i]))
            x[i]=self.im[i](x[i])
            if verbose: print('\tx[i] ', x[i].shape)
            bs,_,ny,nx=x[i].shape
            # BxAxHxWxO where A=number anchors and O is number of classes+5 (where 5 is for bbox coordinate and objectness score) 
            x[i]=x[i].view(bs, self.na, self.no, ny, nx).permute(0,1,3,4,2).contiguous() 
            if verbose:print('\tx[i] ', x[i].shape,'\n\ti+nl ', i+self.nl)
        
            if verbose: print('\tx[i+self.nl] ', x[i+self.nl].shape,  ' self.m2[i] ', self.m2[i])
            x[i+self.nl]=self.m2[i](x[i+self.nl])
            if verbose: print('\tx[i+self.nl] ', x[i+self.nl].shape)
            # BxAxHxWxO where A=number anchors and O is number of classes+5 (where 5 is for bbox coordinate and objectness score) 
            x[i+self.nl]=x[i+self.nl].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
            if verbose: print('\tx[i+self.nl] ', x[i+self.nl].shape)
            if not self.training: # inference
                if verbose:
                    print('\tself.grid[i].shape ', self.grid[i].shape)
                    print('\tself.grid[i].shape[2:4] ', self.grid[i].shape[2:4], ' x[i].shape[2:4] ', x[i].shape[2:4])
                if self.grid[i].shape[2:4]!=x[i].shape[2:4]:
                    # feature map grid
                    self.grid[i]=make_grid(nx=nx, ny=ny).to(x[i].device) # 1x1xHxWx2
                if verbose: print('\tself.grid[i].shape ', self.grid[i].shape )
                y=x[i].sigmoid()
                if verbose: 
                    print('\tx[i] ', x[i].min().item(), x[i].max().item())
                    print('\ty ', y.shape, y.min().item(), y.max().item() )
                    print('\tstride ', self.stride[i])
                if not torch.onnx.is_in_onnx_export():
                    # xy coordinates of bounding boxes
                    #               BxAxHxWx2        1x1xHxWx2  
                    y[...,0:2]=(2.*y[...,0:2] -0.5 + self.grid[i])*self.stride[i]
                    # width and height of bounding boxes
                    #              BxAxHxWx2             1xAx1x1x2
                    y[...,2:4]=( (2.*y[...,2:4])**2. ) * self.anchor_grid[i]
                else:
                    # split the 4th dim of BxAxHxWxO into BxAxHxWx2 of xy, BxAxHxWx2 of width/height, and BxAxHxWx(Nc+1) of objectness and Nc for number of classes
                    xy,wh,confidence=y.split((2,2,self.nc+1), dim=4)
                    if verbose: print('xy ', xy.shape, ' wh ', wh.shape, ' confidence ', confidence.shape)
                    xy=xy*(2.* self.stride[i])+(self.stride[i]*(self.grid[i]-0.5))
                    wh=wh**2. * (4.*self.anchor_grid[i].data)
                    y=torch.cat([xy,wh,confidence], dim=4)
                    
                z.append(y.view(bs, -1, self.no)) # BxAxHxWxO -> Bx(AHW)xO
                
        return x if self.training else (torch.cat(z, 1), x[:self.nl])
        

In [5]:
nc=80
anchors=[[19, 27, 44, 40, 38, 94], [96, 68, 86, 152, 180, 137], [140, 301, 303, 264, 238, 542], [436, 615, 739, 380, 925, 792]]
ch=[256, 512, 768, 1024, 320, 640, 960, 1280]
detector=IAuxDetect(nc=nc, anchors=anchors, ch=ch)
detector.stride=torch.tensor([ 8., 16., 32., 64.], dtype=torch.float32)

detector.eval()
x=[torch.rand(size) for size in [torch.Size([2, 256, 160, 160]), torch.Size([2, 512, 80, 80]), torch.Size([2, 768, 40, 40]), 
                                 torch.Size([2, 1024, 20, 20]), torch.Size([2, 320, 160, 160]), torch.Size([2, 640, 80, 80]), 
                                 torch.Size([2, 960, 40, 40]), torch.Size([2, 1280, 20, 20])] ]
print('len(x) ', len(x), ' nl ', detector.nl, ' na ',  detector.na, ' no ',  detector.no)
output=detector(x)
out0, out1=output
print(type(out0), type(out1))
out0.shape,[i.shape for i in out1]

In IAxDetect nl: 4 na: 3
In IAxDetect anchors: torch.Size([4, 3, 2]) 4x3x2
In IAxDetect anchor_grid: torch.Size([4, 1, 3, 1, 1, 2]) 4x1x3x1x1x2
len(x)  8  nl  4  na  3  no  85


In [15]:
detector.train()
x=[torch.rand(size) for size in [torch.Size([2, 256, 160, 160]), torch.Size([2, 512, 80, 80]), torch.Size([2, 768, 40, 40]), 
                                 torch.Size([2, 1024, 20, 20]), torch.Size([2, 320, 160, 160]), torch.Size([2, 640, 80, 80]), 
                                 torch.Size([2, 960, 40, 40]), torch.Size([2, 1280, 20, 20])] ]
print('len(x) ', len(x), ' nl ', detector.nl, ' na ',  detector.na, ' no ',  detector.no)
output=detector(x)
[i.shape for i in output]

len(x)  8  nl  4  na  3  no  85


[torch.Size([2, 3, 160, 160, 85]),
 torch.Size([2, 3, 80, 80, 85]),
 torch.Size([2, 3, 40, 40, 85]),
 torch.Size([2, 3, 20, 20, 85]),
 torch.Size([2, 3, 160, 160, 85]),
 torch.Size([2, 3, 80, 80, 85]),
 torch.Size([2, 3, 40, 40, 85]),
 torch.Size([2, 3, 20, 20, 85])]

In [7]:
detector.eval()
x=[torch.rand(size) for size in [torch.Size([2, 256, 160, 160]), torch.Size([2, 512, 80, 80]), torch.Size([2, 768, 40, 40]), 
                                 torch.Size([2, 1024, 20, 20]), torch.Size([2, 320, 160, 160]), torch.Size([2, 640, 80, 80]), 
                                 torch.Size([2, 960, 40, 40]), torch.Size([2, 1280, 20, 20])] ]

z=[]
print(detector.training|detector.export)
for i in range(detector.nl):
    print(i, '-'*100)
    print('\tx[i] ', x[i].shape,  ' detector.m[i] ', detector.m[i])
    x[i]=detector.m[i](detector.ia[i](x[i]))
    x[i]=detector.im[i](x[i])
    print('\tx[i] ', x[i].shape)
    bs,_,ny,nx=x[i].shape
    # BxAxHxWxO where A=number anchors and O is number of classes+5 (where 5 is for bbox coordinate and objectness score) 
    x[i]=x[i].view(bs, detector.na, detector.no, ny, nx).permute(0,1,3,4,2).contiguous() 
    print('\tx[i] ', x[i].shape)
    print('\ti+nl ', i+detector.nl)

    print('\tx[i+detector.nl] ', x[i+detector.nl].shape,  ' detector.m2[i] ', detector.m2[i])
    x[i+detector.nl]=detector.m2[i](x[i+detector.nl])
    print('\tx[i+detector.nl] ', x[i+detector.nl].shape)
    # BxAxHxWxO where A=number anchors and O is number of classes+5 (where 5 is for bbox coordinate and objectness score) 
    x[i+detector.nl]=x[i+detector.nl].view(bs, detector.na, detector.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
    print('\tx[i+detector.nl] ', x[i+detector.nl].shape)
    if not detector.training: # inference
        print('\tdetector.grid[i].shape ', detector.grid[i].shape)
        print('\tdetector.grid[i].shape[2:4] ', detector.grid[i].shape[2:4], ' x[i].shape[2:4] ', x[i].shape[2:4])
        if detector.grid[i].shape[2:4]!=x[i].shape[2:4]:
            # feature map grid
            detector.grid[i]=make_grid(nx=nx, ny=ny).to(x[i].device) # 1x1xHxWx2
        print('\tdetector.grid[i].shape ', detector.grid[i].shape )
        y=x[i].sigmoid()
        print('\tx[i] ', x[i].min().item(), x[i].max().item())
        print('\ty ', y.shape, y.min().item(), y.max().item() )
        print('\tstride ', detector.stride[i])
        if not torch.onnx.is_in_onnx_export():
            # xy coordinates of bounding boxes
            #               BxAxHxWx2        1x1xHxWx2  
            y[...,0:2]=(2.*y[...,0:2] -0.5 + detector.grid[i])*detector.stride[i]
            # width and height of bounding boxes
            #              BxAxHxWx2             1xAx1x1x2
            y[...,2:4]=( (2.*y[...,2:4])**2. ) * detector.anchor_grid
        else:
            # split the 4th dim of BxAxHxWxO into BxAxHxWx2 of xy, BxAxHxWx2 of width/height, and BxAxHxWx(Nc+1) of objectness and Nc for number of classes
            xy,wh,confidence=y.split((2,2,detector.nc+1), dim=4)
            print('xy ', xy.shape, ' wh ', wh.shape, ' confidence ', confidence.shape)
            xy=xy*(2.* detector.stride[i])+(detector.stride[i]*(detector.grid[i]-0.5))
            wh=wh**2. * (4.*detector.anchor_grid[i].data)
            y=torch.cat([xy,wh,confidence], dim=4)
        z.append(y.view(bs, -1, detector.no)) # BxAxHXWxO -> Bx(AHW)xO
    break
print([v.shape for v in x])

len(x)  8  nl  4  na  3  no  85
False
0 ----------------------------------------------------------------------------------------------------
	x[i]  torch.Size([2, 256, 160, 160])  detector.m[i]  Conv2d(256, 255, kernel_size=(1, 1), stride=(1, 1))
	x[i]  torch.Size([2, 255, 160, 160])
	x[i]  torch.Size([2, 3, 160, 160, 85])
	i+nl  4
	x[i+detector.nl]  torch.Size([2, 320, 160, 160])  detector.m2[i]  Conv2d(320, 255, kernel_size=(1, 1), stride=(1, 1))
	x[i+detector.nl]  torch.Size([2, 255, 160, 160])
	x[i+detector.nl]  torch.Size([2, 3, 160, 160, 85])
	detector.grid[i].shape  torch.Size([1])
	detector.grid[i].shape[2:4]  torch.Size([])  x[i].shape[2:4]  torch.Size([160, 160])
	detector.grid[i].shape  torch.Size([1, 1, 160, 160, 2])
	x[i]  -1.5335314273834229 1.637326955795288
	y  torch.Size([2, 3, 160, 160, 85]) 0.17747758328914642 0.8371708989143372
	stride  tensor(8.)
[torch.Size([2, 3, 160, 160, 85]), torch.Size([2, 512, 80, 80]), torch.Size([2, 768, 40, 40]), torch.Size([2, 1024, 20, 

In [8]:
detector.stride[i]

tensor(8.)

False

In [15]:
y.shape, y.view(bs, -1, detector.no).shape, detector.no

(torch.Size([2, 3, 160, 160, 85]), torch.Size([2, 76800, 85]), 85)

In [33]:
                if not torch.onnx.is_in_onnx_export():
                    y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
                    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                else:
                    xy, wh, conf = y.split((2, 2, self.nc + 1), 4)  # y.tensor_split((2, 4, 5), 4)  # torch 1.8.0
                    xy = xy * (2. * self.stride[i]) + (self.stride[i] * (self.grid[i] - 0.5))  # new xy
                    wh = wh ** 2 * (4 * self.anchor_grid[i].data)  # new wh
                    y = torch.cat((xy, wh, conf), 4)
                z.append(y.view(bs, -1, self.no))


torch.Size([50, 20]) torch.Size([50, 20])


torch.Size([1, 1, 50, 20, 2])