In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision


from stdcnet import STDCNet1446, STDCNet813
# from modules.bn import InPlaceABNSync as BatchNorm2d
BatchNorm2d = nn.BatchNorm2d

class ConvBNReLU(nn.Module):
    def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_chan,
                out_chan,
                kernel_size = ks,
                stride = stride,
                padding = padding,
                bias = False)
        # self.bn = BatchNorm2d(out_chan)
        self.bn = BatchNorm2d(out_chan)
        self.relu = nn.ReLU()
        self.init_weight()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)


class BiSeNetOutput(nn.Module):
    def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
        super(BiSeNetOutput, self).__init__()
        self.conv = ConvBNReLU(in_chan, mid_chan, ks=1, stride=1, padding=0)
        self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
        self.init_weight()

    def forward(self, x):
        x = self.conv(x)
        x = self.conv_out(x)
        return x

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class AttentionRefinementModule(nn.Module):
    def __init__(self, in_chan, out_chan, *args, **kwargs):
        super(AttentionRefinementModule, self).__init__()
        self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
        self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
        # self.bn_atten = BatchNorm2d(out_chan)
        self.bn_atten = BatchNorm2d(out_chan)

        self.sigmoid_atten = nn.Sigmoid()
        self.init_weight()

    def forward(self, x):
        feat = self.conv(x)
        atten = F.avg_pool2d(feat, feat.size()[2:])
        atten = self.conv_atten(atten)
        atten = self.bn_atten(atten)
        atten = self.sigmoid_atten(atten)
        out = torch.mul(feat, atten)
        return out

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)


class ContextPath(nn.Module):
    def __init__(self, backbone='CatNetSmall', pretrain_model='', use_conv_last=False, *args, **kwargs):
        super(ContextPath, self).__init__()
        
        self.backbone_name = backbone
        if backbone == 'STDCNet1446':
            self.backbone = STDCNet1446(pretrain_model=pretrain_model, use_conv_last=use_conv_last)
            self.arm16 = AttentionRefinementModule(512, 128)
            inplanes = 1024
            if use_conv_last:
                inplanes = 1024
            self.arm32 = AttentionRefinementModule(inplanes, 128)
            self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
            self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
            self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0)

        elif backbone == 'STDCNet813':
            self.backbone = STDCNet813(pretrain_model=pretrain_model, use_conv_last=use_conv_last)
            self.arm16 = AttentionRefinementModule(512, 128)
            inplanes = 1024
            if use_conv_last:
                inplanes = 1024
            self.arm32 = AttentionRefinementModule(inplanes, 128)
            self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
            self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
            self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0)
        else:
            print("backbone is not in backbone lists")
            exit(0)

        self.init_weight()

    def forward(self, x):
        H0, W0 = x.size()[2:]
        print("input_size:")
        print(x.size())

        feat2, feat4, feat8, feat16, feat32 = self.backbone(x)
        print("after_backbone:feat2,feat4,feat8,feat16,feat32")
        print(feat2.size(),feat4.size(),feat8.size(),feat16.size(),feat32.size())
        H8, W8 = feat8.size()[2:]
        H16, W16 = feat16.size()[2:]
        H32, W32 = feat32.size()[2:]
        
        avg = F.avg_pool2d(feat32, feat32.size()[2:])
        print("after_avg:avg")
        print(avg.size())

        avg = self.conv_avg(avg)
        print("after_conv_avg:avg")
        print(avg.size())
        
        avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
        print("after_upsample1:avg_up")
        print(avg_up.size())


        feat32_arm = self.arm32(feat32)
        print("after_attention:feat32_arm")
        print(feat32_arm.size())
        feat32_sum = feat32_arm + avg_up
        print("after_sum:feat32_sum")
        print(feat32_sum.size())
        feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
        print("after_upsample2:feat32_up")
        print(feat32_up.size())
        feat32_up = self.conv_head32(feat32_up)
        print("after_conv_head32:feat32_up")
        print(feat32_up.size())

        feat16_arm = self.arm16(feat16)
        print("after_attention:feat16_arm")
        print(feat16_arm.size())
        feat16_sum = feat16_arm + feat32_up
        print("after_sum:feat16_sum")
        print(feat16_sum.size())
        feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
        print("after_upsample3:feat16_up")
        print(feat16_up.size())
        feat16_up = self.conv_head16(feat16_up)
        print("after_conv_head16:feat16_up")
        print(feat16_up.size())
        print("contextpath end")
        
        return feat2, feat4, feat8, feat16, feat16_up, feat32_up # x8, x16

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class FeatureFusionModule(nn.Module):
    def __init__(self, in_chan, out_chan, *args, **kwargs):
        super(FeatureFusionModule, self).__init__()
        self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
        self.conv1 = nn.Conv2d(out_chan,
                out_chan//4,
                kernel_size = 1,
                stride = 1,
                padding = 0,
                bias = False)
        self.conv2 = nn.Conv2d(out_chan//4,
                out_chan,
                kernel_size = 1,
                stride = 1,
                padding = 0,
                bias = False)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()
        self.init_weight()

    def forward(self, fsp, fcp):
        fcat = torch.cat([fsp, fcp], dim=1)
        feat = self.convblk(fcat)
        atten = F.avg_pool2d(feat, feat.size()[2:])
        atten = self.conv1(atten)
        atten = self.relu(atten)
        atten = self.conv2(atten)
        atten = self.sigmoid(atten)
        feat_atten = torch.mul(feat, atten)
        feat_out = feat_atten + feat
        return feat_out

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params


class BiSeNet(nn.Module):
    def __init__(self, backbone, n_classes, pretrain_model='', use_boundary_2=False, use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False, heat_map=False, *args, **kwargs):
        super(BiSeNet, self).__init__()
        
        self.use_boundary_2 = use_boundary_2
        self.use_boundary_4 = use_boundary_4
        self.use_boundary_8 = use_boundary_8
        self.use_boundary_16 = use_boundary_16
        # self.heat_map = heat_map
        self.cp = ContextPath(backbone, pretrain_model, use_conv_last=use_conv_last)
        
        
        
        if backbone == 'STDCNet1446':
            conv_out_inplanes = 128
            sp2_inplanes = 32
            sp4_inplanes = 64
            sp8_inplanes = 256
            sp16_inplanes = 512
            inplane = sp8_inplanes + conv_out_inplanes

        elif backbone == 'STDCNet813':
            conv_out_inplanes = 128
            sp2_inplanes = 32
            sp4_inplanes = 64
            sp8_inplanes = 256
            sp16_inplanes = 512
            inplane = sp8_inplanes + conv_out_inplanes

        else:
            print("backbone is not in backbone lists")
            exit(0)

        self.ffm = FeatureFusionModule(inplane, 256)
        self.conv_out = BiSeNetOutput(256, 64, n_classes)
        self.conv_out16 = BiSeNetOutput(conv_out_inplanes, 64, n_classes)
        self.conv_out32 = BiSeNetOutput(conv_out_inplanes, 64, n_classes)

        self.conv_out_sp16 = BiSeNetOutput(sp16_inplanes, 64, 1)
        
        self.conv_out_sp8 = BiSeNetOutput(sp8_inplanes, 64, 1)
        self.conv_out_sp4 = BiSeNetOutput(sp4_inplanes, 64, 1)
        self.conv_out_sp2 = BiSeNetOutput(sp2_inplanes, 64, 1)
        
        self.conv_out_ajchan2 = ConvBNReLU(sp2_inplanes, n_classes, ks=3, stride=1, padding=1)
        self.conv_out_ajchan4 = ConvBNReLU(sp4_inplanes, n_classes, ks=3, stride=1, padding=1)
        
#         self.CBR4 = ConvBNReLU(n_classes, n_classes, ks=3, stride=1, padding=1)
#         self.CBR2 = ConvBNReLU(n_classes, n_classes, ks=3, stride=1, padding=1)
        self.init_weight()

    def forward(self, x):
        H, W = x.size()[2:]
        
        feat_res2, feat_res4, feat_res8, feat_res16, feat_cp8, feat_cp16 = self.cp(x)

        feat_out_sp2 = self.conv_out_sp2(feat_res2)
        print("feat_out_sp2:")
        print(feat_out_sp2.size())

        feat_out_sp4 = self.conv_out_sp4(feat_res4)
        print("feat_out_sp4:")
        print(feat_out_sp4.size())
  
        feat_out_sp8 = self.conv_out_sp8(feat_res8)
        print("feat_out_sp8:")
        print(feat_out_sp8.size())

        feat_out_sp16 = self.conv_out_sp16(feat_res16)
        print("feat_out_sp16:")
        print(feat_out_sp16.size())

        feat_fuse = self.ffm(feat_res8, feat_cp8)
        print("feat_fuse:")
        print(feat_fuse.size())

        
        feat_out = self.conv_out(feat_fuse)
        print("feat_out:")
        print(feat_out.size())
        feat_out16 = self.conv_out16(feat_cp8)
        print("feat_out16:")
        print(feat_out16.size())
        feat_out32 = self.conv_out32(feat_cp16)
        print("feat_out32:")
        print(feat_out32.size())
        
        feat_out = F.interpolate(feat_out, (H//4, W//4), mode='bilinear', align_corners=True)
        feat_temp8 = feat_out+self.conv_out_ajchan4(feat_res4)
        feat_out = F.interpolate(feat_temp8, (H//2, W//2), mode='bilinear', align_corners=True)
        feat_temp8 = feat_out+self.conv_out_ajchan2(feat_res2)
        feat_out = F.interpolate(feat_temp8, (H, W), mode='bilinear', align_corners=True)
        print("after_upsample:feat_out:")
        print(feat_out.size())
        
        feat_out16 = F.interpolate(feat_out16, (H//4, W//4), mode='bilinear', align_corners=True)
        feat_temp16 = feat_out16+self.conv_out_ajchan4(feat_res4)
        feat_out = F.interpolate(feat_temp16, (H//2, W//2), mode='bilinear', align_corners=True)
        feat_temp16 = feat_out+self.conv_out_ajchan2(feat_res2)
        feat_out = F.interpolate(feat_temp16, (H, W), mode='bilinear', align_corners=True)
        print("after_upsample:feat_out16:")
        print(feat_out.size())
        
        feat_out32 = F.interpolate(feat_out32, (H//4, W//4), mode='bilinear', align_corners=True)
        feat_temp32 = feat_out32+self.conv_out_ajchan4(feat_res4)
        feat_out = F.interpolate(feat_temp32, (H//2, W//2), mode='bilinear', align_corners=True)
        feat_temp32 = feat_out+self.conv_out_ajchan2(feat_res2)
        feat_out = F.interpolate(feat_temp32, (H, W), mode='bilinear', align_corners=True)
        print("after_upsample:feat_out32:")
        print(feat_out.size())


        if self.use_boundary_2 and self.use_boundary_4 and self.use_boundary_8:
            return feat_out, feat_out16, feat_out32, feat_out_sp2, feat_out_sp4, feat_out_sp8
        
        if (not self.use_boundary_2) and self.use_boundary_4 and self.use_boundary_8:
            return feat_out, feat_out16, feat_out32, feat_out_sp4, feat_out_sp8

        if (not self.use_boundary_2) and (not self.use_boundary_4) and self.use_boundary_8:
            return feat_out, feat_out16, feat_out32, feat_out_sp8
        
        if (not self.use_boundary_2) and (not self.use_boundary_4) and (not self.use_boundary_8):
            return feat_out, feat_out16, feat_out32

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
        for name, child in self.named_children():
            child_wd_params, child_nowd_params = child.get_params()
            if isinstance(child, (FeatureFusionModule, BiSeNetOutput)):
                lr_mul_wd_params += child_wd_params
                lr_mul_nowd_params += child_nowd_params
            else:
                wd_params += child_wd_params
                nowd_params += child_nowd_params
        return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params


In [2]:
if __name__ == "__main__":
    
    net = BiSeNet('STDCNet1446', 19)
#     net.cuda()
    net.eval()
    in_ten = torch.randn(1, 3, 768, 1536)
    out, out16, out32 = net(in_ten)
    print(out.shape)
    torch.save(net.state_dict(), 'STDCNet1446_modify_conv1layeradd.pth')

input_size:
torch.Size([1, 3, 768, 1536])
feat4:
torch.Size([1, 64, 192, 384])
feat8:
torch.Size([1, 256, 96, 192])
feat16tmp:
torch.Size([1, 256, 48, 96])
feat16:
torch.Size([1, 512, 48, 96])
feat32:
torch.Size([1, 1024, 24, 48])
after_backbone:feat2,feat4,feat8,feat16,feat32
torch.Size([1, 32, 384, 768]) torch.Size([1, 64, 192, 384]) torch.Size([1, 256, 96, 192]) torch.Size([1, 512, 48, 96]) torch.Size([1, 1024, 24, 48])
after_avg:avg
torch.Size([1, 1024, 1, 1])
after_conv_avg:avg
torch.Size([1, 128, 1, 1])
after_upsample1:avg_up
torch.Size([1, 128, 24, 48])
after_attention:feat32_arm
torch.Size([1, 128, 24, 48])
after_sum:feat32_sum
torch.Size([1, 128, 24, 48])
after_upsample2:feat32_up
torch.Size([1, 128, 48, 96])
after_conv_head32:feat32_up
torch.Size([1, 128, 48, 96])
after_attention:feat16_arm
torch.Size([1, 128, 48, 96])
after_sum:feat16_sum
torch.Size([1, 128, 48, 96])
after_upsample3:feat16_up
torch.Size([1, 128, 96, 192])
after_conv_head16:feat16_up
torch.Size([1, 128, 96, 1

In [1]:
for name, param in net.named_parameters():
    print(name, param.size())

NameError: name 'net' is not defined

In [None]:
class baseblock(nn.Module):
    def __init__(self, in_planes, out_planes, block_num=3, stride=1):
        super(CatBottleneck, self).__init__()
        self.conv_list = nn.ModuleList()
        
        for idx in range(block_num):
            if idx == 0:
                self.conv_list.append(ConvX(in_planes, out_planes//2, kernel=1))

In [None]:
class Conv3X(nn.Module):
    def __init__(self, in_planes, out_planes, kernel=3, stride=1):
        super(Conv3X, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel, stride=stride, padding=kernel//2, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.relu(self.bn(self.conv(x)))
        return out
class Conv1X(nn.Module):
    def __init__(self, in_planes, out_planes, stride=1):
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        out = self.relu(self.bn(self.conv(x)))

In [6]:
x=torch.arange(24).view(1,2,4,3)
'''
注意：在这里使用的时候转一下类型，否则会报RuntimeError: Can only calculate the mean of floating types. Got Long instead.的错误。
查看了一下x元素类型是torch.int64,根据提示添加一句x=x.float()转为tensor.float32就行
'''
x=x.float()
x_mean=torch.mean(x)
# x_mean0=torch.mean(x,dim=0,keepdim=True)
x_mean1=torch.mean(x,dim=1,keepdim=True)
print('x:')
print(x)
# print('x_mean0:')
# print(x_mean0)
print('x_mean1:')
print(x_mean1)
print(x_mean1.size())
print('x_mean:')
print(x_mean)

x:
tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.],
          [ 6.,  7.,  8.],
          [ 9., 10., 11.]],

         [[12., 13., 14.],
          [15., 16., 17.],
          [18., 19., 20.],
          [21., 22., 23.]]]])
x_mean1:
tensor([[[[ 6.,  7.,  8.],
          [ 9., 10., 11.],
          [12., 13., 14.],
          [15., 16., 17.]]]])
torch.Size([1, 1, 4, 3])
x_mean:
tensor(11.5000)


In [20]:
state_dict = torch.load("/xiaoou/STDC-Seg-master/STDC-Seg-master/checkpoint/train_STDC2-Seg_depthwise14/pths/model_final.pth")
# model_resnet101.load_state_dict({k.replace('module.',''):v for k,v in torch.load("densenet169_rnn_fold_1_model_best_f1.pth.tar")['state_dict'].items()})
# state_dict = {key: value for key, value in state_dict.items() if key < 584}
model = BiSeNet('STDCNet1446', 19)
new_state_dict = {}
sub_list = list(state_dict.keys())[:583]
for k,v in state_dict.items():
    if k in sub_list:
        name = k.replace('cp.backbone.','')
        new_state_dict[name] = v
for k,v in new_state_dict.items():
    print(k)
# model_resnet101.load_state_dict({k.replace('module.',''):v for k,v in torch.load("densenet169_rnn_fold_1_model_best_f1.pth.tar")['state_dict'].items()})


features.0.conv.weight
features.0.bn.weight
features.0.bn.bias
features.0.bn.running_mean
features.0.bn.running_var
features.0.bn.num_batches_tracked
features.1.conv.weight
features.1.bn.weight
features.1.bn.bias
features.1.bn.running_mean
features.1.bn.running_var
features.1.bn.num_batches_tracked
features.2.conv_list.0.conv.weight
features.2.conv_list.0.bn.weight
features.2.conv_list.0.bn.bias
features.2.conv_list.0.bn.running_mean
features.2.conv_list.0.bn.running_var
features.2.conv_list.0.bn.num_batches_tracked
features.2.conv_list.1.conv.weight
features.2.conv_list.1.bn.weight
features.2.conv_list.1.bn.bias
features.2.conv_list.1.bn.running_mean
features.2.conv_list.1.bn.running_var
features.2.conv_list.1.bn.num_batches_tracked
features.2.conv_list.2.conv.weight
features.2.conv_list.2.bn.weight
features.2.conv_list.2.bn.bias
features.2.conv_list.2.bn.running_mean
features.2.conv_list.2.bn.running_var
features.2.conv_list.2.bn.num_batches_tracked
features.2.conv_list.3.conv.weight

In [5]:
state_dict = torch.load("/xiaoou/STDC-Seg-master/STDC-Seg-master/STDC-Seg-weight/STDCNet1446_76.47.tar",map_location='cpu')["state_dict"]
model = BiSeNet('STDCNet1446', 19)
for k,v in state_dict.items():
    print(k) 

features.0.conv.weight
features.0.bn.weight
features.0.bn.bias
features.0.bn.running_mean
features.0.bn.running_var
features.0.bn.num_batches_tracked
features.1.conv.weight
features.1.bn.weight
features.1.bn.bias
features.1.bn.running_mean
features.1.bn.running_var
features.1.bn.num_batches_tracked
features.2.conv_list.0.conv.weight
features.2.conv_list.0.bn.weight
features.2.conv_list.0.bn.bias
features.2.conv_list.0.bn.running_mean
features.2.conv_list.0.bn.running_var
features.2.conv_list.0.bn.num_batches_tracked
features.2.conv_list.1.conv.weight
features.2.conv_list.1.bn.weight
features.2.conv_list.1.bn.bias
features.2.conv_list.1.bn.running_mean
features.2.conv_list.1.bn.running_var
features.2.conv_list.1.bn.num_batches_tracked
features.2.conv_list.2.conv.weight
features.2.conv_list.2.bn.weight
features.2.conv_list.2.bn.bias
features.2.conv_list.2.bn.running_mean
features.2.conv_list.2.bn.running_var
features.2.conv_list.2.bn.num_batches_tracked
features.2.conv_list.3.conv.weight