In [40]:
import torch
from Util import *

In [41]:
# MNIST digit pictures are 28 x 28.
image_dims = (28,28)
latent_size = 10
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 1 x 28 x 28 => 32 x 14 x 14
conv1 = torch.nn.Sequential(
    torch.nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(5, 5),
                    stride=(2, 2), padding=(2, 2), bias=False),
    torch.nn.LeakyReLU(negative_slope=0.01)
)

# 32 x 14 x 14 => 64 x 7 x 7
conv2 = torch.nn.Sequential(
    torch.nn.Conv2d(in_channels=32,out_channels=64,kernel_size=(5,5),
                    stride=(2,2),padding=(2,2),bias=False),
    torch.nn.LeakyReLU(negative_slope=0.01)
)

# 64 x 7 x 7 => 128 x 3 x 3
conv3 = torch.nn.Sequential(
    torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(5, 5),
                    stride=(2, 2), padding=(1, 1), bias=False),
    torch.nn.LeakyReLU(negative_slope=0.01)
)

my_flatten = torch.nn.Flatten()

fc1 = torch.nn.Sequential(
            torch.nn.Linear(in_features=1152,
                            out_features=latent_size),
            torch.nn.LeakyReLU(negative_slope=0.01))


fc2 = torch.nn.Sequential(
            torch.nn.Linear(in_features=latent_size,
                            out_features=1),
            torch.nn.LeakyReLU(negative_slope=0.01))


In [42]:
X = torch.randn(size=(2,1,28,28))
print("Shape:", X.shape)
X = conv1(X)
print("Shape:", X.shape)
X = conv2(X)
print("Shape:", X.shape)
X = conv3(X)
print("Shape:", X.shape)
X = my_flatten(X)
print("Shape:", X.shape)
X = fc1(X)
print("Shape:", X.shape)
X = fc2(X)
print("Shape:", X.shape)


Shape: torch.Size([2, 1, 28, 28])
Shape: torch.Size([2, 32, 14, 14])
Shape: torch.Size([2, 64, 7, 7])
Shape: torch.Size([2, 128, 3, 3])
Shape: torch.Size([2, 1152])
Shape: torch.Size([2, 10])
Shape: torch.Size([2, 1])


In [95]:
fc3 = torch.nn.Sequential(
            torch.nn.Linear(in_features=latent_size,
                            out_features=1152),
            torch.nn.BatchNorm1d(num_features=1152),
            torch.nn.LeakyReLU(negative_slope=0.01))

reshape = Reshape(128, 3, 3)

conv4 = torch.nn.Sequential(
    torch.nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(5, 5),
                    stride=(2, 2), padding=(1, 1), bias=False),
    torch.nn.BatchNorm2d(num_features=64),
    torch.nn.LeakyReLU(negative_slope=0.01)
)

conv5 = torch.nn.Sequential(
    torch.nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=(6, 6),
                    stride=(2, 2), padding=(2, 2), bias=False),
    torch.nn.BatchNorm2d(num_features=32),
    torch.nn.LeakyReLU(negative_slope=0.01)
)

conv6 = torch.nn.Sequential(
    torch.nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=(6, 6),
                    stride=(2, 2), padding=(2, 2), bias=False),
    torch.nn.BatchNorm2d(num_features=1),
    torch.nn.LeakyReLU(negative_slope=0.01)
)


In [97]:
X = torch.randn(size=(2,latent_size))
print("Shape:", X.shape)
X = fc3(X)
print("Shape:", X.shape)
X = reshape(X)
print("Shape:", X.shape)
X = conv4(X)
print("Shape:", X.shape)
X = conv5(X)
print("Shape:", X.shape)
X = conv6(X)
print("Shape:", X.shape)


Shape: torch.Size([2, 10])
Shape: torch.Size([2, 1152])
Shape: torch.Size([2, 128, 3, 3])
Shape: torch.Size([2, 64, 7, 7])
Shape: torch.Size([2, 32, 14, 14])
Shape: torch.Size([2, 1, 28, 28])
