In [1]:
# temporary implementation of the dcgan codebase
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # input: N x channels_img x 64 x 64
            nn.Conv2d(
                channels_img, features_d, kernel_size=4, stride=2, padding=1
            ),
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            self._block(features_d * 8, features_d * 16, 4, 2, 1),
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 16, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            #nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            self._block(channels_noise, features_g*32, 4, 1, 0),
            self._block(features_g*32, features_g*16, 4, 2, 1),
            self._block(features_g*16, features_g*8, 4, 2, 1),
            self._block(features_g*8, features_g*4, 4, 2, 1),
            self._block(features_g*4, features_g*2, 4, 2, 1),
            nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1),
            # Output: N x channels_img x 64 x 64
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            #nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)


def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 0.00025  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 8
IMAGE_SIZE = 128
CHANNELS_IMG = 3
NOISE_DIM = 128
NUM_EPOCHS = 7500
FEATURES_DISC = 128
FEATURES_GEN = 128


img_list = []
G_loss = []
D_loss = []

transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

dataset = datasets.ImageFolder(root="/ssd_scratch/cvit/aditya1/temp_images/", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(128, NOISE_DIM, 1, 1).to(device)


step = 0

gen.train()
disc.train()

for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.to(device)
        noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
        fake = gen(noise)

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        G_loss.append(loss_gen.item())
        D_loss.append(loss_disc.item())

        # Print losses occasionally and print to tensorboard
        if batch_idx in range(BATCH_SIZE):
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )
    
            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:1], normalize=True
                )
                
                for batch_idx in range(BATCH_SIZE):
                  torchvision.utils.save_image(img_grid_fake, 'sample.png', global_step=step)

            step += 1

Epoch [0/7500] Batch 0/3750                   Loss D: 0.7053, loss G: 6.7711
Epoch [0/7500] Batch 1/3750                   Loss D: 0.8983, loss G: 0.0000
Epoch [0/7500] Batch 2/3750                   Loss D: 50.0000, loss G: 0.0000
Epoch [0/7500] Batch 3/3750                   Loss D: 50.0000, loss G: 0.0000
Epoch [0/7500] Batch 4/3750                   Loss D: 50.0000, loss G: 0.0000
Epoch [0/7500] Batch 5/3750                   Loss D: 50.0000, loss G: 0.0000
Epoch [0/7500] Batch 6/3750                   Loss D: 50.0000, loss G: 0.0000
Epoch [0/7500] Batch 7/3750                   Loss D: 50.0000, loss G: 0.0000


KeyboardInterrupt: 

In [3]:
# self._block(channels_noise, features_g * 32, 4, 1, 0),
#             self._block(features_g * 32, features_g * 16, 4, 1, 1),  # img: 4x4
#             self._block(features_g * 16, features_g * 8, 4, 2, 1),  # img: 8x8
#             self._block(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
#             self._block(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
#             nn.ConvTranspose2d(
#                 features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
#             ),
            
channels_noise = 128
features_g = 128
channels_img = 3

# define the convtranspose2d layer 
def block(in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            #nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

In [10]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

block1 = block(channels_noise, features_g*32, 4, 1, 0).to(device)
block2 = block(features_g*32, features_g*16, 4, 2, 1).to(device)
block3 = block(features_g*16, features_g*8, 4, 2, 1).to(device)
block4 = block(features_g*8, features_g*4, 4, 2, 1).to(device)
block5 = block(features_g*4, features_g*2, 4, 2, 1).to(device)
block6 = nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1).to(device)

LEARNING_RATE = 0.00025  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 128
IMAGE_SIZE = 128
CHANNELS_IMG = 3
NOISE_DIM = 128
NUM_EPOCHS = 7500
FEATURES_DISC = 128
FEATURES_GEN = 128
noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)

out1 = block1(noise)
out2 = block2(out1)
out3 = block3(out2)
out4 = block4(out3)
out5 = block5(out4)
out6 = block6(out5)

print(out1.shape)
print(out2.shape)
print(out3.shape)
print(out4.shape)
print(out5.shape)
print(out6.shape)

torch.Size([128, 4096, 4, 4])
torch.Size([128, 2048, 8, 8])
torch.Size([128, 1024, 16, 16])
torch.Size([128, 512, 32, 32])
torch.Size([128, 256, 64, 64])
torch.Size([128, 3, 128, 128])


In [11]:
# original shapes -- 
block1 = block(channels_noise, features_g * 32, 4, 1, 0).to(device)
block2 = block(features_g * 32, features_g * 16, 4, 1, 1).to(device)
block3 = block(features_g * 16, features_g * 8, 4, 2, 1).to(device)
block4 = block(features_g * 8, features_g * 4, 4, 2, 1).to(device)  
block5 = block(features_g * 4, features_g * 2, 4, 2, 1).to(device)  
block6 = nn.ConvTranspose2d(
    features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
).to(device)

out1 = block1(noise)
out1 = block1(noise)
out2 = block2(out1)
out3 = block3(out2)
out4 = block4(out3)
out5 = block5(out4)
out6 = block6(out5)

print(out1.shape)
print(out2.shape)
print(out3.shape)
print(out4.shape)
print(out5.shape)
print(out6.shape)

torch.Size([128, 4096, 4, 4])
torch.Size([128, 2048, 5, 5])
torch.Size([128, 1024, 10, 10])
torch.Size([128, 512, 20, 20])
torch.Size([128, 256, 40, 40])
torch.Size([128, 3, 80, 80])


In [45]:
out1 = block1(noise)
out2 = block2(out1)
out3 = block3(out2)
out4 = block4(out3)
out5 = block5(out4)
out6 = block7(out5)

print(noise.shape, out1.shape, out2.shape, out3.shape, out4.shape, out5.shape, out6.shape)
 
# expected image size should be 32 x 3 x 128 x 128

torch.Size([32, 128, 1, 1]) torch.Size([32, 4096, 4, 4]) torch.Size([32, 2048, 8, 8]) torch.Size([32, 1024, 16, 16]) torch.Size([32, 512, 32, 32]) torch.Size([32, 256, 64, 64]) torch.Size([32, 3, 128, 128])


In [23]:
# check the disc 
image = torch.randn(32, 3, 128, 128).to(device)

disc_out = disc(image)
print(disc_out.shape)

torch.Size([32, 1, 1, 1])
