In [1]:
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from einops import rearrange
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from swin_test import SwinTransformerBlock,BasicLayer_up,PatchMerging,PatchExpand,FinalPatchExpand_X4,BasicLayer,PatchEmbed,WindowAttention,Mlp

In [2]:
from memory_module import MemModule_window,MemModule,MemModule1_new

In [46]:
class SwinTransformerSys(nn.Module):

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=3,
                 embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, final_upsample="expand_first", **kwargs):
        super().__init__()

        print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths,
        depths_decoder,drop_path_rate,num_classes))

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.num_features_up = int(embed_dim * 2)
        self.mlp_ratio = mlp_ratio
        self.final_upsample = final_upsample

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build encoder and bottleneck layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint)
            self.layers.append(layer)
        
        # build decoder layers
        self.layers_up = nn.ModuleList()
        self.concat_back_dim = nn.ModuleList()
        for i_layer in range(self.num_layers):
            concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)),
            int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity()
            if i_layer ==0 :
                layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),
                patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer)
            else:
                layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)),
                                input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),
                                                    patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))),
                                depth=depths[(self.num_layers-1-i_layer)],
                                num_heads=num_heads[(self.num_layers-1-i_layer)],
                                window_size=window_size,
                                mlp_ratio=self.mlp_ratio,
                                qkv_bias=qkv_bias, qk_scale=qk_scale,
                                drop=drop_rate, attn_drop=attn_drop_rate,
                                drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])],
                                norm_layer=norm_layer,
                                upsample=PatchExpand if (i_layer < self.num_layers - 1) else None,
                                use_checkpoint=use_checkpoint)
            self.layers_up.append(layer_up)
            self.concat_back_dim.append(concat_linear)

        self.norm = norm_layer(self.num_features)
        self.norm_up= norm_layer(self.embed_dim)

        if self.final_upsample == "expand_first":
            print("---final upsample expand_first---")
            self.up = FinalPatchExpand_X4(input_resolution=(img_size//patch_size,img_size//patch_size),dim_scale=4,dim=embed_dim)
            self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=1,bias=False)

        self.apply(self._init_weights)
        ###################################
        mem_dim=2000
        shrink_thres=0.0005
        self.mem_rep0_ = MemModule_window(mem_dim=mem_dim, fea_dim=96, window=1,c_num=8,shrink_thres =shrink_thres)
        self.mem_rep0 = MemModule_window(mem_dim=mem_dim, fea_dim=192, window=2,c_num=8,shrink_thres =shrink_thres)
        ############               
        self.mem_rep1 = MemModule_window(mem_dim=mem_dim, fea_dim=384, window=4,c_num=8,shrink_thres =shrink_thres)
        ############                       
        self.mem_rep2 = MemModule_window(mem_dim=mem_dim, fea_dim=768, window=8,c_num=8,shrink_thres =shrink_thres)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    #Encoder and Bottleneck
    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)
        
        x_downsample = []

        for layer in self.layers:
            x_downsample.append(x)
            
            x = layer(x)
#             print("x2.shape",x.shape)

        x = self.norm(x)  # B L C
  
        return x, x_downsample

    #Dencoder and Skip connection
    def forward_up_features(self, x, x_downsample):
        res=[]
        i=0
        for inx, layer_up in enumerate(self.layers_up):
            if inx == 0:
                print("x0.shape",x.shape)
                b,h_w,c=x.shape
                x=x.view(b,int(h_w**0.5),int(h_w**0.5),c).permute(0, 3, 1, 2).contiguous()
                res0=self.mem_rep0_(x)
                res.append(res0)
                x=res0["output"].view(b,c,-1).permute(0, 2, 1).contiguous()
                x = layer_up(x)
                
            else:
                print("x.shape",x.shape)
                b,h_w,c=x.shape
                x=x.view(b,int(h_w**0.5),int(h_w**0.5),c).permute(0, 3, 1, 2).contiguous()
                print("x.shape1",x.shape)
                if i==0:
                    res0=self.mem_rep0(x)
                    print(0)
                elif i==1:
                    res0=self.mem_rep1(x)
                    print(1)
                    
                else:
                    res0=self.mem_rep2(x)
                    print(2)
                    
                res.append(res0)
                i+=1
                x=res0["output"].view(b,c,-1).permute(0, 2, 1).contiguous()
                x = torch.cat([x,x_downsample[3-inx]],-1)###############
                x = self.concat_back_dim[inx](x)
                x = layer_up(x)

        x = self.norm_up(x)  # B L C
  
        return x,res

    def up_x4(self, x):
        H, W = self.patches_resolution
        B, L, C = x.shape
        assert L == H*W, "input features has wrong size"

        if self.final_upsample=="expand_first":
            x = self.up(x)
            x = x.view(B,4*H,4*W,-1)
            x = x.permute(0,3,1,2) #B,C,H,W
#             print("x3.shape",x.shape)
            x = self.output(x)
            
        return x

    def forward(self, x):
        x, x_downsample = self.forward_features(x)
        x,res = self.forward_up_features(x,x_downsample)
#         print( "x.shape",x.shape)
        x = self.up_x4(x)
#         print( "x.shape",x.shape)
        down=[]
        for i in range(len(x_downsample)):
            b,wh,c=x_downsample[i].shape
            x_downsample[i] = x_downsample[i].view(1,int(wh**0.5),int(wh**0.5),-1)
#             y = y.permute(0,3,1,2) #B,C,H,W
#             print("y.shape",y.shape)
        return x,x_downsample,res

    def flops(self):
        flops = 0
        flops += self.patch_embed.flops()
        for i, layer in enumerate(self.layers):
            flops += layer.flops()
        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
        flops += self.num_features * self.num_classes
        return flops


In [47]:
if __name__ == '__main__':
    x=torch.randn(1,3,224,224)
    model= SwinTransformerSys()
    output,down,res=model(x)
    print('output.shape',output.shape)

SwinTransformerSys expand initial----depths:[2, 2, 2, 2];depths_decoder:[1, 2, 2, 2];drop_path_rate:0.1;num_classes:3
---final upsample expand_first---
x0.shape torch.Size([1, 49, 768])
x.shape torch.Size([1, 196, 384])
x.shape1 torch.Size([1, 384, 14, 14])
0
x.shape torch.Size([1, 784, 192])
x.shape1 torch.Size([1, 192, 28, 28])
1
x.shape torch.Size([1, 3136, 96])
x.shape1 torch.Size([1, 96, 56, 56])
2
output.shape torch.Size([1, 3, 224, 224])


In [4]:

def window_partition(x, window_size):
   
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
   
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

In [33]:
class NetD(nn.Module):
    def __init__(self):
        super(NetD, self).__init__()
        ngf = 48
        nc=3
        self.model = nn.Sequential(
            nn.Conv2d(nc, ngf, 4, 2, 1, bias=False),

            nn.LeakyReLU(0.2, inplace=True))
            
        self.mode2 = nn.Sequential(    nn.Conv2d(ngf, ngf << 1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf << 1),       # 128

            nn.LeakyReLU(0.2, inplace=True),)
        
        self.mode3 = nn.Sequential(     nn.Conv2d(ngf << 1, ngf << 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf << 2),       # 256

            nn.LeakyReLU(0.2, inplace=True),)
        self.mode4 = nn.Sequential( nn.Conv2d(ngf << 2, ngf << 3, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf << 3),       # 512
            
            nn.LeakyReLU(0.2, inplace=True),)
        self.mode5 = nn.Sequential(     nn.Conv2d(ngf << 3, ngf << 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf << 4),       # 1024
            
            nn.LeakyReLU(0.2, inplace=True),)
        self.mode6 = nn.Sequential(     nn.Conv2d(ngf << 4, ngf << 5, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf << 5), 
            
            nn.LeakyReLU(0.2, inplace=True),)
        self.mode7 = nn.Sequential(     nn.Conv2d(ngf << 5, 100, 3, 1, 0, bias=False),       # 512
            nn.BatchNorm2d(100),       # 100
        )

        self.classify = nn.Sequential(
            nn.Conv2d(100, 1, 3, 1, 1, bias=False),       # 512
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.model(x)
        print("x1.shape",x.shape)
        x = self.mode2(x)
        print("x2.shape",x.shape)
        x = self.mode3(x)
        print("x3.shape",x.shape)
        x = self.mode4(x)
        print("x4.shape",x.shape)
        x = self.mode5(x)
        print("x5.shape",x.shape)
        x = self.mode6(x)
        print("x6.shape",x.shape)
        feature = self.mode7(x)
        print("x7.shape",feature.shape)
        
        classification = self.classify(feature)
        return classification.view(-1, 1).squeeze(1), feature    
    

In [35]:
if __name__ == '__main__':
    x1=torch.randn(2,3,224,224)
    model1= NetD()
    classification,feature=model1(x1)


x1.shape torch.Size([2, 48, 112, 112])
x2.shape torch.Size([2, 96, 56, 56])
x3.shape torch.Size([2, 192, 28, 28])
x4.shape torch.Size([2, 384, 14, 14])
x5.shape torch.Size([2, 768, 7, 7])
x6.shape torch.Size([2, 1536, 3, 3])
x7.shape torch.Size([2, 100, 1, 1])


In [36]:
classification.shape

torch.Size([2])

In [37]:
class NetD(nn.Module):
    def __init__(self):
        super(NetD, self).__init__()
        ngf = 64
        nc=3
        self.model = nn.Sequential(
            nn.Conv2d(nc, ngf, 4, 2, 1, bias=False),

            nn.LeakyReLU(0.2, inplace=True))
            
        self.mode2 = nn.Sequential(    nn.Conv2d(ngf, ngf << 1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf << 1),       # 128

            nn.LeakyReLU(0.2, inplace=True),)
        
        self.mode3 = nn.Sequential(     nn.Conv2d(ngf << 1, ngf << 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf << 2),       # 256

            nn.LeakyReLU(0.2, inplace=True),)
        self.mode4 = nn.Sequential( nn.Conv2d(ngf << 2, ngf << 3, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf << 3),       # 512
            
            nn.LeakyReLU(0.2, inplace=True),)
        self.mode5 = nn.Sequential(     nn.Conv2d(ngf << 3, ngf << 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf << 4),       # 1024
            
            nn.LeakyReLU(0.2, inplace=True),)
        self.mode6 = nn.Sequential(     nn.Conv2d(ngf << 4, ngf << 5, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf << 5), 
            
            nn.LeakyReLU(0.2, inplace=True),)
        self.mode7 = nn.Sequential(     nn.Conv2d(ngf << 5, 100, 4, 1, 0, bias=False),       # 512
            nn.BatchNorm2d(100),       # 100
        )

        self.classify = nn.Sequential(
            nn.Conv2d(100, 1, 3, 1, 1, bias=False),       # 512
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.model(x)
        print("x1.shape",x.shape)
        x = self.mode2(x)
        print("x2.shape",x.shape)
        x = self.mode3(x)
        print("x3.shape",x.shape)
        x = self.mode4(x)
        print("x4.shape",x.shape)
        x = self.mode5(x)
        print("x5.shape",x.shape)
        x = self.mode6(x)
        print("x6.shape",x.shape)
        feature = self.mode7(x)
        print("x7.shape",feature.shape)
        
        classification = self.classify(feature)
        return classification.view(-1, 1).squeeze(1), feature    
    


In [38]:
if __name__ == '__main__':
    x1=torch.randn(2,3,256,256)
    model1= NetD()
    classification,feature=model1(x1)


x1.shape torch.Size([2, 64, 128, 128])
x2.shape torch.Size([2, 128, 64, 64])
x3.shape torch.Size([2, 256, 32, 32])
x4.shape torch.Size([2, 512, 16, 16])
x5.shape torch.Size([2, 1024, 8, 8])
x6.shape torch.Size([2, 2048, 4, 4])
x7.shape torch.Size([2, 100, 1, 1])


In [39]:
classification.shape

torch.Size([2])