In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter


2023-07-22 17:52:41.825454: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
class Discriminator(nn.Module):
    def __init__(self,features):
        super().__init__()
        self.discr=nn.Sequential(
            nn.Linear(features,128),
            nn.LeakyReLU(0.01),
            nn.Linear(128,1),
            nn.Sigmoid(),
        )
    def forward(self,x):
        return self.discr(x)


In [3]:
class Generator(nn.Module):
    def __init__(self,n_dim,img_dim):
        super().__init__()
        self.genr=nn.Sequential(
            nn.Linear(n_dim,256),
            nn.LeakyReLU(0.01),
            nn.Linear(256,img_dim),
            nn.Tanh(),
        )
    def forward(self,x):
        return self.genr(x)



In [5]:
device="cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [6]:
lr=3e-4
n_dim=128
image_dim=28*28*1 #784 pixels
batch_size=32
epochs=1000

discr=Discriminator(image_dim).to(device)
genr=Generator(n_dim,image_dim).to(device)
fixed_noise=torch.randn((batch_size,n_dim)).to(device)
transforms=transforms.Compose(
    [
        transforms.ToTensor(), 
        transforms.Normalize((0.5,),(0.5,)),
    ]
)
datasets=datasets.MNIST(root="dataset/",transform=transforms,download=True)
loader=DataLoader(dataset=datasets,batch_size=batch_size,shuffle=True)
opt_discr=optim.Adam(discr.parameters(),lr=lr)
opt_genr=optim.Adam(genr.parameters(),lr=lr)
loss=nn.BCELoss()
writer_fake=SummaryWriter(f"runs/fake")
writer_real=SummaryWriter(f"runs/real")
step=0




In [7]:
for epoch in range(epochs+1):
    for batch_idx, (real, _) in enumerate(loader):
        real=real.view(-1,784).to(device)
        batch_size=real.shape[0]

        #Training Discriminator
        noise=torch.randn((batch_size,n_dim)).to(device)
        fake=genr(noise)
        discr_real=discr(real).view(-1)
        loss_discr_real=loss(discr_real,torch.ones_like(discr_real))
        discr_fake=discr(fake).view(-1)
        loss_discr_fake=loss(discr_fake,torch.zeros_like(discr_fake))
        loss_discr=(loss_discr_fake+loss_discr_real)/2 #Average Loss
        discr.zero_grad()
        loss_discr.backward(retain_graph=True)
        opt_discr.step()
        
        #Training Generator
        
        output=discr(fake).view(-1)
        loss_genr=loss(output,torch.ones_like(output))
        genr.zero_grad()
        loss_genr.backward()
        opt_genr.step()
        

        
        if batch_idx ==0:
            print(
                f"Epoch : [{epoch}/{epochs}] \ "
                f"Loss D: {loss_discr:.4f}  \ "
                f"Loss G: {loss_genr:.4f} \ "
            )
            with torch.no_grad():
                fake=genr(fixed_noise).reshape(-1,1,28,28)
                data=real.reshape(-1,1,28,28)
                img_grid_fake=torchvision.utils.make_grid(fake,normalize=True)
                img_grid_real=torchvision.utils.make_grid(data,normalize=True)
                
                writer_fake.add_image(
                    "MNIST FakeImages", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "MNIST FakeImages", img_grid_real, global_step=step
                )
                step+=1
                

Epoch : [0/1000] \ Loss D: 0.6852  \ Loss G: 0.7211 \ 
Epoch : [1/1000] \ Loss D: 0.5952  \ Loss G: 0.8887 \ 
Epoch : [2/1000] \ Loss D: 1.0799  \ Loss G: 0.5077 \ 
Epoch : [3/1000] \ Loss D: 0.3034  \ Loss G: 1.6203 \ 
Epoch : [4/1000] \ Loss D: 0.4128  \ Loss G: 1.1523 \ 
Epoch : [5/1000] \ Loss D: 0.5446  \ Loss G: 1.2272 \ 
Epoch : [6/1000] \ Loss D: 0.5561  \ Loss G: 1.0598 \ 
Epoch : [7/1000] \ Loss D: 0.4283  \ Loss G: 1.8957 \ 
Epoch : [8/1000] \ Loss D: 0.6828  \ Loss G: 1.1789 \ 
Epoch : [9/1000] \ Loss D: 0.5966  \ Loss G: 1.4175 \ 
Epoch : [10/1000] \ Loss D: 0.4156  \ Loss G: 1.4599 \ 
Epoch : [11/1000] \ Loss D: 0.7805  \ Loss G: 1.1069 \ 
Epoch : [12/1000] \ Loss D: 0.5806  \ Loss G: 1.5629 \ 
Epoch : [13/1000] \ Loss D: 0.9200  \ Loss G: 1.1610 \ 
Epoch : [14/1000] \ Loss D: 0.7108  \ Loss G: 0.8963 \ 
Epoch : [15/1000] \ Loss D: 0.4416  \ Loss G: 1.4987 \ 
Epoch : [16/1000] \ Loss D: 0.5630  \ Loss G: 1.5920 \ 
Epoch : [17/1000] \ Loss D: 0.7654  \ Loss G: 1.2619 \ 
Ep

Epoch : [146/1000] \ Loss D: 0.5153  \ Loss G: 1.2893 \ 
Epoch : [147/1000] \ Loss D: 0.6329  \ Loss G: 1.0249 \ 
Epoch : [148/1000] \ Loss D: 0.6427  \ Loss G: 0.9332 \ 
Epoch : [149/1000] \ Loss D: 0.5804  \ Loss G: 1.0995 \ 
Epoch : [150/1000] \ Loss D: 0.6237  \ Loss G: 1.0302 \ 
Epoch : [151/1000] \ Loss D: 0.5177  \ Loss G: 1.1186 \ 
Epoch : [152/1000] \ Loss D: 0.6808  \ Loss G: 1.0409 \ 
Epoch : [153/1000] \ Loss D: 0.5604  \ Loss G: 1.1594 \ 
Epoch : [154/1000] \ Loss D: 0.5316  \ Loss G: 0.9134 \ 
Epoch : [155/1000] \ Loss D: 0.7113  \ Loss G: 1.0271 \ 
Epoch : [156/1000] \ Loss D: 0.6547  \ Loss G: 0.9884 \ 
Epoch : [157/1000] \ Loss D: 0.6738  \ Loss G: 0.8354 \ 
Epoch : [158/1000] \ Loss D: 0.6034  \ Loss G: 1.0715 \ 
Epoch : [159/1000] \ Loss D: 0.5861  \ Loss G: 1.0469 \ 
Epoch : [160/1000] \ Loss D: 0.5099  \ Loss G: 1.1090 \ 
Epoch : [161/1000] \ Loss D: 0.6271  \ Loss G: 0.8222 \ 
Epoch : [162/1000] \ Loss D: 0.5492  \ Loss G: 0.9158 \ 
Epoch : [163/1000] \ Loss D: 0.

Epoch : [290/1000] \ Loss D: 0.6017  \ Loss G: 0.8932 \ 
Epoch : [291/1000] \ Loss D: 0.5854  \ Loss G: 1.0868 \ 
Epoch : [292/1000] \ Loss D: 0.5391  \ Loss G: 1.2199 \ 
Epoch : [293/1000] \ Loss D: 0.6141  \ Loss G: 1.2682 \ 
Epoch : [294/1000] \ Loss D: 0.5308  \ Loss G: 1.1315 \ 
Epoch : [295/1000] \ Loss D: 0.5912  \ Loss G: 0.9213 \ 
Epoch : [296/1000] \ Loss D: 0.6722  \ Loss G: 0.9317 \ 
Epoch : [297/1000] \ Loss D: 0.5688  \ Loss G: 1.1957 \ 
Epoch : [298/1000] \ Loss D: 0.5810  \ Loss G: 1.3419 \ 
Epoch : [299/1000] \ Loss D: 0.5331  \ Loss G: 1.0818 \ 
Epoch : [300/1000] \ Loss D: 0.5187  \ Loss G: 1.1235 \ 
Epoch : [301/1000] \ Loss D: 0.4922  \ Loss G: 1.3917 \ 
Epoch : [302/1000] \ Loss D: 0.5902  \ Loss G: 1.1856 \ 
Epoch : [303/1000] \ Loss D: 0.6494  \ Loss G: 1.0629 \ 
Epoch : [304/1000] \ Loss D: 0.4602  \ Loss G: 1.1420 \ 
Epoch : [305/1000] \ Loss D: 0.4847  \ Loss G: 1.3019 \ 
Epoch : [306/1000] \ Loss D: 0.5059  \ Loss G: 1.3632 \ 
Epoch : [307/1000] \ Loss D: 0.

Epoch : [434/1000] \ Loss D: 0.5250  \ Loss G: 1.3178 \ 
Epoch : [435/1000] \ Loss D: 0.5073  \ Loss G: 1.4110 \ 
Epoch : [436/1000] \ Loss D: 0.4483  \ Loss G: 1.4062 \ 
Epoch : [437/1000] \ Loss D: 0.4446  \ Loss G: 1.4463 \ 
Epoch : [438/1000] \ Loss D: 0.4206  \ Loss G: 1.4066 \ 
Epoch : [439/1000] \ Loss D: 0.5492  \ Loss G: 1.4511 \ 
Epoch : [440/1000] \ Loss D: 0.5118  \ Loss G: 1.2709 \ 
Epoch : [441/1000] \ Loss D: 0.4434  \ Loss G: 1.7471 \ 
Epoch : [442/1000] \ Loss D: 0.5636  \ Loss G: 1.5565 \ 
Epoch : [443/1000] \ Loss D: 0.5781  \ Loss G: 0.9278 \ 
Epoch : [444/1000] \ Loss D: 0.5017  \ Loss G: 1.1981 \ 
Epoch : [445/1000] \ Loss D: 0.4876  \ Loss G: 1.1783 \ 
Epoch : [446/1000] \ Loss D: 0.4645  \ Loss G: 1.1922 \ 
Epoch : [447/1000] \ Loss D: 0.5768  \ Loss G: 1.2137 \ 
Epoch : [448/1000] \ Loss D: 0.5229  \ Loss G: 1.2149 \ 
Epoch : [449/1000] \ Loss D: 0.4681  \ Loss G: 1.3242 \ 
Epoch : [450/1000] \ Loss D: 0.4257  \ Loss G: 1.5953 \ 
Epoch : [451/1000] \ Loss D: 0.

Epoch : [578/1000] \ Loss D: 0.5201  \ Loss G: 1.0241 \ 
Epoch : [579/1000] \ Loss D: 0.5427  \ Loss G: 1.5439 \ 
Epoch : [580/1000] \ Loss D: 0.6180  \ Loss G: 1.2656 \ 
Epoch : [581/1000] \ Loss D: 0.4279  \ Loss G: 1.2793 \ 
Epoch : [582/1000] \ Loss D: 0.4714  \ Loss G: 1.2913 \ 
Epoch : [583/1000] \ Loss D: 0.4898  \ Loss G: 1.2731 \ 
Epoch : [584/1000] \ Loss D: 0.5201  \ Loss G: 1.5430 \ 
Epoch : [585/1000] \ Loss D: 0.5418  \ Loss G: 1.1567 \ 
Epoch : [586/1000] \ Loss D: 0.4305  \ Loss G: 1.7331 \ 
Epoch : [587/1000] \ Loss D: 0.5292  \ Loss G: 1.1586 \ 
Epoch : [588/1000] \ Loss D: 0.6296  \ Loss G: 1.1192 \ 
Epoch : [589/1000] \ Loss D: 0.4912  \ Loss G: 1.4215 \ 
Epoch : [590/1000] \ Loss D: 0.4252  \ Loss G: 1.0206 \ 
Epoch : [591/1000] \ Loss D: 0.4799  \ Loss G: 1.6162 \ 
Epoch : [592/1000] \ Loss D: 0.4550  \ Loss G: 1.8657 \ 
Epoch : [593/1000] \ Loss D: 0.4972  \ Loss G: 2.0385 \ 
Epoch : [594/1000] \ Loss D: 0.5105  \ Loss G: 1.6713 \ 
Epoch : [595/1000] \ Loss D: 0.

Epoch : [722/1000] \ Loss D: 0.4202  \ Loss G: 1.4312 \ 
Epoch : [723/1000] \ Loss D: 0.4361  \ Loss G: 1.4704 \ 
Epoch : [724/1000] \ Loss D: 0.4880  \ Loss G: 1.6378 \ 
Epoch : [725/1000] \ Loss D: 0.5191  \ Loss G: 1.1589 \ 
Epoch : [726/1000] \ Loss D: 0.4729  \ Loss G: 1.5633 \ 
Epoch : [727/1000] \ Loss D: 0.5828  \ Loss G: 1.1614 \ 
Epoch : [728/1000] \ Loss D: 0.4650  \ Loss G: 1.6202 \ 
Epoch : [729/1000] \ Loss D: 0.4277  \ Loss G: 1.6252 \ 
Epoch : [730/1000] \ Loss D: 0.4476  \ Loss G: 1.6321 \ 
Epoch : [731/1000] \ Loss D: 0.4163  \ Loss G: 1.7379 \ 
Epoch : [732/1000] \ Loss D: 0.5576  \ Loss G: 1.7156 \ 
Epoch : [733/1000] \ Loss D: 0.4562  \ Loss G: 2.1098 \ 
Epoch : [734/1000] \ Loss D: 0.4387  \ Loss G: 1.4456 \ 
Epoch : [735/1000] \ Loss D: 0.5003  \ Loss G: 1.3187 \ 
Epoch : [736/1000] \ Loss D: 0.6256  \ Loss G: 1.0567 \ 
Epoch : [737/1000] \ Loss D: 0.4556  \ Loss G: 1.3100 \ 
Epoch : [738/1000] \ Loss D: 0.3947  \ Loss G: 2.0153 \ 
Epoch : [739/1000] \ Loss D: 0.

Epoch : [866/1000] \ Loss D: 0.5344  \ Loss G: 1.4659 \ 
Epoch : [867/1000] \ Loss D: 0.4721  \ Loss G: 1.4842 \ 
Epoch : [868/1000] \ Loss D: 0.5269  \ Loss G: 1.6508 \ 
Epoch : [869/1000] \ Loss D: 0.4021  \ Loss G: 1.7884 \ 
Epoch : [870/1000] \ Loss D: 0.4513  \ Loss G: 1.7355 \ 
Epoch : [871/1000] \ Loss D: 0.4869  \ Loss G: 1.8118 \ 
Epoch : [872/1000] \ Loss D: 0.6171  \ Loss G: 1.4142 \ 
Epoch : [873/1000] \ Loss D: 0.4193  \ Loss G: 1.6499 \ 
Epoch : [874/1000] \ Loss D: 0.4754  \ Loss G: 1.4738 \ 
Epoch : [875/1000] \ Loss D: 0.5120  \ Loss G: 1.4595 \ 
Epoch : [876/1000] \ Loss D: 0.4067  \ Loss G: 1.5353 \ 
Epoch : [877/1000] \ Loss D: 0.5649  \ Loss G: 1.0779 \ 
Epoch : [878/1000] \ Loss D: 0.4948  \ Loss G: 1.5054 \ 
Epoch : [879/1000] \ Loss D: 0.4972  \ Loss G: 2.0560 \ 
Epoch : [880/1000] \ Loss D: 0.4532  \ Loss G: 1.6663 \ 
Epoch : [881/1000] \ Loss D: 0.5109  \ Loss G: 1.7307 \ 
Epoch : [882/1000] \ Loss D: 0.3826  \ Loss G: 1.5445 \ 
Epoch : [883/1000] \ Loss D: 0.

In [8]:
%load_ext tensorboard
%tensorboard --logdir runs