In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18


class Resnet18(nn.Module):
    def __init__(self, num_classes=2, pretrained=False):
        super(Resnet18, self).__init__()
        model = resnet18(pretrained)

        # take pretrained resnet, except AvgPool and FC
        self.conv1 = model.conv1
        self.bn1 = model.bn1
        self.relu = model.relu
        self.maxpool = model.maxpool
        self.layer1 = model.layer1
        self.layer2 = model.layer2
        self.layer3 = model.layer3
        self.layer4 = model.layer4
        self.last_conv = nn.Conv2d(512, num_classes, kernel_size=1)

    def forward(self, x):
        input_size = x.shape[-2:]
        x = torch.cat([x, x, x], dim=1)  # 扩充为3通道
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        low = x
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        high = x
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.last_conv(x)
        x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=True)
        return x, low, high

In [83]:
import torch.nn as nn
import torch


class conv_block(nn.Module):
    """
    Convolution Block
    """

    def __init__(self, in_ch, out_ch):
        super(conv_block, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv(x)
        return x


class up_conv(nn.Module):
    """
    Up Convolution Block
    """
    def __init__(self, in_ch, out_ch):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x


class U_Net_R18(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=1, out_ch=4):
        super(U_Net, self).__init__()
        n1 = 32
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], filters[0])

        self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)

    # self.active = torch.nn.Sigmoid()

    def forward(self, x):
        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)
        
        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)

        return out, e2, e5  

In [84]:
r18 = U_Net()
x = torch.rand(2,1,224,224)
out, e2, e5   = r18(x)
print(out.shape)

torch.Size([2, 512, 14, 14])
torch.Size([2, 4, 224, 224])


In [88]:
n1 = 64
filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16, n1*32]
print(filters)

[64, 128, 256, 512, 1024, 2048]


In [198]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

class U_Net_R50(nn.Module):
    """
    UNet with ResNet-50 Encoder
    """
    def __init__(self, in_ch=3, out_ch=4):
        super(U_Net_R50, self).__init__()
        
        # Encoder (ResNet-50)
        resnet = models.resnet50(pretrained=False)
        self.encoder1 = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool
        )
        self.encoder2 = resnet.layer1
        self.encoder3 = resnet.layer2
        self.encoder4 = resnet.layer3
        self.encoder5 = resnet.layer4
        
        # Decoder
        self.Up5 = up_conv(2048, 1024)  # ResNet-50 layer4 output has 2048 channels
        self.Up_conv5 = conv_block(2048, 1024)
        
        self.Up4 = up_conv(1024, 512)
        self.Up_conv4 = conv_block(1024, 512)
        
        self.Up3 = up_conv(512, 256)
        self.Up_conv3 = conv_block(512, 256)
        
        self.Up2 = up_conv(256, 128)
        self.Up_conv2 = conv_block(192, 128)
        
        self.Up1 = up_conv(128, 64)
        self.Up_conv1 = conv_block(67, 64)

        self.Conv = nn.Conv2d(64, out_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)
        e5 = self.encoder5(e4)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        e1 = F.interpolate(e1, size=d2.size()[2:], mode='bilinear', align_corners=True)

        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Up1(d2)
        d1 = torch.cat((x, d1), dim=1)  # Fix the input channel dimension here
        d1 = self.Up_conv1(d1)

        out = self.Conv(d1)

        return e1,e2,e3,e4,e5, d5,d4,d3,d2,out


# Example usage:
# Create a U-Net with ResNet-50 encoder, 3 input channels, and 4 output channels
model = U_Net_R50(in_ch=3, out_ch=4)
# r18 = U_Net_ResNet18()
x = torch.rand(2,3,224,224)
e1,e2,e3,e4,e5, d5,d4,d3,d2,out   = model(x)
print(e1.shape,e2.shape,e3.shape,e4.shape,e5.shape, d5.shape,d4.shape,d3.shape,d2.shape,out.shape)

torch.Size([2, 64, 112, 112]) torch.Size([2, 256, 56, 56]) torch.Size([2, 512, 28, 28]) torch.Size([2, 1024, 14, 14]) torch.Size([2, 2048, 7, 7]) torch.Size([2, 1024, 14, 14]) torch.Size([2, 512, 28, 28]) torch.Size([2, 256, 56, 56]) torch.Size([2, 128, 112, 112]) torch.Size([2, 4, 224, 224])


In [174]:
import torch
import torch.nn as nn
import torchvision.models as models

class U_Net_ResNet18(nn.Module):
    def __init__(self, in_ch=3, out_ch=4):
        super(U_Net_ResNet18, self).__init__()
        
        # Encoder (ResNet-18)
        resnet = models.resnet18(pretrained=False)
        self.encoder1 = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool
        )
        self.encoder2 = resnet.layer1
        self.encoder3 = resnet.layer2
        self.encoder4 = resnet.layer3
        self.encoder5 = resnet.layer4
        
        # Decoder
        self.Up5 = up_conv(512, 256)  # ResNet-18 layer4 output has 512 channels
        self.Up_conv5 = conv_block(512, 256)
        
        self.Up4 = up_conv(256, 128)
        self.Up_conv4 = conv_block(256, 128)
        
        self.Up3 = up_conv(128, 64)
        self.Up_conv3 = conv_block(128, 64)
        
        self.Up2 = up_conv(64, 32)
        self.Up_conv2 = conv_block(64+32, 32)
        
        self.Conv = nn.Conv2d(32, out_ch, kernel_size=1, stride=1, padding=0)
    
    def t_guided_s(self, s, t):
        """
        Compact Cross-Attention from Teacher to Student feature maps.
        """
#         print(s.shape)
        s_pri = s
        channel_decomp = nn.Conv2d(s.shape[1], t.shape[1], kernel_size=1)
        s = channel_decomp(s)
        
        if s.shape[2] != t.shape[2]:
            s = F.interpolate(s, t.size()[-2:], mode='bilinear')
                
        attn_map = torch.matmul(t, s.transpose(2, 3))
        attn_map = F.softmax(attn_map, dim=-1)
        guided_s = torch.matmul(attn_map.transpose(2, 3), s)
        
        channel_comp = nn.Conv2d(guided_s.shape[1], s_pri.shape[1], kernel_size=1)
        guided_s = channel_comp(guided_s)
        
        guided_s = F.interpolate(guided_s, s_pri.size()[-2:], mode='bilinear')
        
        return guided_s

    def s_guided_t(self, t, s):
        """
        Compact Cross-Attention from Student to Teacher feature maps.
        """
        attn_map = torch.matmul(s, t.transpose(2, 3))
        attn_map = F.softmax(attn_map, dim=-1)
        guided_t = torch.matmul(attn_map.transpose(2, 3), t)
        return guided_t
    
    def forward(self, x, t_attn_skips):
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)
        e5 = self.encoder5(e4)
        
        t_e1, t_e2, t_e3, t_e4, t_e5 = t_attn_skips
        t_s_e1, t_s_e2, t_s_e3, t_s_e4, t_s_e5 = self.t_guided_s(e1, t_e1), self.t_guided_s(e2, t_e2), self.t_guided_s(e3, t_e3), self.t_guided_s(e4, t_e4), self.t_guided_s(e5, t_e5)
        print('----------------------------------------------')
        print(e1.shape,e2.shape,e3.shape,e4.shape,e5.shape)
        print(t_s_e1.shape, t_s_e2.shape, t_s_e3.shape, t_s_e4.shape, t_s_e5.shape)
        d5 = self.Up5(t_s_e5)
        d5 = torch.cat((t_s_e4, d5), dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((t_s_e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((t_s_e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        t_s_e1 = F.interpolate(t_s_e1, size=d2.size()[2:], mode='bilinear', align_corners=True)
        d2 = torch.cat((t_s_e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)
        out = F.interpolate(out, size=x.size()[2:], mode='bilinear', align_corners=True)

        return t_s_e1, t_s_e2, t_s_e3, t_s_e4, t_s_e5, d5, d4, d3, d2, out


# Example usage:
# Create a U-Net with ResNet-18 encoder, 3 input channels, and 4 output channels
# model = U_Net_ResNet18(in_ch=3, out_ch=4)


In [175]:
r18 = U_Net_ResNet18()
x = torch.rand(2,3,224,224)
e1,e2,e3,e4,e5, d5,d4,d3,d2,out = r18(x, [e1,e2,e3,e4,e5])
print(e1.shape,e2.shape,e3.shape,e4.shape,e5.shape, d5.shape,d4.shape,d3.shape,d2.shape,out.shape)

----------------------------------------------
torch.Size([2, 64, 56, 56]) torch.Size([2, 64, 56, 56]) torch.Size([2, 128, 28, 28]) torch.Size([2, 256, 14, 14]) torch.Size([2, 512, 7, 7])
torch.Size([2, 64, 56, 56]) torch.Size([2, 64, 56, 56]) torch.Size([2, 128, 28, 28]) torch.Size([2, 256, 14, 14]) torch.Size([2, 512, 7, 7])
torch.Size([2, 64, 112, 112]) torch.Size([2, 64, 56, 56]) torch.Size([2, 128, 28, 28]) torch.Size([2, 256, 14, 14]) torch.Size([2, 512, 7, 7]) torch.Size([2, 256, 14, 14]) torch.Size([2, 128, 28, 28]) torch.Size([2, 64, 56, 56]) torch.Size([2, 32, 112, 112]) torch.Size([2, 4, 224, 224])


In [199]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

class U_Net_MobileNetV2(nn.Module):
    """
    UNet with MobileNetV2 Encoder
    """
    def __init__(self, in_ch=3, out_ch=4):
        super(U_Net_MobileNetV2, self).__init__()

        # Encoder (MobileNetV2)
        mobilenet_v2 = models.mobilenet_v2(pretrained=True)
        print(mobilenet_v2)
        self.encoder1 = mobilenet_v2.features[0:2]
        self.encoder2 = mobilenet_v2.features[2:4]
        self.encoder3 = mobilenet_v2.features[4:7]
        self.encoder4 = mobilenet_v2.features[7:14]
        self.encoder5 = mobilenet_v2.features[14:18]


        # Decoder
        self.Up5 = up_conv(320, 96)  # MobileNetV2 last layer output has 320 channels
        self.Up_conv5 = conv_block(96 + 96, 96)  # Adjust input channels to match MobileNetV2

        self.Up4 = up_conv(96, 24)
        self.Up_conv4 = conv_block(24 + 32, 24)  # Adjust input channels to match MobileNetV2

        self.Up3 = up_conv(24, 16)
        self.Up_conv3 = conv_block(16 + 24, 16)  # Adjust input channels to match MobileNetV2

        self.Up2 = up_conv(16, 8)
        self.Up_conv2 = conv_block(8 + 16, 8)  # Adjust input channels to match MobileNetV2

        self.Up1 = up_conv(8, 3)
        self.Up_conv1 = conv_block(3 + 3, 3)  # Adjust input channels to match MobileNetV2

        self.Conv = nn.Conv2d(3, out_ch, kernel_size=1, stride=1, padding=0)
    
    def t_guided_s(self, s, t):
        """
        Compact Cross-Attention from Teacher to Student feature maps.
        """
#         print(s.shape)
        s_pri = s
        channel_decomp = nn.Conv2d(s.shape[1], t.shape[1], kernel_size=1)
        s = channel_decomp(s)
        
        if s.shape[2] != t.shape[2]:
            s = F.interpolate(s, t.size()[-2:], mode='bilinear')
                
        attn_map = torch.matmul(t, s.transpose(2, 3))
        attn_map = F.softmax(attn_map, dim=-1)
        guided_s = torch.matmul(attn_map.transpose(2, 3), s)
        
        channel_comp = nn.Conv2d(guided_s.shape[1], s_pri.shape[1], kernel_size=1)
        guided_s = channel_comp(guided_s)
        
        guided_s = F.interpolate(guided_s, s_pri.size()[-2:], mode='bilinear')
        
        return guided_s
    
    def forward(self, x, t_attn_skips):
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)
        e5 = self.encoder5(e4)
        
        t_e1, t_e2, t_e3, t_e4, t_e5 = t_attn_skips
        t_s_e1, t_s_e2, t_s_e3, t_s_e4, t_s_e5 = self.t_guided_s(e1, t_e1), self.t_guided_s(e2, t_e2), self.t_guided_s(e3, t_e3), self.t_guided_s(e4, t_e4), self.t_guided_s(e5, t_e5)
        print('----------------------------------------------')
        print(e1.shape,e2.shape,e3.shape,e4.shape,e5.shape)
        print(t_s_e1.shape, t_s_e2.shape, t_s_e3.shape, t_s_e4.shape, t_s_e5.shape)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        e1 = F.interpolate(e1, size=d2.size()[2:], mode='bilinear', align_corners=True)

        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Up1(d2)
        d1 = torch.cat((x, d1), dim=1)  # Fix the input channel dimension here
        d1 = self.Up_conv1(d1)

        out = self.Conv(d1)

        return e1, e2, e3, e4, e5, d5, d4, d3, d2, out

# Helper functions (up_conv and conv_block) should be defined as per your original code.


In [200]:
r18 = U_Net_MobileNetV2()
x = torch.rand(2,3,224,224)
e1,e2,e3,e4,e5, d5,d4,d3,d2,out = r18(x, [e1,e2,e3,e4,e5])
# print(e1.shape,e2.shape,e3.shape,e4.shape,e5.shape, d5.shape,d4.shape,d3.shape,d2.shape,out.shape)

MobileNetV2(
  (features): Sequential(
    (0): ConvNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): ConvNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): ConvNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=1e-05,

In [124]:

def at(x, exp):
    """
    attention value of a feature map
    :param x: feature
    :return: attention value
    """
    return F.normalize(x.pow(exp).mean(1).view(x.size(0), -1))

def t_guided_s(s, t):
    """
    Compact Cross-Attention from Teacher to Student feature maps.
    """
    channel_decomp = nn.Conv2d(s.shape[1], t.shape[1], kernel_size=1)
    s = channel_decomp(s)

    attn_map = torch.matmul(t, s.transpose(2, 3))
    attn_map = F.softmax(attn_map, dim=-1)
    guided_s = torch.matmul(attn_map.transpose(2, 3), s)

    channel_comp = nn.Conv2d(guided_s.shape[1], s.shape[1], kernel_size=1)
    guided_s = channel_comp(guided_s)
    return guided_s
def importance_maps_distillation(s, t, exp=4):
    """
    importance_maps_distillation KD loss, based on "Paying More Attention to Attention:
    Improving the Performance of Convolutional Neural Networks via Attention Transfer"
    https://arxiv.org/abs/1612.03928
    :param exp: exponent
    :param s: student feature maps
    :param t: teacher feature maps
    :return: imd loss value
    """
    if s.shape[2] != t.shape[2]:
        s = F.interpolate(s, t.size()[-2:], mode='bilinear')
    print(s.shape, t.shape)
    s = t_guided_s(s,t)
    return torch.sum((at(s, exp) - at(t, exp)).pow(2), dim=1).mean()


In [125]:
loss_imd = importance_maps_distillation(low, high)

torch.Size([2, 64, 28, 28]) torch.Size([2, 128, 28, 28])


In [113]:
a = [1,2,3]
a,b,c = a
b

2