In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install pyforest



In [None]:
import pyforest
import torch

import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable

In [None]:
# Define the data transformations
transform = transforms.Compose([transforms.Resize(64), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Create the CIFAR-10 dataset
dataset = dset.CIFAR10(root='./data', download=True, transform=transform)

# Create a data loader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=6)

Files already downloaded and verified




In [None]:
def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    m.weight.data.normal_(0.0, 0.02)

  elif classname.find('BatchNorm') != -1:
    m.weight.data.normal_(1.0, 0.02)
    m.bias.data.fill_(0)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        output = self.main(input)
        return output

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1)

In [None]:
generator = Generator()
generator.apply(weights_init)

discriminator = Discriminator()
discriminator.apply(weights_init)

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

In [None]:
criterion = nn.BCELoss()
optimizerG = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
for epoch in range(25):

  for i, data in enumerate(dataloader, 0):
    discriminator.zero_grad()
    real, _ = data
    input = Variable(real)
    target = Variable(torch.ones(input.size()[0]))
    output = discriminator(input)
    errD_real = criterion(output, target)

    noise = Variable(torch.randn(input.size()[0], 100, 1, 1))
    fake = generator(noise)
    target = Variable(torch.zeros(input.size()[0]))
    output = discriminator(fake.detach())
    errD_fake = criterion(output, target)

    errD = errD_real + errD_fake
    errD.backward()
    optimizerD.step()


    generator.zero_grad()
    target = Variable(torch.ones(input.size()[0]))
    output = discriminator(fake)
    errG = criterion(output, target)
    errG.backward()
    optimizerG.step()

    print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (epoch, 25, i, len(dataloader), errD.item(), errG.item()))

    if i % 200 == 0:
      vutils.save_image(real, f'/content/drive/MyDrive/DCGan/resultsreal_samples{epoch}_{i}.jpg', normalize=True)
      fake = generator(noise)
      vutils.save_image(fake.data, f'/content/drive/MyDrive/DCGan/results/fake_samples{epoch}_{i}.jpg', normalize=True)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[2/25][1201/1563] Loss_D: 0.9887 Loss_G: 6.7650
[2/25][1202/1563] Loss_D: 1.9874 Loss_G: 0.6496
[2/25][1203/1563] Loss_D: 1.5374 Loss_G: 7.5359
[2/25][1204/1563] Loss_D: 1.7235 Loss_G: 3.6410
[2/25][1205/1563] Loss_D: 0.5936 Loss_G: 1.0662
[2/25][1206/1563] Loss_D: 1.9033 Loss_G: 6.0231
[2/25][1207/1563] Loss_D: 2.0525 Loss_G: 1.6707
[2/25][1208/1563] Loss_D: 0.3645 Loss_G: 2.5465
[2/25][1209/1563] Loss_D: 1.2831 Loss_G: 5.5169
[2/25][1210/1563] Loss_D: 1.2583 Loss_G: 1.8962
[2/25][1211/1563] Loss_D: 0.6085 Loss_G: 2.2548
[2/25][1212/1563] Loss_D: 0.4702 Loss_G: 3.0741
[2/25][1213/1563] Loss_D: 0.6030 Loss_G: 3.2160
[2/25][1214/1563] Loss_D: 0.6514 Loss_G: 2.6230
[2/25][1215/1563] Loss_D: 0.8958 Loss_G: 1.5635
[2/25][1216/1563] Loss_D: 1.1474 Loss_G: 4.9954
[2/25][1217/1563] Loss_D: 0.6812 Loss_G: 2.7729
[2/25][1218/1563] Loss_D: 0.3093 Loss_G: 2.9243
[2/25][1219/1563] Loss_D: 0.2561 Loss_G: 3.3259
[2/25][1220/1563] Loss_