In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

2023-07-15 18:53:58.467711: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
 torch.cuda.get_device_name(0)

'Radeon RX 560 Series'

In [3]:
class Discriminator(nn.Module):
    def __init__(self,img_channels,filters_size):
        super(Discriminator,self).__init__()
        self.disc = nn.Sequential(
            # input: N x channels_img x 64 x 64
            nn.Conv2d(
                img_channels,filters_size,kernel_size=4,stride=2,padding=1
            ), # 32*32
            nn.LeakyReLU(0.2),
            self._block(filters_size,filters_size*2,4,2,1), #16*16
            self._block(filters_size*2,filters_size*4,4,2,1), #8*8
            self._block(filters_size*4,filters_size*8,4,2,1), #4*4
            nn.Conv2d(filters_size*8, 1, kernel_size=4, stride=2, padding=0, bias=False),#1*1
        )
        
    def _block(self,in_channels,out_channels,kernel_size,stride,padding):
        return nn.Sequential(
        nn.Conv2d(in_channels,
                  out_channels,
                  kernel_size,
                  stride,
                  padding,
                  bias=False),
        nn.InstanceNorm2d(out_channels,affine=True),
        nn.LeakyReLU(0.2)
        )
    
    def forward(self,x):
        return self.disc(x)

In [4]:
class Generator(nn.Module):
    def __init__(self,latent_space,img_channels,filter_size):
        super(Generator,self).__init__()
        self.gen = nn.Sequential(
            # Input n*latent_space*1*1
            self._block(latent_space,filter_size*16,4,1,0), #n*f_g*16*4*4
            self._block(filter_size*16,filter_size*8,4,2,1), #8
            self._block(filter_size*8,filter_size*4,4,2,1), #16
            self._block(filter_size*4,filter_size*2,4,2,1), #32
            nn.ConvTranspose2d(
                filter_size*2, img_channels, kernel_size=4, stride=2, padding=1, bias=False
            ),
            # Output: N x channels_img x 64 x 64
            nn.Tanh() #[-1,1]
        )
        
    def _block(self,in_channels,out_channels,kernel_size,stride,padding):
        return nn.Sequential(
        nn.ConvTranspose2d(in_channels,
                  out_channels,
                  kernel_size,
                  stride,
                  padding,
                  bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(0.2)
        )
    
    def forward(self,x):
        return self.gen(x)

In [5]:
def gradient_penalty(critic,image,fake_image):
    batch_size, C, H, W = image.shape
    eplison = torch.rand((batch_size, 1, 1, 1)).repeat(1, C, H, W)
    interpolated_images = image * eplison + fake_image * (1 - eplison)
   
    #calculate the critic scores
    mixed_scores = critic(interpolated_images)
    
    gradient = torch.autograd.grad(
        inputs = interpolated_images,
        outputs = mixed_scores,
        grad_outputs = torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True
    )[0]
    
    gradient = gradient.view(gradient.shape[0],-1)
    gradient_norm = gradient.norm(2,dim=1)
    gradient_penalty = torch.mean((gradient_norm-1)**2)
    return gradient_penalty

In [6]:
# custom weights initialization called on ``netG`` and ``netD``
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [7]:
def test():
    N, in_channels, H, W = 8, 3, 64, 64
    noise_dim = 100
    x = torch.randn((N, in_channels, H, W))
    critic = Discriminator(in_channels, 8)
    critic.apply(weights_init)
    assert critic(x).shape == (N, 1, 1, 1), "discriminator test failed"
    gen = Generator(noise_dim, in_channels, 8)
    gen.apply(weights_init)
    z = torch.randn((N, noise_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
    print("Success, tests passed!")


test()

Success, tests passed!


In [11]:
learning_rate = 1e-4
batch_size = 64
img_channels = 3
img_size = 64
latent_space = 100
filter_size = 64
critic_iterations = 5
lambdaGP = 10

In [12]:
transform = transforms.Compose(
    [
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize([
            0.5 for _ in range(img_channels)],[0.5 for _ in range(img_channels)])
    ]
)

dataset = datasets.ImageFolder(root="faces/", transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

gen = Generator(latent_space,img_channels,filter_size)
critic = Discriminator(img_channels,filter_size)

gen.apply(weights_init)
critic.apply(weights_init)

opt_gen = optim.Adam(gen.parameters(),lr=learning_rate,betas=(0.0,0.9))
opt_critic = optim.Adam(critic.parameters(),lr=learning_rate,betas=(0.0,0.9))

In [10]:
gen.train()

Generator(
  (gen): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(100, 256, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (3): Sequential(
      (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)

In [11]:
critic.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(3, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(128, 1, kernel_size=(4, 4), stride=(2, 2), bias=False)
  )
)

In [18]:
num_epochs = 25
fixed_noise = torch.randn(32,latent_space,1,1)
G_losses = []
D_losses = []
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

for epoch in range(num_epochs):
    for batch_idx,(img,_) in enumerate(dataloader):
        image = img
        cur_batch_size = image.shape[0]
        
        for _ in range(critic_iterations):
            noise = torch.randn(cur_batch_size,latent_space,1,1)
            fake_image = gen(noise)
            
            # Train the critic: min −∇θ [(Dw(x(i)) − 1) (Dw(gθ (z(i))+λ*‖∇gθ (z(i)(Dw(gθ (z(i)))‖2 − 1)2]
            real_critic = critic(image).reshape(-1)
            fake_critic = critic(fake_image).reshape(-1)
            gp = gradient_penalty(critic,image,fake_image)
            loss_critic = (
                -(torch.mean(real_critic)-torch.mean(fake_critic)) + lambdaGP * gp
            )

            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()
                
        # Train the Generator: min −∇θ(Dw(gθ (z(i))
        lossGf = critic(fake_image).view(-1)
        loss_gen = -(torch.mean(lossGf))

        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()
        
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} \
                      Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                data = image
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
                img_grid_real = torchvision.utils.make_grid(data[:32], normalize=True)

                writer_fake.add_image(
                    "Manushya Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Manushya Real Images", img_grid_real, global_step=step
                )
                step += 1
                
        # Save Losses for plotting later
        G_losses.append(loss_gen.item())
        D_losses.append(loss_critic.item())

Epoch [0/25] Batch 0/7                       Loss D: -17.0362, loss G: 28.0876
Epoch [1/25] Batch 0/7                       Loss D: -19.4724, loss G: 28.5305
Epoch [2/25] Batch 0/7                       Loss D: -15.9333, loss G: 29.8470
Epoch [3/25] Batch 0/7                       Loss D: -14.8889, loss G: 29.9046
Epoch [4/25] Batch 0/7                       Loss D: -15.0982, loss G: 29.3289
Epoch [5/25] Batch 0/7                       Loss D: -16.0886, loss G: 29.3797
Epoch [6/25] Batch 0/7                       Loss D: -14.0970, loss G: 28.6335
Epoch [7/25] Batch 0/7                       Loss D: -15.5221, loss G: 28.3341
Epoch [8/25] Batch 0/7                       Loss D: -14.3186, loss G: 30.5688
Epoch [9/25] Batch 0/7                       Loss D: -14.7423, loss G: 30.7595
Epoch [10/25] Batch 0/7                       Loss D: -15.7715, loss G: 30.3951
Epoch [11/25] Batch 0/7                       Loss D: -14.7235, loss G: 32.5841
Epoch [12/25] Batch 0/7                       Loss

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(loss_disc.item(),label="G")
plt.plot(loss_gen.item(),label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [60]:
import numpy as np 
n = torch.randn(1,100,1,1)
fa = gen(n)
dim1 = fa.unsqueeze(2)
dim2 = fa.unsqueeze(3) 
dim1 = np.array(dim1)
# transform = transforms.ToPILImage()
# img = transform((dim1,dim2))
# img.show()

RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

In [None]:
torch.save({
            'epoch': num_epochs,
            'model_state_dict': gen.state_dict(),
            'optimizer_state_dict': opt_gen.state_dict(),
            'loss': loss_gen
            }, f="gen_model/")

In [None]:
torch.save({
            'epoch': num_epochs,
            'model_state_dict': disc.state_dict(),
            'optimizer_state_dict': opt_disc.state_dict(),
            'loss': loss_disc,
            ...
            }, "disc_model/")

In [None]:
def save_checkpoint(state, filename="celeba_wgan_gp.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)


def load_checkpoint(checkpoint, gen, disc):
    print("=> Loading checkpoint")
    gen.load_state_dict(checkpoint['gen'])
    disc.load_state_dict(checkpoint['disc'])