In [22]:
from timm import create_model
import torch
import torch.nn as nn
import torch.nn.functional as F

# device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
feature_extractor = create_model(
        'resnet18', 
        pretrained    = True, 
        features_only = True
    )
# 
img = torch.randn(40, 3, 256, 256)
features = feature_extractor(img)
for i, v in enumerate(features):
    print(f'{i} : {v.shape}')
# print(feature_extractor)

f_in = features[0]
f_out = features[-1]
f_ii = features[1:-1]
catted = []
for ii in f_ii:
    catted.append(torch.cat([ii, ii], dim=1))

class MSFFBlock(nn.Module):
    def __init__(self, in_channel):
        super(MSFFBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1)
#         self.attn = CoordAtt(in_channel, in_channel)
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channel, in_channel // 2, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channel // 2, in_channel // 2, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x):
        x_conv = self.conv1(x)
#         x_att = self.attn(x)
        
#         x = x_conv * x_att
        x = self.conv2(x_conv)
        return x

    
class MSFF(nn.Module):
    def __init__(self):
        super(MSFF, self).__init__()
        self.blk1 = MSFFBlock(128)
        self.blk2 = MSFFBlock(256)
        self.blk3 = MSFFBlock(512)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.upconv32 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        )
        self.upconv21 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, features):
        # features = [level1, level2, level3]
        f1, f2, f3 = features 
#         print(f1.shape) 
#         print(f2.shape) 
#         print(f3.shape)
        
        # MSFF Module
        f1_k = self.blk1(f1)
        f2_k = self.blk2(f2)
        f3_k = self.blk3(f3)
#         print(f1_k.shape) 
#         print(f2_k.shape) 
#         print(f3_k.shape)

        f2_f = f2_k + self.upconv32(f3_k)
        f1_f = f1_k + self.upconv21(f2_f)

        # spatial attention
        
        # mask 
        m3 = f3[:,256:,...].mean(dim=1, keepdim=True)
        m2 = f2[:,128:,...].mean(dim=1, keepdim=True) * self.upsample(m3)
        m1 = f1[:,64:,...].mean(dim=1, keepdim=True) * self.upsample(m2)
        print(m1.shape) 
        print(m2.shape) 
        print(m3.shape)
        
        f1_out = f1_f * m1
        f2_out = f2_f * m2
        f3_out = f3_k * m3
        
        print(f1_out.shape) 
        print(f2_out.shape) 
        print(f3_out.shape)
        
        return [f1_out, f2_out, f3_out]
msff = MSFF()
msff(catted)[0].shape

0 : torch.Size([40, 64, 128, 128])
1 : torch.Size([40, 64, 64, 64])
2 : torch.Size([40, 128, 32, 32])
3 : torch.Size([40, 256, 16, 16])
4 : torch.Size([40, 512, 8, 8])
torch.Size([40, 1, 64, 64])
torch.Size([40, 1, 32, 32])
torch.Size([40, 1, 16, 16])
torch.Size([40, 64, 64, 64])
torch.Size([40, 128, 32, 32])
torch.Size([40, 256, 16, 16])


torch.Size([40, 64, 64, 64])