In [1]:
# DCGAN
import math
import torch
import torch.nn as nn
import torch.optim as optim


def weights_init(m):
    classname = m.__class__.__name__
    if classname == 'ConvTranspose2d' or classname == 'Conv2d':
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname == 'BatchNorm2d':
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class ConvBlock(nn.Module):
    def __init__(self, conv_kwargs, activation='leaky_relu', normalization='batch_normalization'):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(**conv_kwargs)
        
        if normalization == 'batch_normalization': self.norm = nn.BatchNorm2d(conv_kwargs['out_channels'])
        elif normalization == 'instance_normalization': self.norm = nn.InstanceNorm2d(conv_kwargs['out_channels'])
        elif normalization is None: self.norm = nn.Sequential()
        
        if activation == 'leaky_relu': self.actv = nn.LeakyReLU(0.2)
        elif activation == 'sigmoid': self.actv = nn.Sigmoid()
    
    def forward(self, x):
        out = self.conv(x)
        out = self.norm(out)
        out = self.actv(out)
        return out

    
class DCGANDiscriminator(nn.Module):
    def __init__(self, args):
        super(DCGANDiscriminator, self).__init__()
        depth = int(math.log2(args.resolution)) -1

        net = [ConvBlock(
            {
                'in_channels': args.nc,
                'out_channels': args.ngf,
                'kernel_size': 4,
                'stride': 2,
                'padding': 1
            },
            normalization=None
        )]
        
        mult = 1
        for _ in range(depth - 2):
            net.append(ConvBlock(
                {
                    'in_channels': args.ngf * mult,
                    'out_channels': args.ngf * mult * 2,
                    'kernel_size': 4,
                    'stride': 2,
                    'padding': 1
                }
            ))
            mult *= 2
        net.append(ConvBlock(
            {
                'in_channels': args.ngf * mult,
                'out_channels': 1,
                'kernel_size': 4,
                'stride': 1,
                'padding': 0
            },
            activation='sigmoid',
            normalization=None
        ))
        self.net = nn.Sequential(*net)
        self.apply(weights_init)
            
    def forward(self, x):
        return self.net(x)
        

class ConvTransposeBlock(nn.Module):
    
    def __init__(self, deconv_kwargs, activation='relu', normalization='batch_normalization'):
        super(ConvTransposeBlock, self).__init__()
        self.deconv = nn.ConvTranspose2d(**deconv_kwargs)
        
        if normalization == 'batch_normalization': self.norm = nn.BatchNorm2d(deconv_kwargs['out_channels'])
        elif normalization == 'instance_normalization': self.norm = nn.InstanceNorm2d(deconv_kwargs['out_channels'])
        elif normalization is None: self.norm = nn.Sequential()
        
        if activation == 'relu': self.actv = nn.ReLU()
        elif activation == 'tanh': self.actv = nn.Tanh()
    
    def forward(self, x):
        out = self.deconv(x)
        out = self.norm(out)
        out = self.actv(out)
        return out
    
    

class DCGANGenerator(nn.Module):
    def __init__(self, args):
        super(DCGANGenerator, self).__init__()
        depth = int(math.log2(args.resolution)) -1
        
        mult = 2 ** (depth - 2)
        net = [ConvTransposeBlock(
            {
                'in_channels': args.nz,
                'out_channels': args.ngf * mult,
                'kernel_size': 4,
                'stride': 1,
                'padding': 0
            },
        )]
        for _ in range(depth - 2):
            mult = int(mult * 0.5)
            net.append(ConvTransposeBlock(
                {
                    'in_channels': args.ngf * mult * 2,
                    'out_channels': args.ngf * mult,
                    'kernel_size': 4,
                    'stride': 2,
                    'padding': 1
                },
            ))
        net.append(ConvTransposeBlock(
            {
                'in_channels': args.ngf * mult,
                'out_channels': args.nc,
                'kernel_size': 4,
                'stride': 2,
                'padding': 1
            },
            activation='tanh',
            normalization=None
        ))
        
        self.net = nn.Sequential(*net)
        self.apply(weights_init)
        
    def forward(self, x):
        return self.net(x)


class DCGANLoss(nn.Module):
    def __init__(self, args):
        super(DCGANLoss, self).__init__()
        self.device = args.device
        self.bce = nn.BCELoss()
        
    def forward(self, x, mode='discriminator_loss'):
        if mode == 'discriminator_loss':
            fake_pred, real_pred = x
            real_loss = self.bce(real_pred, torch.tensor(1.0).expand_as(real_pred).to(self.device))
            fake_loss = self.bce(fake_pred, torch.tensor(0.0).expand_as(fake_pred).to(self.device))
            loss = (real_loss + fake_loss) * 0.5
            
        elif mode == 'generator_loss':
            fake_pred, _ = x
            loss = self.bce(fake_pred, torch.tensor(1.0).expand_as(fake_pred).to(self.device))
        
        return loss
    
class DCGANOptimizer(optim.Adam):
    def __init__(self, args, params):
        self.args = args
        self.params = params
        
        super(DCGANOptimizer, self).__init__(
            params,
            lr=args.lr,
        )

    def step(self):
        loss = super(DCGANOptimizer, self).step()
        return loss

components = {
    'generator': DCGANGenerator,
    'discriminator': DCGANDiscriminator,
    'criterion': DCGANLoss,
    'optimizer': DCGANOptimizer
}

NameError: name 'optim' is not defined

In [4]:
from easydict import EasyDict
args = EasyDict({
    'resolution': 32,
    'nz': 100,
    'ngf': 64,
    'ndf': 64,
    'nc': 3,
    'device': 'cpu'
})
g = DCGANGenerator(args)
g(torch.rand(1, 100, 1, 1))
# d = DCGANDiscriminator(args)
# d(torch.rand(1, 3, 32, 32))
criterion = DCGANLoss(args)
criterion((torch.rand(1), torch.rand(1)), mode='generator_loss')

tensor(3.9743)

In [5]:
g

DCGANGenerator(
  (net): Sequential(
    (0): ConvTransposeBlock(
      (deconv): ConvTranspose2d(100, 256, kernel_size=(4, 4), stride=(1, 1))
      (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (actv): ReLU()
    )
    (1): ConvTransposeBlock(
      (deconv): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (actv): ReLU()
    )
    (2): ConvTransposeBlock(
      (deconv): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (actv): ReLU()
    )
    (3): ConvTransposeBlock(
      (deconv): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (norm): Sequential()
      (actv): Tanh()
    )
  )
)

In [131]:
g.net[0].norm.weight

Parameter containing:
tensor([-3.3361e-02,  4.7484e-03,  8.8337e-03, -4.2278e-02, -1.4795e-02,
         1.5242e-02,  1.4019e-03, -7.7961e-03,  1.1149e-03, -6.3838e-03,
         4.4132e-02,  1.6390e-02,  7.8639e-03, -3.3794e-03,  4.6303e-03,
         6.1060e-03, -1.9545e-04, -3.9471e-03,  2.2000e-03, -9.4550e-03,
         2.1371e-02, -1.9330e-02,  1.9404e-02, -1.4085e-03,  9.7211e-03,
         7.1018e-04, -2.0313e-02,  5.4012e-03,  2.8326e-02, -3.5813e-02,
        -2.5176e-02, -5.6153e-03, -4.2708e-03, -2.1380e-02, -5.4555e-03,
         6.1231e-03, -2.1843e-02,  1.7375e-04,  4.2771e-03, -1.1101e-02,
        -1.6016e-02,  2.2322e-02, -5.2337e-03,  8.3100e-03,  3.6959e-02,
        -2.2778e-05, -1.5645e-02, -1.2547e-02, -4.2400e-03,  4.3271e-03,
         2.5820e-02, -2.4466e-02, -6.3521e-03, -3.1910e-02,  1.6226e-02,
        -4.7438e-04, -2.4821e-02, -4.8451e-02, -4.1103e-02, -1.2347e-02,
        -5.0930e-03,  2.0740e-02,  3.0244e-04,  1.3278e-02, -2.3070e-04,
        -7.1682e-03, -8.6342e