In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

# import

In [None]:
# export
import torchvision

In [None]:
# export
from torch import nn

In [None]:
# export
import torch

In [None]:
# export
import torchvision

In [None]:
# export
from collections import OrderedDict

In [None]:
# export
from IPython.core import debugger as idb

In [None]:
# export
from torchvision.models.resnet import conv1x1

In [None]:
# export
from FLAI.detect_symbol.exp import resnet_ssd as resnet_ssd_detsym

In [None]:
#测试用，不需要导出
from FLAI.detect_symbol.exp import databunch

# functions

### ssd_block

In [None]:
# export
class ssd_block(nn.Module):
    '''
    和detect_symbol里面的ssd_block相比只是去掉了宽高相关的部分
    '''
    def __init__(self, k, nin, n_clas):
        '''
        ssd头模块，它根据某层的特征图给出bbox预测信息，该模块的输出包含4个部分：
        -- loc：bbox中心偏移，2个值
        -- conf：目标信心，1个值
        -- clas：目标类别，n_clas个值
        ----------------------------------------
        参数：
        -- k：每个grid的anchor数
        -- nin：输入特征图通道数
        -- n_clas：目标类别数
        '''
        super().__init__()
        self.k = k
        self.oconv_loc = nn.Conv2d(nin, 2*k, 3, padding=1) # bbox center
        self.oconv_conf = nn.Conv2d(nin, 1*k, 3, padding=1) # confidence
        self.oconv_clas = nn.Conv2d(nin, n_clas*k, 3, padding=1) # classification
        
    def forward(self, x):
        return (resnet_ssd_detsym.flatten_grid_anchor(self.oconv_loc(x), self.k),
                resnet_ssd_detsym.flatten_grid_anchor(self.oconv_conf(x), self.k),
                resnet_ssd_detsym.flatten_grid_anchor(self.oconv_clas(x), self.k)
               )

### ResNetIsh_SSD

In [None]:
# export
class ResNetIsh_SSD(resnet_ssd_detsym.ResNetIsh_SSD):    
    def forward(self, x):
        outs = self._forward_impl(x)
        
        locs,confs,clss = [],[],[]
        for out in outs:
            locs += [out[0]]
            confs += [out[1]]
            clss += [out[2]]
        
        return (torch.cat(locs,dim=1),
                torch.cat(confs,dim=1),
                torch.cat(clss,dim=1)
               )

### ResNetIsh_1SSD

In [None]:
# export
class ResNetIsh_1SSD(resnet_ssd_detsym.ResNetIsh_1SSD):    
    def forward(self, x):
        outs = self._forward_impl(x)
        
        locs,confs,clss = [],[],[]
        for out in outs:
            locs += [out[0]]
            confs += [out[1]]
            clss += [out[2]]
        
        return (torch.cat(locs,dim=1),
                torch.cat(confs,dim=1),
                torch.cat(clss,dim=1)
               )

In [None]:
#export 
class ResNetIsh_1SSD_fpn(ResNetIsh_1SSD):
    '''
    带fpn的
    '''
    def init_fpn(self):
        assert len(self.pred_layerIds) > 1
        #这个现在固定的就是resnet18用的
        self.fpn = torchvision.ops.FeaturePyramidNetwork([256, 512, 1024], 256)
        
    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        outs = []
        input4fpn = OrderedDict()
        for i in range(len(self.res_blocks)):
            x = self.res_blocks[i](x)
            if i in self.pred_layerIds:
                #>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
#                 outs += [self.neck_blocks[i-self.pred_layerIds[0]](x)]
                ################################
                if self.fpn is None:#没有fpn，直接用neck_block
                    neck_out = self.neck_blocks[i-self.pred_layerIds[0]](x)
                    outs += [self.head_block(neck_out)]
                else:#如果有fpn跳过neck_block
                    key = 'feat%d' % (i-self.pred_layerIds[0])
                    input4fpn[key] = x
                #<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
        if self.fpn is not None:
            fpnout = self.fpn(input4fpn)
            #先只取最终合并完成的那个特征图。这个和之前的只用一层的特征图的行为是一致的
            outs += [self.head_block(fpnout['feat0'])]
            
        return outs

## test

In [None]:
data = databunch.get_databunch(data_root='/home/dev/jupyter/detect_symbol/data/ds_20200429/', cache=False)
x,y = data.one_batch(denorm=False)

num_classes = len(data.train_ds.y.classes)-1

In [None]:
mt = ResNetIsh_1SSD_fpn(block=torchvision.models.resnet.BasicBlock,
                   layers=[2,2,2,2,2],
                   chs=[64,128,256,512,1024],
                   strides=[1,2,2,2,2],
                   pred_layerIds=[2, 3, 4],
                   num_anchors=1,
                   neck_block=resnet_ssd_detsym.cnv1x1_bn_relu,
                   head_chin=256,
                   head_block=ssd_block,
                   num_classes=num_classes)    
mt.init_fpn()
mt(x)

In [None]:
m = get_resnet18_1ssd(layers4fpn = True, num_classes = 17)

In [None]:
def dbg():
    import pdb; pdb.set_trace()
    p = m(x)
    
dbg()    

In [None]:
p[0].shape


### ResNetIsh_SSD

In [None]:
# 构建模型
m = ResNetIsh_SSD(block=torchvision.models.resnet.BasicBlock,
                  layers=[2,2,2],
                  chs=[64,128,256],
                  strides=[1,2,2],
                  pred_layerIds=[2],
                  num_anchors=[1],
                  pred_block=ssd_block,
                  num_classes=16)

In [None]:
# 跑一个batch
pred = m(x)

In [None]:
# 查看输出的形状
print(f'type(pred)={type(pred)}')
print(f'len(pred)={len(pred)}')

print('-----------------------')
for p in pred:
    print(p.shape)
    
print('-----------------------')
#print(49*49*4+25*25*3+13*13*3)
print(49*49*1)

### ResNetIsh_1SSD

In [None]:
num_classes

In [None]:
# 构建模型。不涉及到hw的事情，去掉后面的层
m = ResNetIsh_1SSD(block=torchvision.models.resnet.BasicBlock,
                   layers=[2,2,2],
                   chs=[64,128,256],
                   strides=[1,2,2],
                   pred_layerIds=[2],
                   num_anchors=1,
                   neck_block=resnet_ssd_detsym.cnv1x1_bn_relu,
                   head_chin=256,
                   head_block=ssd_block,
                   num_classes=17)

In [None]:
# 跑一个batch
pred = m(x)

In [None]:
# 查看输出的形状
print(f'type(pred)={type(pred)}')
print(f'len(pred)={len(pred)}')

print('-----------------------')
for p in pred:
    print(p.shape)
    
print('-----------------------')
#print(49*49*4+25*25*4+13*13*4)
print(49*49*1)

## zip as function

In [None]:
# export
def get_resnet18_1ssd(layers4fpn = False, num_classes = 1):
    #layers4fpn是否保留后面的两层给fpn用
    if not layers4fpn:
        return ResNetIsh_1SSD(block=torchvision.models.resnet.BasicBlock,
                   layers=[2,2,2],
                   chs=[64,128,256],
                   strides=[1,2,2],
                   pred_layerIds=[2],
                   num_anchors=1,
                   neck_block=resnet_ssd_detsym.cnv1x1_bn_relu,
                   head_chin=256,
                   head_block=ssd_block,
                   num_classes=num_classes)
    else:
        return ResNetIsh_1SSD(block=torchvision.models.resnet.BasicBlock,
                   layers=[2,2,2,2,2],
                   chs=[64,128,256,512,1024],
                   strides=[1,2,2,2,2],
                   pred_layerIds=[2, 3, 4],
                   num_anchors=1,
                   neck_block=resnet_ssd_detsym.cnv1x1_bn_relu,
                   head_chin=256,
                   head_block=ssd_block,
                   num_classes=num_classes)

In [None]:
# export
def get_resnet18_ssd(layers4fpn = False, num_classes = 1):
    if not layers4fpn:
        return ResNetIsh_SSD(block=torchvision.models.resnet.BasicBlock,
                  layers=[2,2,2],
                  chs=[64,128,256],
                  strides=[1,2,2],
                  pred_layerIds=[2],
                  num_anchors=[1],
                  pred_block=ssd_block,
                  num_classes=num_classes)
    else:
        return ResNetIsh_SSD(block=torchvision.models.resnet.BasicBlock,
                  layers=[2,2,2,2,2],
                  chs=[64,128,256,512,1024],
                  strides=[1,2,2,2,2],
                  pred_layerIds=[2, 3, 4],
                  num_anchors=[1, 1, 1],
                  pred_block=ssd_block,
                  num_classes=num_classes)

In [None]:
get_resnet18_1ssd(layers4fpn = True)

In [None]:
get_resnet18_1ssd(layers4fpn = False)

In [None]:
get_resnet18_ssd(layers4fpn = True)

In [None]:
get_resnet18_ssd(layers4fpn = False)

In [None]:
import torchvision

In [None]:
from torchvision.ops import *

In [None]:
from collections import *

In [None]:
import torch

In [None]:
#feature_pyramid_network.py

In [None]:
m = torchvision.ops.FeaturePyramidNetwork([10, 20, 30], 5)
# get some dummy data
x = OrderedDict()
x['feat0'] = torch.rand(1, 10, 64, 64)
x['feat2'] = torch.rand(1, 20, 16, 16)
x['feat3'] = torch.rand(1, 30, 8, 8)
#x['feat4'] = torch.rand(1, 20, 8, 8)
# compute the FPN on top of x
output = m(x)
print([(k, v.shape) for k, v in output.items()])

### tmp

In [None]:
model3 = get_resnet18_ssd_std(layers4fpn = True, num_classes = 17)
model3

# export

In [None]:
!python ../notebook2script.py --fname 'resnet_ssd.ipynb' --outputDir '../exp/'