In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision


from stdcnet import STDCNet1446, STDCNet813
# from modules.bn import InPlaceABNSync as BatchNorm2d
BatchNorm2d = nn.BatchNorm2d

class ConvBNReLU(nn.Module):
    def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_chan,
                out_chan,
                kernel_size = ks,
                stride = stride,
                padding = padding,
                bias = False)
        # self.bn = BatchNorm2d(out_chan)
        self.bn = BatchNorm2d(out_chan)
        self.relu = nn.ReLU()
        self.init_weight()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)


class BiSeNetOutput(nn.Module):
    def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
        super(BiSeNetOutput, self).__init__()
        self.conv = ConvBNReLU(in_chan, mid_chan, ks=1, stride=1, padding=0)
        self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
        self.init_weight()

    def forward(self, x):
        x = self.conv(x)
        x = self.conv_out(x)
        return x

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class AttentionRefinementModule(nn.Module):
    def __init__(self, in_chan, out_chan, *args, **kwargs):
        super(AttentionRefinementModule, self).__init__()
        self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
        self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
        # self.bn_atten = BatchNorm2d(out_chan)
        self.bn_atten = BatchNorm2d(out_chan)

        self.sigmoid_atten = nn.Sigmoid()
        self.init_weight()

    def forward(self, x):
        feat = self.conv(x)
        atten = F.avg_pool2d(feat, feat.size()[2:])
        atten = self.conv_atten(atten)
        atten = self.bn_atten(atten)
        atten = self.sigmoid_atten(atten)
        out = torch.mul(feat, atten)
        return out

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)


class ContextPath(nn.Module):
    def __init__(self, backbone='CatNetSmall', pretrain_model='', use_conv_last=False, *args, **kwargs):
        super(ContextPath, self).__init__()
        
        self.backbone_name = backbone
        if backbone == 'STDCNet1446':
            self.backbone = STDCNet1446(pretrain_model=pretrain_model, use_conv_last=use_conv_last)
            self.arm16 = AttentionRefinementModule(512, 128)
            inplanes = 1024
            if use_conv_last:
                inplanes = 1024
            self.arm32 = AttentionRefinementModule(inplanes, 128)
            self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
            self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
            self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0)

        elif backbone == 'STDCNet813':
            self.backbone = STDCNet813(pretrain_model=pretrain_model, use_conv_last=use_conv_last)
            self.arm16 = AttentionRefinementModule(512, 128)
            inplanes = 1024
            if use_conv_last:
                inplanes = 1024
            self.arm32 = AttentionRefinementModule(inplanes, 128)
            self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
            self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
            self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0)
        else:
            print("backbone is not in backbone lists")
            exit(0)

        self.init_weight()

    def forward(self, x):
        H0, W0 = x.size()[2:]
        print("input_size:")
        print(x.size())

        feat2, feat4, feat8, feat16, feat32 = self.backbone(x)
        print("after_backbone:feat2,feat4,feat8,feat16,feat32")
        print(feat2.size(),feat4.size(),feat8.size(),feat16.size(),feat32.size())
        H8, W8 = feat8.size()[2:]
        H16, W16 = feat16.size()[2:]
        H32, W32 = feat32.size()[2:]
        
        avg = F.avg_pool2d(feat32, feat32.size()[2:])
        print("after_avg:avg")
        print(avg.size())

        avg = self.conv_avg(avg)
        print("after_conv_avg:avg")
        print(avg.size())
        
        avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
        print("after_upsample1:avg_up")
        print(avg_up.size())


        feat32_arm = self.arm32(feat32)
        print("after_attention:feat32_arm")
        print(feat32_arm.size())
        feat32_sum = feat32_arm + avg_up
        print("after_sum:feat32_sum")
        print(feat32_sum.size())
        feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
        print("after_upsample2:feat32_up")
        print(feat32_up.size())
        feat32_up = self.conv_head32(feat32_up)
        print("after_conv_head32:feat32_up")
        print(feat32_up.size())

        feat16_arm = self.arm16(feat16)
        print("after_attention:feat16_arm")
        print(feat16_arm.size())
        feat16_sum = feat16_arm + feat32_up
        print("after_sum:feat16_sum")
        print(feat16_sum.size())
        feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
        print("after_upsample3:feat16_up")
        print(feat16_up.size())
        feat16_up = self.conv_head16(feat16_up)
        print("after_conv_head16:feat16_up")
        print(feat16_up.size())
        print("contextpath end")
        
        return feat2, feat4, feat8, feat16, feat16_up, feat32_up # x8, x16

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class FeatureFusionModule(nn.Module):
    def __init__(self, in_chan, out_chan, *args, **kwargs):
        super(FeatureFusionModule, self).__init__()
        self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
        self.conv1 = nn.Conv2d(out_chan,
                out_chan//4,
                kernel_size = 1,
                stride = 1,
                padding = 0,
                bias = False)
        self.conv2 = nn.Conv2d(out_chan//4,
                out_chan,
                kernel_size = 1,
                stride = 1,
                padding = 0,
                bias = False)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()
        self.init_weight()

    def forward(self, fsp, fcp):
        fcat = torch.cat([fsp, fcp], dim=1)
        feat = self.convblk(fcat)
        atten = F.avg_pool2d(feat, feat.size()[2:])
        atten = self.conv1(atten)
        atten = self.relu(atten)
        atten = self.conv2(atten)
        atten = self.sigmoid(atten)
        feat_atten = torch.mul(feat, atten)
        feat_out = feat_atten + feat
        return feat_out

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class BiSeNet(nn.Module):
    def __init__(self, backbone, n_classes, pretrain_model='', use_boundary_2=False, use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False, heat_map=False, *args, **kwargs):
        super(BiSeNet, self).__init__()
        
        self.use_boundary_2 = use_boundary_2
        self.use_boundary_4 = use_boundary_4
        self.use_boundary_8 = use_boundary_8
        self.use_boundary_16 = use_boundary_16
        # self.heat_map = heat_map
        self.cp = ContextPath(backbone, pretrain_model, use_conv_last=use_conv_last)
        
        
        
        if backbone == 'STDCNet1446':
            conv_out_inplanes = 128
            sp2_inplanes = 32
            sp4_inplanes = 64
            sp8_inplanes = 256
            sp16_inplanes = 512
            inplane = sp8_inplanes + conv_out_inplanes

        elif backbone == 'STDCNet813':
            conv_out_inplanes = 128
            sp2_inplanes = 32
            sp4_inplanes = 64
            sp8_inplanes = 256
            sp16_inplanes = 512
            inplane = sp8_inplanes + conv_out_inplanes

        else:
            print("backbone is not in backbone lists")
            exit(0)

        self.ffm = FeatureFusionModule(inplane, 256)
        self.conv_out = BiSeNetOutput(256, 64, n_classes)
        self.conv_out16 = BiSeNetOutput(conv_out_inplanes, 64, n_classes)
        self.conv_out32 = BiSeNetOutput(conv_out_inplanes, 64, n_classes)

        self.conv_out_sp16 = BiSeNetOutput(sp16_inplanes, 64, 1)
        
        self.conv_out_sp8 = BiSeNetOutput(sp8_inplanes, 64, 1)
        self.conv_out_sp4 = BiSeNetOutput(sp4_inplanes, 64, 1)
        self.conv_out_sp2 = BiSeNetOutput(sp2_inplanes, 64, 1)
        
        self.conv_out_ajchan2 = ConvBNReLU(sp2_inplanes, n_classes, ks=3, stride=1, padding=1)
        self.conv_out_ajchan4 = ConvBNReLU(sp4_inplanes, n_classes, ks=3, stride=1, padding=1)
        
#         self.CBR4 = ConvBNReLU(n_classes, n_classes, ks=3, stride=1, padding=1)
#         self.CBR2 = ConvBNReLU(n_classes, n_classes, ks=3, stride=1, padding=1)
        self.init_weight()

    def forward(self, x):
        H, W = x.size()[2:]
        
        feat_res2, feat_res4, feat_res8, feat_res16, feat_cp8, feat_cp16 = self.cp(x)

        feat_out_sp2 = self.conv_out_sp2(feat_res2)
        print("feat_out_sp2:")
        print(feat_out_sp2.size())

        feat_out_sp4 = self.conv_out_sp4(feat_res4)
        print("feat_out_sp4:")
        print(feat_out_sp4.size())
  
        feat_out_sp8 = self.conv_out_sp8(feat_res8)
        print("feat_out_sp8:")
        print(feat_out_sp8.size())

        feat_out_sp16 = self.conv_out_sp16(feat_res16)
        print("feat_out_sp16:")
        print(feat_out_sp16.size())

        feat_fuse = self.ffm(feat_res8, feat_cp8)
        print("feat_fuse:")
        print(feat_fuse.size())

        
        feat_out = self.conv_out(feat_fuse)
        print("feat_out:")
        print(feat_out.size())
        feat_out16 = self.conv_out16(feat_cp8)
        print("feat_out16:")
        print(feat_out16.size())
        feat_out32 = self.conv_out32(feat_cp16)
        print("feat_out32:")
        print(feat_out32.size())
        
        feat_out = F.interpolate(feat_out, (H//4, W//4), mode='bilinear', align_corners=True)
        feat_temp8 = feat_out+self.conv_out_ajchan4(feat_res4)
        feat_out = F.interpolate(feat_temp8, (H//2, W//2), mode='bilinear', align_corners=True)
        feat_temp8 = feat_out+self.conv_out_ajchan2(feat_res2)
        feat_out = F.interpolate(feat_temp8, (H, W), mode='bilinear', align_corners=True)
        print("after_upsample:feat_out:")
        print(feat_out.size())
        
        feat_out16 = F.interpolate(feat_out16, (H//4, W//4), mode='bilinear', align_corners=True)
        feat_temp16 = feat_out16+self.conv_out_ajchan4(feat_res4)
        feat_out = F.interpolate(feat_temp16, (H//2, W//2), mode='bilinear', align_corners=True)
        feat_temp16 = feat_out+self.conv_out_ajchan2(feat_res2)
        feat_out = F.interpolate(feat_temp16, (H, W), mode='bilinear', align_corners=True)
        print("after_upsample:feat_out16:")
        print(feat_out.size())
        
        feat_out32 = F.interpolate(feat_out32, (H//4, W//4), mode='bilinear', align_corners=True)
        feat_temp32 = feat_out32+self.conv_out_ajchan4(feat_res4)
        feat_out = F.interpolate(feat_temp32, (H//2, W//2), mode='bilinear', align_corners=True)
        feat_temp32 = feat_out+self.conv_out_ajchan2(feat_res2)
        feat_out = F.interpolate(feat_temp32, (H, W), mode='bilinear', align_corners=True)
        print("after_upsample:feat_out32:")
        print(feat_out.size())


        if self.use_boundary_2 and self.use_boundary_4 and self.use_boundary_8:
            return feat_out, feat_out16, feat_out32, feat_out_sp2, feat_out_sp4, feat_out_sp8
        
        if (not self.use_boundary_2) and self.use_boundary_4 and self.use_boundary_8:
            return feat_out, feat_out16, feat_out32, feat_out_sp4, feat_out_sp8

        if (not self.use_boundary_2) and (not self.use_boundary_4) and self.use_boundary_8:
            return feat_out, feat_out16, feat_out32, feat_out_sp8
        
        if (not self.use_boundary_2) and (not self.use_boundary_4) and (not self.use_boundary_8):
            return feat_out, feat_out16, feat_out32

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
        for name, child in self.named_children():
            child_wd_params, child_nowd_params = child.get_params()
            if isinstance(child, (FeatureFusionModule, BiSeNetOutput)):
                lr_mul_wd_params += child_wd_params
                lr_mul_nowd_params += child_nowd_params
            else:
                wd_params += child_wd_params
                nowd_params += child_nowd_params
        return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params


In [2]:
if __name__ == "__main__":
    
    net = BiSeNet('STDCNet1446', 19)
#     net.cuda()
    net.eval()
    in_ten = torch.randn(1, 3, 768, 1536)
    out, out16, out32 = net(in_ten)
    print(out.shape)
    torch.save(net.state_dict(), 'STDCNet1446_modify_conv1layeradd.pth')

input_size:
torch.Size([1, 3, 768, 1536])
feat4:
torch.Size([1, 64, 192, 384])
feat8:
torch.Size([1, 256, 96, 192])
feat16tmp:
torch.Size([1, 256, 48, 96])
feat16:
torch.Size([1, 512, 48, 96])
feat32:
torch.Size([1, 1024, 24, 48])
after_backbone:feat2,feat4,feat8,feat16,feat32
torch.Size([1, 32, 384, 768]) torch.Size([1, 64, 192, 384]) torch.Size([1, 256, 96, 192]) torch.Size([1, 512, 48, 96]) torch.Size([1, 1024, 24, 48])
after_avg:avg
torch.Size([1, 1024, 1, 1])
after_conv_avg:avg
torch.Size([1, 128, 1, 1])
after_upsample1:avg_up
torch.Size([1, 128, 24, 48])
after_attention:feat32_arm
torch.Size([1, 128, 24, 48])
after_sum:feat32_sum
torch.Size([1, 128, 24, 48])
after_upsample2:feat32_up
torch.Size([1, 128, 48, 96])
after_conv_head32:feat32_up
torch.Size([1, 128, 48, 96])
after_attention:feat16_arm
torch.Size([1, 128, 48, 96])
after_sum:feat16_sum
torch.Size([1, 128, 48, 96])
after_upsample3:feat16_up
torch.Size([1, 128, 96, 192])
after_conv_head16:feat16_up
torch.Size([1, 128, 96, 1

In [1]:
for name, param in net.named_parameters():
    print(name, param.size())

NameError: name 'net' is not defined

In [None]:
class baseblock(nn.Module):
    def __init__(self, in_planes, out_planes, block_num=3, stride=1):
        super(CatBottleneck, self).__init__()
        self.conv_list = nn.ModuleList()
        
        for idx in range(block_num):
            if idx == 0:
                self.conv_list.append(ConvX(in_planes, out_planes//2, kernel=1))

In [None]:
class Conv3X(nn.Module):
    def __init__(self, in_planes, out_planes, kernel=3, stride=1):
        super(Conv3X, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel, stride=stride, padding=kernel//2, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.relu(self.bn(self.conv(x)))
        return out
class Conv1X(nn.Module):
    def __init__(self, in_planes, out_planes, stride=1):
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        out = self.relu(self.bn(self.conv(x)))

In [6]:
x=torch.arange(24).view(1,2,4,3)
'''
注意：在这里使用的时候转一下类型，否则会报RuntimeError: Can only calculate the mean of floating types. Got Long instead.的错误。
查看了一下x元素类型是torch.int64,根据提示添加一句x=x.float()转为tensor.float32就行
'''
x=x.float()
x_mean=torch.mean(x)
# x_mean0=torch.mean(x,dim=0,keepdim=True)
x_mean1=torch.mean(x,dim=1,keepdim=True)
print('x:')
print(x)
# print('x_mean0:')
# print(x_mean0)
print('x_mean1:')
print(x_mean1)
print(x_mean1.size())
print('x_mean:')
print(x_mean)

x:
tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.],
          [ 6.,  7.,  8.],
          [ 9., 10., 11.]],

         [[12., 13., 14.],
          [15., 16., 17.],
          [18., 19., 20.],
          [21., 22., 23.]]]])
x_mean1:
tensor([[[[ 6.,  7.,  8.],
          [ 9., 10., 11.],
          [12., 13., 14.],
          [15., 16., 17.]]]])
torch.Size([1, 1, 4, 3])
x_mean:
tensor(11.5000)


In [15]:
state_dict = torch.load("/xiaoou/STDC-Seg-master/STDC-Seg-master/checkpoint/train_STDC2-Seg_depthwise51/pths/model_final.pth")
# model_resnet101.load_state_dict({k.replace('module.',''):v for k,v in torch.load("densenet169_rnn_fold_1_model_best_f1.pth.tar")['state_dict'].items()})
# state_dict = {key: value for key, value in state_dict.items() if key < 584}
# model = BiSeNet('STDCNet1446', 19)
new_state_dict = {}
sub_list = list(state_dict.keys())[:673]
for k,v in state_dict.items():
    if k in sub_list:
        name = k.replace('cp.backbone.','')
        new_state_dict[name] = v
for k,v in new_state_dict.items():
    print(k)
# model_resnet101.load_state_dict({k.replace('module.',''):v for k,v in torch.load("densenet169_rnn_fold_1_model_best_f1.pth.tar")['state_dict'].items()})


features.0.conv.weight
features.0.bn.weight
features.0.bn.bias
features.0.bn.running_mean
features.0.bn.running_var
features.0.bn.num_batches_tracked
features.1.conv.weight
features.1.bn.weight
features.1.bn.bias
features.1.bn.running_mean
features.1.bn.running_var
features.1.bn.num_batches_tracked
features.2.conv_list.0.conv.weight
features.2.conv_list.0.bn.weight
features.2.conv_list.0.bn.bias
features.2.conv_list.0.bn.running_mean
features.2.conv_list.0.bn.running_var
features.2.conv_list.0.bn.num_batches_tracked
features.2.conv_list.1.conv.weight
features.2.conv_list.1.bn.weight
features.2.conv_list.1.bn.bias
features.2.conv_list.1.bn.running_mean
features.2.conv_list.1.bn.running_var
features.2.conv_list.1.bn.num_batches_tracked
features.2.conv_list.2.conv.weight
features.2.conv_list.2.bn.weight
features.2.conv_list.2.bn.bias
features.2.conv_list.2.bn.running_mean
features.2.conv_list.2.bn.running_var
features.2.conv_list.2.bn.num_batches_tracked
features.2.conv_list.3.conv.weight

In [5]:
state_dict = torch.load("/xiaoou/STDC-Seg-master/STDC-Seg-master/STDC-Seg-weight/STDCNet1446_76.47.tar",map_location='cpu')["state_dict"]
model = BiSeNet('STDCNet1446', 19)
for k,v in state_dict.items():
    print(k) 

features.0.conv.weight
features.0.bn.weight
features.0.bn.bias
features.0.bn.running_mean
features.0.bn.running_var
features.0.bn.num_batches_tracked
features.1.conv.weight
features.1.bn.weight
features.1.bn.bias
features.1.bn.running_mean
features.1.bn.running_var
features.1.bn.num_batches_tracked
features.2.conv_list.0.conv.weight
features.2.conv_list.0.bn.weight
features.2.conv_list.0.bn.bias
features.2.conv_list.0.bn.running_mean
features.2.conv_list.0.bn.running_var
features.2.conv_list.0.bn.num_batches_tracked
features.2.conv_list.1.conv.weight
features.2.conv_list.1.bn.weight
features.2.conv_list.1.bn.bias
features.2.conv_list.1.bn.running_mean
features.2.conv_list.1.bn.running_var
features.2.conv_list.1.bn.num_batches_tracked
features.2.conv_list.2.conv.weight
features.2.conv_list.2.bn.weight
features.2.conv_list.2.bn.bias
features.2.conv_list.2.bn.running_mean
features.2.conv_list.2.bn.running_var
features.2.conv_list.2.bn.num_batches_tracked
features.2.conv_list.3.conv.weight

In [11]:
! ls

STDCNet1446.pth			      __init__.py
STDCNet1446_modify.pth		      __pycache__
STDCNet1446_modify_ca.pth	      model_stages.py
STDCNet1446_modify_conv.pth	      model_stages_trt.py
STDCNet1446_modify_conv1layeradd.pth  modeltest.ipynb
STDCNet1446_modify_convadd.pth	      stdcnet.py
STDCNet1446_modify_convlayeradd.pth


In [12]:
from torchsummary import summary
from model_stages import BiSeNet
# models = BiSeNet()
summary(BiSeNet, input_size=(3, 768, 1536))

ModuleNotFoundError: No module named 'nets'

In [11]:
pip install torchsummary

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting torchsummary
  Downloading https://files.pythonhosted.org/packages/7d/18/1474d06f721b86e6a9b9d7392ad68bed711a02f3b61ac43f13c719db50a6/torchsummary-1.5.1-py3-none-any.whl
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
You should consider upgrading via the 'pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [13]:
import torch
import torch.nn as nn
from torch.nn import init
import math
from resnet_block import Bottleneck
BatchNorm2d = nn.BatchNorm2d

# class depthwise_separable_conv(nn.Module):
#     def init(self, nin, nout):
#         super(depthwise_separable_conv, self).init()
#         self.depthwise = nn.Conv2d(nin, nin, kernel_size=3, padding=1, groups=nin)
#         self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)
#     def forward(self, x):
#         out = self.depthwise(x)
#         out = self.pointwise(out)
#         return out
# class ConvBNReLU(nn.Module):
#     def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
#         super(ConvBNReLU, self).__init__()
#         self.conv = nn.Conv2d(in_chan,
#                 out_chan,
#                 kernel_size = ks,
#                 stride = stride,
#                 padding = padding,
#                 bias = False)
#         # self.bn = BatchNorm2d(out_chan)
#         self.bn = BatchNorm2d(out_chan)
#         self.relu = nn.ReLU()
# #         self.init_weight()

#     def forward(self, x):
#         x = self.conv(x)
#         x = self.bn(x)
#         x = self.relu(x)
#         return x
class ConvX(nn.Module):
    def __init__(self, in_planes, out_planes, kernel=3, stride=1):
        super(ConvX, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel, stride=stride, padding=kernel//2, bias=False)
#         self.depthwise = nn.Conv2d(in_planes, in_planes, kernel_size=3, padding=1, groups=in_planes)
#         self.pointwise = nn.Conv2d(in_planes, out_planes, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.relu(self.bn(self.conv(x)))
#         out = self.relu(self.bn(self.pointwise(self.depthwise(x))))
        return out


class AddBottleneck(nn.Module):
    def __init__(self, in_planes, out_planes, block_num=3, stride=1):
        super(AddBottleneck, self).__init__()
        assert block_num > 1, print("block number should be larger than 1.")
        self.conv_list = nn.ModuleList()
        self.stride = stride
        if stride == 2:
            self.avd_layer = nn.Sequential(
                nn.Conv2d(out_planes//2, out_planes//2, kernel_size=3, stride=2, padding=1, groups=out_planes//2, bias=False),
                nn.BatchNorm2d(out_planes//2),
            )
            self.skip = nn.Sequential(
                nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=2, padding=1, groups=in_planes, bias=False),
                nn.BatchNorm2d(in_planes),
                nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_planes),
            )
            stride = 1

        for idx in range(block_num):
            if idx == 0:
                self.conv_list.append(ConvX(in_planes, out_planes//2, kernel=1))
            elif idx == 1 and block_num == 2:
                self.conv_list.append(ConvX(out_planes//2, out_planes//2, stride=stride))
            elif idx == 1 and block_num > 2:
                self.conv_list.append(ConvX(out_planes//2, out_planes//4, stride=stride))
            elif idx < block_num - 1:
                self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx+1))))
            else:
                self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx))))
            
    def forward(self, x):
        out_list = []
        out = x

        for idx, conv in enumerate(self.conv_list):
            if idx == 0 and self.stride == 2:
                out = self.avd_layer(conv(out))
            else:
                out = conv(out)
            out_list.append(out)

        if self.stride == 2:
            x = self.skip(x)

        return torch.cat(out_list, dim=1) + x

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)
class channal_shuffle(nn.Module):
    def __init__(self,groups=4):
        super(channal_shuffle, self).__init__()
        self.groups = groups
    def forward(self, x):
        batchsize, num_channels, height, width = x.data.size()
        channels_per_group = num_channels // self.groups
        # grouping, 通道分组
        # b, num_channels, h, w =======>  b, groups, channels_per_group, h, w
        x = x.view(batchsize, self.groups, channels_per_group, height, width)
        # channel shuffle, 通道洗牌
        x = torch.transpose(x, 1, 2).contiguous()
        # x.shape=(batchsize, channels_per_group, groups, height, width)
        # flatten
        x = x.view(batchsize, -1, height, width)
        return x
class CatBottleneck(nn.Module):
    def __init__(self, in_planes, out_planes, block_num=3, stride=1):
        super(CatBottleneck, self).__init__()
        assert block_num > 1, print("block number should be larger than 1.")
        self.conv_list = nn.ModuleList()
        self.stride = stride
#         self.atenchannal1 = ChannelAttention()
#         self.atenchannal2 = SpatialAttention(1024)
        self.channal_shuffle = channal_shuffle(4)
        if stride == 2:
            self.avd_layer = nn.Sequential(
                nn.Conv2d(out_planes//2, out_planes//2, kernel_size=3, stride=2, padding=1, groups=out_planes//2, bias=False),
                nn.BatchNorm2d(out_planes//2),
            )
            self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
            stride = 1
        self.avg = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        self.relu = nn.ReLU(inplace=True)
#         self.conavg_layer = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0,bias=False)
#         self.conavg_layer_BN = nn.BatchNorm2d(out_planes)

#         self.conavg_layer = nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0,bias=False),
#                                           nn.BatchNorm2d(out_planes))

        self.conavg_layer = nn.Sequential(nn.Conv2d(in_planes, out_planes//4, kernel_size=1, stride=1, padding=0,bias=False),
                                          nn.BatchNorm2d(out_planes//4),
#                                           nn.ReLU(),
                                          nn.Conv2d(out_planes//4, out_planes//4, kernel_size=3, stride=1, padding=1,bias=False),
                                         nn.BatchNorm2d(out_planes//4),
#                                           nn.ReLU(),
                                         nn.Conv2d(out_planes//4, out_planes, kernel_size=1, stride=1, padding=0,bias=False),
                                         nn.BatchNorm2d(out_planes),nn.Sigmoid())
    
#         self.channal_link = nn.Sequential(nn.Conv2d(out_planes, out_planes, kernel_size=1, stride=1, padding=0,bias=False),
#                                           nn.BatchNorm2d(out_planes))
        

        for idx in range(block_num):
            if idx == 0:
                self.conv_list.append(ConvX(in_planes, out_planes//2, kernel=1))
#                 self.conv_list.append()
#                 self.conv_list.append(ChannelAttention(out_planes//2))
            elif idx == 1 and block_num == 2:
                self.conv_list.append(ConvX(out_planes//2, out_planes//2, stride=stride))

#                 self.conv_list.append(ChannelAttention(out_planes//2))
            elif idx == 1 and block_num > 2:
                self.conv_list.append(ConvX(out_planes//2, out_planes//4, stride=stride))

#                 self.conv_list.append(ChannelAttention(out_planes//4))
            elif idx < block_num - 1:
                self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx+1))))

#                 self.conv_list.append(ChannelAttention(out_planes//int(math.pow(2, idx+1))))
            else:
                self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx))))
              
#                 self.conv_list.append(ChannelAttention(out_planes//int(math.pow(2, idx))))
                                      
    def forward(self, x):
        out_list = []
        out1 = self.conv_list[0](x)
#         out_atten = self.atenchannal2(out1)
        if self.stride==2:
            x = self.avg(x)
#         out_temp = self.atenchannal2(self.conavg_layer(x))
        out_temp = self.conavg_layer(x)
#         out_temp = self.atenchannal2(out_temp)*out_temp
#         out_temp = self.conavg_layer_BN(out_temp)
        
        
#         print("out_temp:")
#         print(out_temp.size())

        for idx, conv in enumerate(self.conv_list[1:]):
            if idx == 0:
                if self.stride == 2:
                    out = conv(self.avd_layer(out1))
#                     print("out2:")
#                     print(out.size())
#                     out = self.conv_list[2*idx+3](out1)
#                     print("out2:")
#                     print(out.size())
                else:
                    out = conv(out1)
#                     out = self.conv_list[2*idx+3](out1)+out
#                     print("out2:")
#                     print(out.size())
            else:
                out = conv(out)
#                 out = atten + out
#                 out = (self.conv_list[2*idx+3](out))+out_temp
#                 print("out:")
#                 print(out.size())
            out_list.append(out)

        if self.stride == 2:
            out1 = self.skip(out1)
        out_list.insert(0, out1)

        out = torch.cat(out_list, dim=1)
#         print("output:")
#         print(out.size())
#         out = self.channal_link(out)
        out = out+out_temp
#         out = self.channal_shuffle(out)
#         out = self.atenchannal2(out)*out
#         out = outlist[]
        return out

#STDC2Net
class STDCNet1446(nn.Module):
    def __init__(self, block_res,layers_res,base=64, layers=[2,2,2], block_num=4, type="cat", num_classes=1000, dropout=0.20, pretrain_model='', use_conv_last=False):
        super(STDCNet1446, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3,bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(64)
#         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block_res, 64, layers_res[0],stride=2)
        self.layer2 = self._make_layer(block_res, 256, layers_res[1], stride=2)
        self.layer3 = self._make_layer(block_res, 512, layers_res[2], stride=2)
        self.layer4 = self._make_layer(block_res, 1024, layers_res[3], stride=2)
#         if type == "cat":
#             block = CatBottleneck
#         elif type == "add":
#             block = AddBottleneck
        self.use_conv_last = use_conv_last
#         self.features = self._make_layers(base, layers, block_num, block)
#         self.conv_last = ConvX(base*16, max(1024, base*16), 1, 1)
#         self.gap = nn.AdaptiveAvgPool2d(1)
#         self.fc = nn.Linear(max(1024, base*16), max(1024, base*16), bias=False)
#         self.bn = nn.BatchNorm1d(max(1024, base*16))
#         self.relu = nn.ReLU(inplace=True)
#         self.dropout = nn.Dropout(p=dropout)
#         self.linear = nn.Linear(max(1024, base*16), num_classes, bias=False)

#         self.atenchannal1 = ChannelAttention(1024)
#         self.atenchannal2 = SpatialAttention()
#         self.atenchannal2_3 = ChannelAttention(256)
#         self.atenchannal2_4 = ChannelAttention(512)
#         self.atenchannal2_5 = ChannelAttention(1024)
#         self.spatten = SpatialAttention()
#         self.trans4 = ConvX(base,base*2,3,2)
#         self.trans8 = ConvX(base*2,base*4,3,2)
#         self.trans16 = ConvX(base*4,base*8,3,2)
#         self.concatfuse8 = ConvX(base*4, base*4)
#         self.concatfuse16 = ConvX(base*8, base*8)
#         self.concatfuse32 = ConvX(base*16, base*16)

#         self.x2 = nn.Sequential(self.features[:1])
#         self.x4 = nn.Sequential(self.features[1:2])
#         self.x8 = nn.Sequential(self.features[2:6])
#         self.t256 = nn.Sequential(self.features[6])
#         self.addlayer7 = nn.Sequential(self.features[7])
#         self.x16 = nn.Sequential(self.features[8:13])
#         self.t512 = nn.Sequential(self.features[13])
#         self.addlayer14 = nn.Sequential(self.features[14])
#         self.x32 = nn.Sequential(self.features[15:18])
#         self.t1024 = nn.Sequential(self.features[18])
#         self.addlayer19 = nn.Sequential(self.features[19])

#         self.x2 = nn.Sequential(self.features[:1])
#         self.x4 = nn.Sequential(self.features[1:2])
#         self.x8 = nn.Sequential(self.features[2:7])
#         self.t256 = nn.Sequential(self.features[7])
#         self.addlayer7 = nn.Sequential(self.features[8])
#         self.x16 = nn.Sequential(self.features[9:15])
#         self.t512 = nn.Sequential(self.features[15])
#         self.addlayer14 = nn.Sequential(self.features[16])
#         self.x32 = nn.Sequential(self.features[17:21])
#         self.t1024 = nn.Sequential(self.features[21])
#         self.addlayer19 = nn.Sequential(self.features[22])

#         self.x2 = nn.Sequential(self.features[:1])
#         self.x4 = nn.Sequential(self.features[1:2])
#         self.x8 = nn.Sequential(self.features[2:6])
#         self.x16 = nn.Sequential(self.features[6:11])
#         self.x32 = nn.Sequential(self.features[11:])

#         self.x2 = nn.Sequential(self.features[:1])
#         self.x4 = nn.Sequential(self.features[1:2])
#         self.x8 = nn.Sequential(self.features[2:5])
#         self.x16 = nn.Sequential(self.features[5:9])
#         self.x32 = nn.Sequential(self.features[9:])

#         self.x2 = nn.Sequential(self.features[:1])
#         self.x4 = nn.Sequential(self.features[1:2])
#         self.x8 = nn.Sequential(self.features[2:4])
#         self.x16 = nn.Sequential(self.features[4:6])
#         self.x32 = nn.Sequential(self.features[6:])
        
#         self.x2 = nn.Sequential(self.features[:1])
#         self.x4 = nn.Sequential(self.features[1:2])
#         self.x8 = nn.Sequential(self.features[2:7])
#         self.x16 = nn.Sequential(self.features[7:13])
#         self.x32 = nn.Sequential(self.features[13:])
        

        if pretrain_model:
#             print('use pretrain model {}'.format(pretrain_model))
#             self.init_weight(pretrain_model)
#         else:
            self.init_params()

    def init_weight(self, pretrain_model):
        
        state_dict = torch.load(pretrain_model)
        self_state_dict = self.state_dict()
        sub_list = list(state_dict.keys())[:295]
        for k, v in state_dict.items():
            if k in sub_list:
                k = k.replace('cp.backbone.','')
                self_state_dict.update({k: v})
        self.load_state_dict(self_state_dict)

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

#     def _make_layers(self, base, layers, block_num, block):
#         features = []
#         features += [ConvX(3, base//2, 3, 2)]
#         features += [ConvX(base//2, base, 3, 2)]

#         for i, layer in enumerate(layers):
#             for j in range(layer):
#                 if i == 0 and j == 0:
#                     features.append(block(base, base*4, block_num, 2))
#                 elif j == 0:
#                     features.append(block(base*int(math.pow(2,i+1)), base*int(math.pow(2,i+2)), block_num, 2))
# #                 elif j+2==layer:
# #                     features.append(block(base*int(math.pow(2,i+1)),base*int(math.pow(2,i+2))))
# #                 elif j+2>layer:
# #                     features.append(block(base*int(math.pow(2,i+2)),base*int(math.pow(2,i+2))))
#                 else:
#                     features.append(block(base*int(math.pow(2,i+2)), base*int(math.pow(2,i+2)), block_num, 1))
# #             features.append(block(base*int(math.pow(2,i+1)),base*int(math.pow(2,i+2))))

#         return nn.Sequential(*features)
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )
 
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))
 
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
#         x = self.maxpool(x)
        
#         feat2 = self.x2(x)
        feat2 = x
        print(feat2.size())
        x = self.relu(self.bn2(self.conv2(x)))
#         feat4 = self.x4(feat2)
        feat4 = self.layer1(x)
        print(feat4.size())
#         print("feat4:")
#         print(feat4.size())
        
#         trans4 = self.trans4(feat4)
        
#         feat8_temp = self.x8(feat4)
#         feat8tmp = self.x8(feat4)
#         feat8 = self.x8(feat4)
        feat8 = self.layer2(feat4)
        print(feat8.size())
#         feat8 = self.addlayer7(self.t256(feat8tmp))
#         feat8_ori = self.addlayer7(self.t256(feat8tmp))
#         feat8 = self.atenchannal2_3(feat8_ori)*feat8_ori
#         feat8 = self.spatten(feat8)*feat8
#         feat8 = feat8+feat8_ori
#         feat8 = self.t256(feat8tmp)
#         feat8 = torch.cat((trans4,feat8_temp),dim=1)
#         feat8 = self.concatfuse8(feat8)
#         print("feat8:")
#         print(feat8.size())
        
#         trans8 = self.trans8(feat8_temp)
        
#         feat16tmp = self.x16(feat8tmp)
#         feat16 = self.x16(feat8)
        feat16 = self.layer3(feat8)
        print(feat16.size())
#         print("feat16tmp:")
#         print(feat16tmp.size())
#         feat16 = self.addlayer14(self.t512(feat16tmp))
#         feat16_ori = self.addlayer14(self.t512(feat16tmp))
#         feat16 = self.atenchannal2_4(feat16_ori)*feat16_ori
#         feat16 = self.spatten(feat16)*feat16
#         feat16 = feat16+feat16_ori
#         feat16 = self.t512(feat16tmp)
#         feat16_temp = self.x16(feat8_temp)
#         feat16 = torch.cat((trans8,feat16_temp),dim=1)
#         feat16 = self.concatfuse16(feat16)
#         print("feat16:")
#         print(feat16.size())
        
#         trans16 = self.trans16(feat16_temp)
        
#         feat32tmp = self.x32(feat16tmp)
#         feat32 = self.x32(feat16)
        feat32 = self.layer4(feat16)
        print(feat32.size())
#         feat32 = self.atenchannal1(feat32)*feat32
#         feat32 = self.atenchannal2(feat32)*feat32
#         feat32 = self.addlayer19(self.t1024(feat32tmp))
#         feat32_ori = self.addlayer19(self.t1024(feat32tmp))
#         feat32 = self.atenchannal2_5(feat32_ori)*feat32_ori
#         feat32 = self.spatten(feat32)*feat32
#         feat32 = feat32+feat32_ori
#         feat32 = self.t1024(feat32tmp)
#         feat32 = torch.cat((trans16,feat32_temp),dim=1)
#         feat32 = self.concatfuse32(feat32)
#         print("feat32:")
#         print(feat32.size())
        
        if self.use_conv_last:
           feat32 = self.conv_last(feat32)

        return feat2, feat4, feat8, feat16, feat32

    def forward_impl(self, x):
        out = self.features(x)
        out = self.conv_last(out).pow(2)
        out = self.gap(out).flatten(1)
        out = self.fc(out)
        # out = self.bn(out)
        out = self.relu(out)
        # out = self.relu(self.bn(self.fc(out)))
        out = self.dropout(out)
        out = self.linear(out)
        return out

# STDC1Net
class STDCNet813(nn.Module):
    def __init__(self, base=64, layers=[2,2,2], block_num=4, type="cat", num_classes=1000, dropout=0.20, pretrain_model='', use_conv_last=False):
        super(STDCNet813, self).__init__()
        if type == "cat":
            block = CatBottleneck
        elif type == "add":
            block = AddBottleneck
        self.use_conv_last = use_conv_last
        self.features = self._make_layers(base, layers, block_num, block)
        self.conv_last = ConvX(base*16, max(1024, base*16), 1, 1)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(max(1024, base*16), max(1024, base*16), bias=False)
        self.bn = nn.BatchNorm1d(max(1024, base*16))
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=dropout)
        self.linear = nn.Linear(max(1024, base*16), num_classes, bias=False)

        self.x2 = nn.Sequential(self.features[:1])
        self.x4 = nn.Sequential(self.features[1:2])
        self.x8 = nn.Sequential(self.features[2:4])
        self.x16 = nn.Sequential(self.features[4:6])
        self.x32 = nn.Sequential(self.features[6:])

        if pretrain_model:
            print('use pretrain model {}'.format(pretrain_model))
            self.init_weight(pretrain_model)
        else:
            self.init_params()

    def init_weight(self, pretrain_model):
        
        state_dict = torch.load(pretrain_model)["state_dict"]
        self_state_dict = self.state_dict()
        for k, v in state_dict.items():
            self_state_dict.update({k: v})
        self.load_state_dict(self_state_dict)

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def _make_layers(self, base, layers, block_num, block):
        features = []
        features += [ConvX(3, base//2, 3, 2)]
        features += [ConvX(base//2, base, 3, 2)]

        for i, layer in enumerate(layers):
            for j in range(layer):
                if i == 0 and j == 0:
                    features.append(block(base, base*4, block_num, 2))
                elif j == 0:
                    features.append(block(base*int(math.pow(2,i+1)), base*int(math.pow(2,i+2)), block_num, 2))
                else:
                    features.append(block(base*int(math.pow(2,i+2)), base*int(math.pow(2,i+2)), block_num, 1))

        return nn.Sequential(*features)

    def forward(self, x):
        feat2 = self.x2(x)
        feat4 = self.x4(feat2)
        feat8 = self.x8(feat4)
        feat16 = self.x16(feat8)
        feat32 = self.x32(feat16)
        if self.use_conv_last:
           feat32 = self.conv_last(feat32)

        return feat2, feat4, feat8, feat16, feat32

    def forward_impl(self, x):
        out = self.features(x)
        out = self.conv_last(out).pow(2)
        out = self.gap(out).flatten(1)
        out = self.fc(out)
        # out = self.bn(out)
        out = self.relu(out)
        # out = self.relu(self.bn(self.fc(out)))
        out = self.dropout(out)
        out = self.linear(out)
        return out

if __name__ == "__main__":
    model_b = STDCNet1446(Bottleneck,[3,4,6,3],num_classes=1000, dropout=0.00, block_num=4)
    model_b.eval()
    x = torch.randn(1,3,768,1536)
    y = model(x)
    torch.save(model_b.state_dict(), 'cat_res.pth')
#     print(y.size())

torch.Size([1, 32, 384, 768])
torch.Size([1, 64, 192, 384])
torch.Size([1, 256, 96, 192])
torch.Size([1, 512, 48, 96])
torch.Size([1, 1024, 24, 48])


In [12]:
from ptflops import get_model_complexity_info
from torchvision import models

ops, params = get_model_complexity_info(model_b, (3, 768, 1536), as_strings=True, print_per_layer_stat=True, verbose=True)


torch.Size([1, 32, 384, 768])
torch.Size([1, 64, 192, 384])
torch.Size([1, 256, 96, 192])
torch.Size([1, 512, 48, 96])
torch.Size([1, 1024, 24, 48])
STDCNet1446(
  104.039 M, 100.000% Params, 421.594 GMac, 100.000% MACs, 
  (conv1): Conv2d(0.005 M, 0.005% Params, 1.387 GMac, 0.329% MACs, 3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(0.0 M, 0.000% Params, 0.019 GMac, 0.004% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(0.0 M, 0.000% Params, 0.028 GMac, 0.007% MACs, inplace=True)
  (conv2): Conv2d(0.018 M, 0.018% Params, 5.436 GMac, 1.289% MACs, 32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(0.0 M, 0.000% Params, 0.038 GMac, 0.009% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    0.141 M, 0.135% Params, 11.353 GMac, 2.693% MACs, 
    (0): Bottleneck(
      0.05 M, 0.048% Params, 4.624 GMac, 1.097% MACs, 


In [11]:
# lz = torch.load('/xiaoou/STDC-Seg-master/STDC-Seg-master/yolov5s.pt',_use_new_zipfile_serialization=False)
import torch
# state_dict = torch.load("/xiaoou/STDC-Seg-master/STDC-Seg-master/yolov5s.pt")#加载原来的模型  在torch=1.6时加载
torch.save(state_dict, "/xiaoou/STDC-Seg-master/STDC-Seg-master/yolov5s2.pt", _use_new_zipfile_serialization=False)
#不是zip
# state_dict_new = torch.load("/xiaoou/STDC-Seg-master/STDC-Seg-master/yolov5s.pt")
# for k, v in state_dict_new.items():
#     print(k)

NameError: name 'state_dict' is not defined