In [35]:
import sys

import torch
import torch.nn as nn

import numpy as np

sys.path.append('../../../')

%load_ext autoreload
%autoreload 2

from video_processing.yolov7.utils.general import make_grid
from video_processing.yolov7.models.common import Conv, autopad, ImplicitA, ImplicitM

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [48]:
class IAuxDetect(nn.Module):
    # class variables shared across all instances
    stride=None # compute during build
    export = False # onnx export
    end2end=False
    include_nms=False
    concat=False
    def __init__(self, nc=80, anchors=(), ch=()):# detection layer
        '''
        Args:
            nc (int): number of object classes
            anchors (sequence): anchors locations
            ch (): channel
        '''
        super(IAuxDetect, self).__init__()
        self.nc=nc # number of classes
        self.no=nc+5 # number of output per anchors
        self.nl=len(anchors) # number of detection layers
        self.na=len(anchors[0]) # number of anchors
        self.grid=[torch.zeros(1)]*self.nl # init grid
        a=torch.tensor(anchors).float().view(self.nl, -1, 2)
        self.register_buffer('anchors', a) # shape nl, na//2, 2 but the author said this will be nl,na,2...?
        # nl 1 na//2 1 1 2 but the authors said nl 1 na 1 1 2
        self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2))
        print(f'In IAxDetect nl: {self.nl} na: {self.na}')
        print(f'In IAxDetect anchors: {self.anchors.shape} {self.nl}x{self.na//2}x{2}')
        print(f'In IAxDetect anchor_grid: {self.anchor_grid.shape} {self.nl}x1x{self.na//2}x1x1x{2}')
        self.m=nn.ModuleList(nn.Conv2d(x, self.no*self.na, 1) for x in ch[:self.nl]) # output conv
        self.m2=nn.ModuleList(nn.Conv2d(x, self.no*self.na, 1) for x in ch[self.nl:]) # output conv

        self.ia=nn.ModuleList(ImplicitA(x) for x in ch[:self.nl])
        self.im=nn.ModuleList(ImplicitM(self.no*self.na) for _ in ch[:self.nl])
    def forward(self, x):
        ## see https://github.com/WongKinYiu/yolov7/blob/main/models/yolo.py#L116
        z=[]
        self.training|=self.export
        for i in range(self.nl):
            print(i, '-'*100)
            print('\tx[i] ', x[i].shape,  ' self.m[i] ', self.m[i])
            x[i]=self.m[i](self.ia[i](x[i]))
            x[i]=self.im[i](x[i])
            print('\tx[i] ', x[i].shape)
            bs,_,ny,nx=x[i].shape
            # BxAxHxWxO where A=number anchors and O is number of classes+5 (where 5 is for bbox coordinate and objectness score) 
            x[i]=x[i].view(bs, self.na, self.no, ny, nx).permute(0,1,3,4,2).contiguous() 
            print('\tx[i] ', x[i].shape)
            print('\ti+nl ', i+self.nl)
        
            print('\tx[i+self.nl] ', x[i+self.nl].shape,  ' self.m2[i] ', self.m2[i])
            x[i+self.nl]=self.m2[i](x[i+self.nl])
            print('\tx[i+self.nl] ', x[i+self.nl].shape)
            # BxAxHxWxO where A=number anchors and O is number of classes+5 (where 5 is for bbox coordinate and objectness score) 
            x[i+self.nl]=x[i+self.nl].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
            print('\tx[i+self.nl] ', x[i+self.nl].shape)
            if not self.training: # inference
                print('\tself.grid[i].shape ', self.grid[i].shape)
                print('\tself.grid[i].shape[2:4] ', self.grid[i].shape[2:4], ' x[i].shape[2:4] ', x[i].shape[2:4])
                if self.grid[i].shape[2:4]!=x[i].shape[2:4]:
                    self.grid[i]=make_grid(nx=nx, ny=ny).to(x[i].device) # 1x1xHxWx2
                print('\tself.grid[i].shape ', self.grid[i].shape )
                y=x[i].sigmoid()
                print('\tx[i] ', x[i].min().item(), x[i].max().item())
                print('\ty ', y.min().item(), y.max().item() )
                raise NotImplementedError('Please implement after determine stride')
        return x if self.training else (torch.cat(z, 1), x[:self.nl])
        

In [37]:
nc=80
anchors=[[19, 27, 44, 40, 38, 94], [96, 68, 86, 152, 180, 137], [140, 301, 303, 264, 238, 542], [436, 615, 739, 380, 925, 792]]
ch=[256, 512, 768, 1024, 320, 640, 960, 1280]
detector=IAuxDetect(nc=nc, anchors=anchors, ch=ch)
detector

In IAxDetect nl: 4 na: 6
In IAxDetect anchors: torch.Size([4, 3, 2]) 4x3x2
In IAxDetect anchor_grid: torch.Size([4, 1, 3, 1, 1, 2]) 4x1x3x1x1x2


IAuxDetect(
  (m): ModuleList(
    (0): Conv2d(256, 510, kernel_size=(1, 1), stride=(1, 1))
    (1): Conv2d(512, 510, kernel_size=(1, 1), stride=(1, 1))
    (2): Conv2d(768, 510, kernel_size=(1, 1), stride=(1, 1))
    (3): Conv2d(1024, 510, kernel_size=(1, 1), stride=(1, 1))
  )
  (m2): ModuleList(
    (0): Conv2d(320, 510, kernel_size=(1, 1), stride=(1, 1))
    (1): Conv2d(640, 510, kernel_size=(1, 1), stride=(1, 1))
    (2): Conv2d(960, 510, kernel_size=(1, 1), stride=(1, 1))
    (3): Conv2d(1280, 510, kernel_size=(1, 1), stride=(1, 1))
  )
  (ia): ModuleList(
    (0-3): 4 x ImplicitA()
  )
  (im): ModuleList(
    (0-3): 4 x ImplicitM()
  )
)

(8, 4)

In [44]:
detector.eval()
x=[torch.rand(size) for size in [torch.Size([2, 256, 160, 160]), torch.Size([2, 512, 80, 80]), torch.Size([2, 768, 40, 40]), 
 torch.Size([2, 1024, 20, 20]), torch.Size([2, 320, 160, 160]), torch.Size([2, 640, 80, 80]), 
 torch.Size([2, 960, 40, 40]), torch.Size([2, 1280, 20, 20])] ]
print('len(x) ', len(x), ' nl ', detector.nl, ' na ',  detector.na, ' no ',  detector.no)

z=[]
print(detector.training|detector.export)
for i in range(detector.nl):
    print(i, '-'*100)
    print('\tx[i] ', x[i].shape,  ' detector.m[i] ', detector.m[i])
    x[i]=detector.m[i](detector.ia[i](x[i]))
    x[i]=detector.im[i](x[i])
    print('\tx[i] ', x[i].shape)
    bs,_,ny,nx=x[i].shape
    # BxAxHxWxO where A=number anchors and O is number of classes+5 (where 5 is for bbox coordinate and objectness score) 
    x[i]=x[i].view(bs, detector.na, detector.no, ny, nx).permute(0,1,3,4,2).contiguous() 
    print('\tx[i] ', x[i].shape)
    print('\ti+nl ', i+detector.nl)

    print('\tx[i+detector.nl] ', x[i+detector.nl].shape,  ' detector.m2[i] ', detector.m2[i])
    x[i+detector.nl]=detector.m2[i](x[i+detector.nl])
    print('\tx[i+detector.nl] ', x[i+detector.nl].shape)
    # BxAxHxWxO where A=number anchors and O is number of classes+5 (where 5 is for bbox coordinate and objectness score) 
    x[i+detector.nl]=x[i+detector.nl].view(bs, detector.na, detector.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
    print('\tx[i+detector.nl] ', x[i+detector.nl].shape)
    if not detector.training: # inference
        print('\tdetector.grid[i].shape ', detector.grid[i].shape)
        print('\tdetector.grid[i].shape[2:4] ', detector.grid[i].shape[2:4], ' x[i].shape[2:4] ', x[i].shape[2:4])
        if detector.grid[i].shape[2:4]!=x[i].shape[2:4]:
            detector.grid[i]=make_grid(nx=nx, ny=ny).to(x[i].device) # 1x1xHxWx2
        print('\tdetector.grid[i].shape ', detector.grid[i].shape )
        y=x[i].sigmoid()
        print('\tx[i] ', x[i].min().item(), x[i].max().item())
        print('\ty ', y.min().item(), y.max().item() )
        raise NotImplementedError('Please implement after determine stride')
    break
print([v.shape for v in x])

len(x)  8  nl  4  na  6  no  85
False
0 ----------------------------------------------------------------------------------------------------
	x[i]  torch.Size([2, 256, 160, 160])  detector.m[i]  Conv2d(256, 510, kernel_size=(1, 1), stride=(1, 1))
	x[i]  torch.Size([2, 510, 160, 160])
	x[i]  torch.Size([2, 6, 160, 160, 85])
	i+nl  4
	x[i+detector.nl]  torch.Size([2, 320, 160, 160])  detector.m2[i]  Conv2d(320, 510, kernel_size=(1, 1), stride=(1, 1))
	x[i+detector.nl]  torch.Size([2, 510, 160, 160])
	x[i+detector.nl]  torch.Size([2, 6, 160, 160, 85])
	detector.grid[i].shape  torch.Size([1, 1, 160, 160, 2])
	detector.grid[i].shape[2:4]  torch.Size([160, 160])  x[i].shape[2:4]  torch.Size([160, 160])
	detector.grid[i].shape  torch.Size([1, 1, 160, 160, 2])
	x[i]  -1.5234111547470093 1.5491529703140259
	y  0.17895977199077606 0.8247913718223572
[torch.Size([2, 6, 160, 160, 85]), torch.Size([2, 512, 80, 80]), torch.Size([2, 768, 40, 40]), torch.Size([2, 1024, 20, 20]), torch.Size([2, 6, 160,

In [47]:
detector.stride

In [45]:
#if not torch.onnx.is_in_onnx_export():
print(y.shape)
y[...,0:2]=

torch.Size([2, 6, 160, 160, 85])


In [33]:
                if not torch.onnx.is_in_onnx_export():
                    y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
                    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                else:
                    xy, wh, conf = y.split((2, 2, self.nc + 1), 4)  # y.tensor_split((2, 4, 5), 4)  # torch 1.8.0
                    xy = xy * (2. * self.stride[i]) + (self.stride[i] * (self.grid[i] - 0.5))  # new xy
                    wh = wh ** 2 * (4 * self.anchor_grid[i].data)  # new wh
                    y = torch.cat((xy, wh, conf), 4)
                z.append(y.view(bs, -1, self.no))


torch.Size([50, 20]) torch.Size([50, 20])


torch.Size([1, 1, 50, 20, 2])