In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [16]:
class Encoder(nn.Module):
  
    def __init__(self):
       
        super(Encoder, self).__init__()
        base_channels=64
        self.conv1=nn.Conv2d(in_channels=base_channels, out_channels=base_channels, 
                              kernel_size=5, stride = 1,padding = 4,dilation = 2)
    def forward(self, x):
        x=self.conv1(x)
        print(x.shape)
        return x

In [17]:
model=Encoder()
input=torch.rand((1,64,128,128))
out1=model(input)

torch.Size([1, 64, 128, 128])


In [85]:
import torch.nn as nn
import torch
from torchsummary import summary
# from transformer import Block

class Encoder(nn.Module):
    """
    DCGAN ENCODER NETWORK
    """

    def __init__(self, imageSize, nz, nc, ngf, ngpu, n_extra_layers=0, add_final_conv=True):
        # nz : dimensionality of the latent space潜在空间的维度
        # nc : number of image channels
        # ndf : channels of middle layers for generator 生成器中间层通道数
        # ngpu : number of gpu
        # n_extra_layers : extra layers of Encoder and Decoder
        
        super(Encoder, self).__init__()
        self.ngpu = ngpu
        assert imageSize % 16 == 0, "imageSize has to be a multiple of 16"

        
        self.initial0 = nn.Sequential(
            nn.Conv2d(nc, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True)
        )
        
        self.pyramid0 = nn.Sequential(
            nn.Conv2d(ngf, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.pyramid1 = nn.Sequential(
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.pyramid2 = nn.Sequential(
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.pyramid3 = nn.Sequential(
            nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
        )
#         if add_final_conv:
#             self.final_conv=nn.Conv2d(1024, nz, 4, 1, 0, bias=False)
            
    def forward(self, input):
        x=[]
        
        output=self.initial0(input)
        x.append(output)
        
        output=self.pyramid0(output)
        x.append(output)
        
        output=self.pyramid1(output)
        x.append(output)
        
        output=self.pyramid2(output)
        x.append(output)
        
        output=self.pyramid3(output)
#         x.append(output)
            
        return output,x

    
class Decoder(nn.Module):
    """
    DCGAN DECODER NETWORK
    """
    def __init__(self, imageSize, nz, nc, ngf, ngpu, n_extra_layers=0):
        
        # nz : dimensionality of the latent space潜在空间的维度
        # nc : number of image channels
        # ndf : channels of middle layers for generator 生成器中间层通道数
        # ngpu : number of gpu
        # n_extra_layers : extra layers of Encoder and Decoder
        
        super(Decoder, self).__init__()
        self.ngpu = ngpu
        assert imageSize % 16 == 0, "imageSize has to be a multiple of 16"

        cngf, timageSize = ngf // 2, 4
        while timageSize != imageSize:
            cngf = cngf * 2
            timageSize = timageSize * 2
        
        self.pyramid0 = nn.Sequential(
            nn.ConvTranspose2d(1024,512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
        )
        self.pyramid1 = nn.Sequential(
            nn.ConvTranspose2d(512,256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
        )
        self.pyramid2 = nn.Sequential(
            nn.ConvTranspose2d(256,128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
        )
        self.pyramid3 = nn.Sequential(
            nn.ConvTranspose2d(128,64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
        )
        self.final0 = nn.Sequential(
            nn.ConvTranspose2d(64,nc, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

        
        self.double_conv0 = nn.Sequential(
            nn.Conv2d(2048, 1024, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True)
        )
        self.double_conv1 = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        self.double_conv2 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.double_conv3 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.double_conv4 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        
    def forward(self, input,y):
        num=len(y)-1
        
        input=self.pyramid0(input)
        input=torch.cat([y[num], input], dim=1)
        input=self.double_conv1(input)
        num=num-1
        
        input=self.pyramid1(input)
        input=torch.cat([y[num], input], dim=1)
        input=self.double_conv2(input)
        num=num-1
        
        input=self.pyramid2(input)
        input=torch.cat([y[num], input], dim=1)
        input=self.double_conv3(input)
        num=num-1
        
        input=self.pyramid3(input)
        input=torch.cat([y[num], input], dim=1)
        input=self.double_conv4(input)
        
        input=self.final0(input)
        
        return input
class Encoder_ori(nn.Module):
    """
    DCGAN ENCODER NETWORK
    """

    def __init__(self, imageSize, nz, nc, ndf, ngpu, n_extra_layers=0, add_final_conv=True):
        # nz : dimensionality of the latent space潜在空间的维度
        # nc : number of image channels
        # ndf : channels of middle layers for generator 生成器中间层通道数
        # ngpu : number of gpu
        # n_extra_layers : extra layers of Encoder and Decoder
        
        super(Encoder_ori, self).__init__()
        self.ngpu = ngpu
        assert imageSize % 16 == 0, "imageSize has to be a multiple of 16"

        main = nn.Sequential()
        # input is nc x imageSize x imageSize
        main.add_module('initial-conv-{0}-{1}'.format(nc, ndf),
                        nn.Conv2d(nc, ndf, 4, 2, 1, bias=False))
        main.add_module('initial-relu-{0}'.format(ndf),
                        nn.LeakyReLU(0.2, inplace=True))
        csize, cndf = imageSize / 2, ndf

        # Extra layers
        for t in range(n_extra_layers):
            main.add_module('extra-layers-{0}-{1}-conv'.format(t, cndf),
                            nn.Conv2d(cndf, cndf, 3, 1, 1, bias=False))
            main.add_module('extra-layers-{0}-{1}-batchnorm'.format(t, cndf),
                            nn.BatchNorm2d(cndf))
            main.add_module('extra-layers-{0}-{1}-relu'.format(t, cndf),
                            nn.LeakyReLU(0.2, inplace=True))

        while csize > 4:
            in_feat = cndf
            out_feat = cndf * 2
            main.add_module('pyramid-{0}-{1}-conv'.format(in_feat, out_feat),
                            nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False))
            main.add_module('pyramid-{0}-batchnorm'.format(out_feat),
                            nn.BatchNorm2d(out_feat))
            main.add_module('pyramid-{0}-relu'.format(out_feat),
                            nn.LeakyReLU(0.2, inplace=True))
            cndf = cndf * 2
            csize = csize / 2

        # state size. K x 4 x 4
#         if add_final_conv:
#             main.add_module('final-{0}-{1}-conv'.format(cndf, 1),
#                             nn.Conv2d(cndf, nz, 4, 1, 0, bias=False))

        self.main = main

    def forward(self, input):
        if self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        return output
    
class NetG(nn.Module):
    """
    GENERATOR NETWORK
    """

    def __init__(self,dim=1024, num_heads=4, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0.,norm_layer=nn.LayerNorm, linear=False,depths=4):
        super(NetG, self).__init__()
        
        imageSize=128
        nz=100
        nc=3
        ngf=64
        ngpu=1
        n_extra_layers=0
        
        self.depths = depths
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)]
        
        self.transformer_blocks = nn.ModuleList([Block(
                dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j], norm_layer=norm_layer,
                linear=linear)
                for j in range(depths)])
        
        self.encoder1 = Encoder(imageSize, nz, nc, ngf, ngpu, n_extra_layers)
        self.decoder = Decoder(imageSize, nz, nc, ngf, ngpu, n_extra_layers)
        self.encoder2 = Encoder_ori(imageSize, nz, nc, ngf, ngpu, n_extra_layers)

    def forward(self, x):
#         latent_i,out = self.encoder1(x)
#         x=self.transformer_blocks(x,4,4)
#         return x
# #         print(len(self.transformer_blocks))
        latent_i,out = self.encoder1(x)
        b,c,h,w=latent_i.shape
        latent_i = latent_i.flatten(2).transpose(1, 2)
        for i in range(len(self.transformer_blocks)):
            latent_i=self.transformer_blocks[i](latent_i,h,w)
        latent_i = latent_i.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
        print(latent_i.shape)
        gen_imag = self.decoder(latent_i,out)
        latent_o = self.encoder2(gen_imag)
        print('gen_imag, latent_i, latent_o',gen_imag.shape, latent_i.shape, latent_o.shape)
        return gen_imag, latent_i, latent_o
    
class NetD(nn.Module):
    """
    DISCRIMINATOR NETWORK
    """

    def __init__(self, opt):
        super(NetD, self).__init__()
        model = Encoder_(opt.imageSize, 1, opt.nc, opt.ngf, opt.ngpu, opt.n_extra_layers)
        layers = list(model.main.children())

        self.features = nn.Sequential(*layers[:-1])
        self.classifier = nn.Sequential(layers[-1])
        self.classifier.add_module('Sigmoid', nn.Sigmoid())

    def forward(self, x):
        features = self.features(x)
        features = features
        classifier = self.classifier(features)
        classifier = classifier.view(-1, 1).squeeze(1)

        return classifier, features


In [86]:
model=NetG(dim=1024, num_heads=4, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0.,norm_layer=nn.LayerNorm, linear=False,depths=4)
input=torch.rand((1,3,128,128))
out1=model(input)

torch.Size([1, 1024, 4, 4])
gen_imag, latent_i, latent_o torch.Size([1, 3, 128, 128]) torch.Size([1, 1024, 4, 4]) torch.Size([1, 1024, 4, 4])


In [82]:
out1[0].shape

torch.Size([1, 3, 128, 128])

In [62]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from functools import partial

# from timm.models.layers import DropPath, to_2tuple, trunc_normal_
# from timm.models.registry import register_model
# from timm.models.vision_transformer import _cfg
# import math

# class DWConv(nn.Module):
#     def __init__(self, dim=768):
#         super(DWConv, self).__init__()
#         self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

#     def forward(self, x, H, W):
#         B, N, C = x.shape
#         x = x.transpose(1, 2).view(B, C, H, W)
#         x = self.dwconv(x)
#         x = x.flatten(2).transpose(1, 2)

#         return x

# class Mlp(nn.Module):
#     def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False):
#         super().__init__()
#         out_features = out_features or in_features
#         hidden_features = hidden_features or in_features
#         self.fc1 = nn.Linear(in_features, hidden_features)
#         self.dwconv = DWConv(hidden_features)
#         self.act = act_layer()
#         self.fc2 = nn.Linear(hidden_features, out_features)
#         self.drop = nn.Dropout(drop)
#         self.linear = linear
#         if self.linear:
#             self.relu = nn.ReLU(inplace=True)
#         self.apply(self._init_weights)

#     def _init_weights(self, m):
#         if isinstance(m, nn.Linear):
#             trunc_normal_(m.weight, std=.02)
#             if isinstance(m, nn.Linear) and m.bias is not None:
#                 nn.init.constant_(m.bias, 0)
#         elif isinstance(m, nn.LayerNorm):
#             nn.init.constant_(m.bias, 0)
#             nn.init.constant_(m.weight, 1.0)
#         elif isinstance(m, nn.Conv2d):
#             fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
#             fan_out //= m.groups
#             m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
#             if m.bias is not None:
#                 m.bias.data.zero_()

#     def forward(self, x, H, W):
#         x = self.fc1(x)
#         if self.linear:
#             x = self.relu(x)
#         x = self.dwconv(x, H, W)
#         x = self.act(x)
#         x = self.drop(x)
#         x = self.fc2(x)
#         x = self.drop(x)
#         return x

# class Attention(nn.Module):
#     def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False):
#         super().__init__()
#         assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

#         self.dim = dim
#         self.num_heads = num_heads
#         head_dim = dim // num_heads
#         self.scale = qk_scale or head_dim ** -0.5

#         self.q = nn.Linear(dim, dim, bias=qkv_bias)
#         self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
#         self.attn_drop = nn.Dropout(attn_drop)
#         self.proj = nn.Linear(dim, dim)
#         self.proj_drop = nn.Dropout(proj_drop)

#         self.linear = linear
#         self.sr_ratio = sr_ratio

#         self.apply(self._init_weights)

#     def _init_weights(self, m):
#         if isinstance(m, nn.Linear):
#             trunc_normal_(m.weight, std=.02)
#             if isinstance(m, nn.Linear) and m.bias is not None:
#                 nn.init.constant_(m.bias, 0)
#         elif isinstance(m, nn.LayerNorm):
#             nn.init.constant_(m.bias, 0)
#             nn.init.constant_(m.weight, 1.0)
#         elif isinstance(m, nn.Conv2d):
#             fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
#             fan_out //= m.groups
#             m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
#             if m.bias is not None:
#                 m.bias.data.zero_()

#     def forward(self, x, H, W):
#         B, N, C = x.shape
#         q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
#         kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
#         k, v = kv[0], kv[1]
#         attn = (q @ k.transpose(-2, -1)) * self.scale
#         attn = attn.softmax(dim=-1)
#         attn = self.attn_drop(attn)
#         x = (attn @ v).transpose(1, 2).reshape(B, N, C)
#         x = self.proj(x)
#         x = self.proj_drop(x)
#         return x
    
# class Block(nn.Module):

#     def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
#                  drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, linear=False):
#         super().__init__()
#         self.norm1 = norm_layer(dim)
        
#         self.attn = Attention(
#             dim,
#             num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
#             attn_drop=attn_drop, proj_drop=drop, linear=linear)
#         self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
#         self.norm2 = norm_layer(dim)
#         mlp_hidden_dim = int(dim * mlp_ratio)
#         self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear)

#         self.apply(self._init_weights)
#         self.norm = nn.LayerNorm(1024)

#     def _init_weights(self, m):
#         if isinstance(m, nn.Linear):
#             trunc_normal_(m.weight, std=.02)
#             if isinstance(m, nn.Linear) and m.bias is not None:
#                 nn.init.constant_(m.bias, 0)
#         elif isinstance(m, nn.LayerNorm):
#             nn.init.constant_(m.bias, 0)
#             nn.init.constant_(m.weight, 1.0)
#         elif isinstance(m, nn.Conv2d):
#             fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
#             fan_out //= m.groups
#             m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
#             if m.bias is not None:
#                 m.bias.data.zero_()

#     def forward(self,x):
        
# #         print(x.shape)
#         H=4
#         W=4
# #         B, C, H, W = x.shape

# #         x = x.flatten(2).transpose(1, 2)
# #         x = self.norm(x)
        
#         x = x + self.drop_path(self.attn(self.norm1(x), H, W))
#         x = x + self.drop_path(self.mlp(self.norm2(x), H, W))

#         return x

In [51]:
model=Block(dim=1024, num_heads=4, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, linear=False)
input=torch.rand((1,16,1024))
out1=model(input)

In [52]:
out1.shape

torch.Size([1, 16, 1024])

In [40]:
B, H, W=1,4,4
x = out1.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

In [41]:
x.shape

torch.Size([1, 1024, 4, 4])

In [21]:
x=[[]for i in range(11)]
for j in range(4):
    for i in range(11):
        x[i].append(i)

In [22]:
x

[[0, 0, 0, 0],
 [1, 1, 1, 1],
 [2, 2, 2, 2],
 [3, 3, 3, 3],
 [4, 4, 4, 4],
 [5, 5, 5, 5],
 [6, 6, 6, 6],
 [7, 7, 7, 7],
 [8, 8, 8, 8],
 [9, 9, 9, 9],
 [10, 10, 10, 10]]