In [2]:
import torch
import torch.nn as nn
from torch.optim import lr_scheduler

In [3]:
def conv2d_size_out(size: int, kernel_size: int, padding: int = 0, dilation: int = 1, stride: int = 1) -> int:
    return (size + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1 

def conv_transpose_2d_out(size: int, kernel_size: int, stride: int = 1, padding: int = 0, output_padding: int = 0, dilation: int = 1) -> int:
    return (size - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + output_padding + 1

In [4]:
x = torch.randn((1, 28, 28))

conv2d_args = {"kernel_size": 4, "stride": 3}
conv_transpose_2d_args = {"kernel_size": 4, "stride": 3}

print(f"""
Predict:
conv2d_size_out = {conv2d_size_out(size=28, **conv2d_args)}
conv_transpose_2d_out = {conv_transpose_2d_out(size=28, **conv_transpose_2d_args)}

Real:
conv2d_size_out = {nn.Conv2d(1, 16, **conv2d_args)(x).shape}
conv_transpose_2d_out = {nn.ConvTranspose2d(1, 100, **conv_transpose_2d_args)(x).shape}
""")


Predict:
conv2d_size_out = 9
conv_transpose_2d_out = 85

Real:
conv2d_size_out = torch.Size([16, 9, 9])
conv_transpose_2d_out = torch.Size([100, 85, 85])



### Discriminator: 28x28 -> 16x16 -> 8x8 -> 4x4 -> 1
### Generator: 4x4 -> 7x7 -> 14x14 -> 28x28

In [5]:
s = conv_transpose_2d_out(size=4, kernel_size=4, stride=1)
print(s)
s = conv_transpose_2d_out(size=s, kernel_size=4, stride=2, padding=1)
print(s)
s = conv_transpose_2d_out(size=s, kernel_size=4, stride=2, padding=1)
print(s)

7
14
28


In [6]:
def block(in_channels: int, out_channels: int, kernel_size: int, stride: int, padding: int):
    return [
        nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=False,
        ),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(True),
    ]

model = nn.Sequential(
    *block(in_channels=10, out_channels=64, kernel_size=4, stride=1, padding=0),
    *block(in_channels=64, out_channels=64, kernel_size=8, stride=2, padding=1),
    nn.ConvTranspose2d(
        in_channels=64, out_channels=1, kernel_size=4, stride=2, padding=1
    ),
    nn.Tanh()
)

model(torch.randn(20, 10, 1, 1)).shape


torch.Size([20, 1, 24, 24])

In [7]:
def block(in_channel: int, out_channels: int, norm: bool = True):
    layers: List[Any] = [
        nn.Conv2d(
            in_channels=in_channel,
            out_channels=out_channels,
            kernel_size=(4, 4),
            stride=2,
            padding=1,
            bias=False,
        ),
    ]
    if norm:
        layers.append(nn.BatchNorm2d(out_channels, affine=True))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
    return layers

model = nn.Sequential(
    *block(1, 64, norm=False),
    *block(64, 128),
    # *block(128, 256),
    *block(128, 256),
    nn.Conv2d(256, 1, kernel_size=7, stride=2, bias=False),
)

y = model(torch.randn(100, 1, 64, 64))
y.shape

torch.Size([100, 1, 1, 1])

In [8]:
optim = torch.optim.Adam([torch.rand((10, ))], lr=0.001)
scheduler = lr_scheduler.ConstantLR(optim, factor=0.9, total_iters=10)
for i in range(10):
    scheduler.step()



In [9]:
optim = torch.optim.Adam([torch.rand((10, ))], lr=0.001)
scheduler = lr_scheduler.LinearLR(optim, start_factor=1, end_factor=0.1, total_iters=10)
for i in range(20):
    if i > 5:
        scheduler.step()
    print(scheduler.get_lr())

[0.001]
[0.001]
[0.001]
[0.001]
[0.001]
[0.001]
[0.0008281]
[0.0007389010989010989]
[0.0006498780487804878]
[0.0005610958904109588]
[0.00047265624999999993]
[0.0003847272727272727]
[0.00029760869565217393]
[0.00021189189189189193]
[0.00012892857142857145]
[5.2631578947368444e-05]
[0.00010000000000000003]
[0.00010000000000000003]
[0.00010000000000000003]
[0.00010000000000000003]




In [10]:
class NN(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

nn = NN()
nn.apply(weights_init)

NN()