<a href="https://colab.research.google.com/github/yalopez84/GAN_study/blob/master/15_6_GAN_Video_Serie_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [26]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from  torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [27]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc=nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128,1),
            nn.Sigmoid()
        )
    def forward(self,x):
        return self.disc(x)

In [28]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen=nn.Sequential(
            nn.Linear(z_dim,256),
            nn.LeakyReLU(0.1),
            nn.Linear(256,img_dim),
            nn.Tanh()
        )
    def forward(self,x):
        return self.gen(x)

In [29]:
#Hyperparameters
device="cuda" if torch.cuda.is_available() else "cpu"
lr=3e-4
z_dim=64
image_dim=28 * 28 *1
batch_size=32
num_epochs=50
disc=Discriminator(image_dim).to(device)
gen=Generator(z_dim=z_dim, img_dim=image_dim).to(device)
fixed_noise=torch.randn(batch_size, z_dim).to(device)
transforms=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
dataset=datasets.MNIST(root="dataset/", transform=transforms,download=True)
loader=DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc=optim.Adam(disc.parameters(), lr=lr)
opt_gen=optim.Adam(gen.parameters(), lr=lr)
criterion=nn.BCELoss()
writer_fake= SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real= SummaryWriter(f"runs/GAN_MNIST/real")
step=0
for epoch in range (num_epochs):
    for batch_idx, (real,_) in enumerate(loader):
        real=real.view(-1,784).to(device)
        batch_size=real.shape[0]

        #train discriminator: max log(D(real)) + log(1-D(G(z)))
        noise=torch.randn(batch_size,z_dim).to(device)
        fake=gen(noise)
        disc_real=disc(real).view(-1)
        lossD_real=criterion(disc_real,torch.ones_like(disc_real))
        disc_fake=disc(fake.detach()).view(-1)
        lossD_fake=criterion(disc_fake,torch.zeros_like(disc_fake))
        lossD=(lossD_real +lossD_fake)/2
        disc.zero_grad()
        lossD.backward()
        opt_disc.step()


        #El detach de arriba es porque una vez que se hace el backward en la perdida, todo lo que influyo en ella se limpia, sin embargo,
        #se necesita en el entrenamiento del generador esas imagenes calculadas en el generador para no tener que calcularlas twice.
        # otra forma de solucionarlo es poniendo en el metodo backward lossD.backward (retain_graph=True)
        #train generator min log(1-D(G(z))) ->max(log(D(G(z))))
        output=disc(fake).view(-1)
        lossG=criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        #Adicional code for tensor board

        if batch_idx == 0:
            print(
                f"Epoch[{epoch}/{num_epochs}] \ "
                f"Loss D:{lossD:.4f} Loss G:{lossG:.4f}"
            )
            with torch.no_grad():
                fake=gen(fixed_noise). reshape(-1,1,28,28)
                data=real.reshape(-1,1,28,28)
                img_grid_fake =torchvision.utils.make_grid(fake,normalize=True)
                img_grid_real=torchvision.utils.make_grid(data,normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist real Images", img_grid_real, global_step=step
                )
                step+=1

Epoch[0/50] \ Loss D:0.7419 Loss G:0.6998
Epoch[1/50] \ Loss D:0.1128 Loss G:2.1092
Epoch[2/50] \ Loss D:0.1426 Loss G:2.6138
Epoch[3/50] \ Loss D:0.0642 Loss G:3.9327
Epoch[4/50] \ Loss D:0.0836 Loss G:3.6783
Epoch[5/50] \ Loss D:0.0325 Loss G:4.7291
Epoch[6/50] \ Loss D:0.1130 Loss G:5.5562
Epoch[7/50] \ Loss D:0.1297 Loss G:4.9644
Epoch[8/50] \ Loss D:0.0094 Loss G:5.7367
Epoch[9/50] \ Loss D:0.0066 Loss G:5.2078
Epoch[10/50] \ Loss D:0.0093 Loss G:5.3875
Epoch[11/50] \ Loss D:0.0114 Loss G:5.6058
Epoch[12/50] \ Loss D:0.0045 Loss G:5.8727
Epoch[13/50] \ Loss D:0.0307 Loss G:5.9416
Epoch[14/50] \ Loss D:0.0060 Loss G:5.7367
Epoch[15/50] \ Loss D:0.0701 Loss G:5.5558
Epoch[16/50] \ Loss D:0.0378 Loss G:6.0299
Epoch[17/50] \ Loss D:0.0066 Loss G:5.5463
Epoch[18/50] \ Loss D:0.0088 Loss G:6.0077
Epoch[19/50] \ Loss D:0.0100 Loss G:5.7299
Epoch[20/50] \ Loss D:0.0095 Loss G:7.4298
Epoch[21/50] \ Loss D:0.0048 Loss G:6.3707
Epoch[22/50] \ Loss D:0.0304 Loss G:5.7240
Epoch[23/50] \ Loss D