In [1]:
from torchvision.models._utils import IntermediateLayerGetter
import torch
import torch.nn as nn

class BasicBlockEnc(nn.Module):

    def __init__(self, in_channels,out_channels, stride=1):
        super(BasicBlockEnc, self).__init__()
        in_channels =2 
        out_channels = in_channels*stride

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        if stride == 1:
            self.shortcut = nn.Sequential()
        else:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels,out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        #print("conv2 output size:", out.size())
        if hasattr(self, 'shortcut'):
            shortcut_out = self.shortcut(x)
            #print("shortcut output size:", shortcut_out.size())
        else:
            shortcut_out = x
        out = torch.relu(out)
        return out


class ResNet18cus(nn.Module):

    def __init__(self, num_Blocks=[2,2,2,2], nc=2):
        super().__init__()
        self.in_channels = 2
        self.conv1 = nn.Conv2d(nc, 2, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(2)
        self.layer1 = self._make_layer(BasicBlockEnc, 8, 2, stride=1) ######## 2 is the number of basic blocks
        self.layer2 = self._make_layer(BasicBlockEnc, 16, 2, stride=1)
        self.layer3 = self._make_layer(BasicBlockEnc, 32, 2, stride=1)
        self.layer4 = self._make_layer(BasicBlockEnc, 64, 2, stride=1)
        self.linear = nn.Linear(64,2048)

    def _make_layer(self, BasicBlockEnc,out_channels, num_Blocks, stride):
        strides = [stride] + [1]*(num_Blocks-1)
        layers = []
        for stride in strides:
            layers += [BasicBlockEnc(self.in_channels, stride)]
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = F.adaptive_avg_pool2d(x, 1)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x


class PPM(nn.ModuleList):
    def __init__(self, pool_sizes, in_channels, out_channels):
        super(PPM, self).__init__()
        self.pool_sizes = pool_sizes
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        for pool_size in pool_sizes:
            self.append(
                nn.Sequential(
                    nn.AdaptiveMaxPool2d(pool_size),
                    nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1),
                )
            )
            
    def forward(self, x):
        out_puts = []
        for ppm in self:
            ppm_out = nn.functional.interpolate(ppm(x), size=x.size()[-2:], mode='bilinear', align_corners=True)
            out_puts.append(ppm_out)
        return out_puts
 
    
class PSPHEAD(nn.Module):
    def __init__(self, in_channels, out_channels,pool_sizes = [1, 2, 3, 6],num_classes=3):
        super(PSPHEAD, self).__init__()
        self.pool_sizes = pool_sizes
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.psp_modules = PPM(self.pool_sizes, self.in_channels, self.out_channels)
        self.final = nn.Sequential(
            nn.Conv2d(self.in_channels + len(self.pool_sizes)*self.out_channels, self.out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(),
        )
        
    def forward(self, x):
        out = self.psp_modules(x)
        out.append(x)
        out = torch.cat(out, 1)
        out = self.final(out)
        return out
 
# 构建一个FCN分割头，用于计算辅助损失
class Aux_Head(nn.Module):
    def __init__(self, in_channels=1024, num_classes=3):
        super(Aux_Head, self).__init__()
        self.num_classes = num_classes
        self.in_channels = in_channels
 
        self.decode_head = nn.Sequential(
            nn.Conv2d(self.in_channels, self.in_channels//2, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.in_channels//2),
            nn.ReLU(),            
            
            nn.Conv2d(self.in_channels//2, self.in_channels//4, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.in_channels//4),
            nn.ReLU(),            
            
            nn.Conv2d(self.in_channels//4, self.num_classes, kernel_size=3, padding=1),
 
        )
        
    def forward(self, x):
 
        return self.decode_head(x)
 
class Pspnet(nn.Module):
    def __init__(self, num_classes, aux_loss = True):
        super(Pspnet, self).__init__()
        self.num_classes = num_classes
        self.backbone = IntermediateLayerGetter(
            ResNet18cus(),
            return_layers={'layer3':"aux" ,'linear': 'stage4'}
        )
        self.aux_loss = aux_loss
        self.decoder = PSPHEAD(in_channels=64, out_channels=32, pool_sizes = [1, 2, 3, 6], num_classes=self.num_classes)
        self.cls_seg = nn.Sequential(
            nn.Conv2d(32, self.num_classes, kernel_size=3, padding=1),
        )
        if self.aux_loss:
            self.aux_head = Aux_Head(in_channels=128, num_classes=self.num_classes)
 
        
    def forward(self, x):
        _, _, h, w = x.size()
        feats = self.backbone(x)
        last_loss = feats["stage4"].view(-1,64,64,64)
        x = self.decoder(last_loss)
        x = self.cls_seg(x)
        segmentation = x ##################derive segmentation
        x = nn.functional.interpolate(x, size=(h, w),mode='bilinear', align_corners=True)
 
        # 如果需要添加辅助损失
        if self.aux_loss:
            aux_output = self.aux_head(feats['aux'].view(-1,128,8,8))
            aux_output = nn.functional.interpolate(aux_output, size=(h, w),mode='bilinear', align_corners=True)
 
            return {"output":x, "aux_output":aux_output}
        return {"output":x}
 
 
# if __name__ == "__main__":
#     model = Pspnet(num_classes=4, aux_loss=True)
#     model = model.cuda()
#     a = torch.ones([10, 2, 64, 64])
#     a = a.cuda()
 
#     for name, out in model(a).items():
#         print(name,": ", out.shape)
model = Pspnet(num_classes=4, aux_loss=True)
model = model.cuda()
print("Num params: ", sum(p.numel() for p in model.parameters()))



  from .autonotebook import tqdm as notebook_tqdm


Num params:  292272


In [6]:
model_size_bytes = sum(p.numel() for p in model.parameters()) * 4  # numel()返回参数的总元素数，乘以4字节得到总字节数
model_size_mb = model_size_bytes / (1024 ** 2)  # 将字节转换为兆字节
print(model_size_mb)

1.11492919921875


In [5]:
in_channels = 1024
num_classes = 3

# 卷积层参数
conv1_params = (in_channels * (in_channels//2) * 3**2) + (in_channels//2)
conv2_params = ((in_channels//2) * (in_channels//4) * 3**2) + (in_channels//4)
conv3_params = ((in_channels//4) * num_classes * 3**2) + num_classes

# 批量归一化层参数
bn1_params = (in_channels//2) * 2
bn2_params = (in_channels//4) * 2

# 总参数数量
total_params = conv1_params + bn1_params + conv2_params + bn2_params + conv3_params

print("Total parameters in Aux_Head:", total_params)


Total parameters in Aux_Head: 5907459
