In [1]:
# 1 import package
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets
from torchvision.utils import save_image
from torchvision import transforms
import os



In [2]:
BATCH_SIze=512
EPOCHS=1000
image_size=28
channel=1
z_dim=128
device=torch.device("cuda")

In [3]:
# 1 data loader 
dataset=datasets.MNIST("../data/",train=True,transform=transforms.Compose([
    transforms.Resize(28),transforms.ToTensor(),transforms.Normalize(0.5,0.5)
]))
mnist=DataLoader(dataset,shuffle=True,batch_size=BATCH_SIze,drop_last=True)

# # VAE  model

In [6]:

class Encoder(nn.Module):
    def __init__(self,z_dim):
        super(Encoder,self).__init__()
        self.model=nn.Sequential(
        nn.Linear(image_size*image_size,1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Linear(1024,512),
            nn.ReLU(),
            nn.Linear(512,128),
            nn.ReLU(),
            nn.Linear(128,z_dim),
            nn.ReLU(),
            nn.Linear(z_dim,2)
        )
    def forward(self,x):
        x=x.view(BATCH_SIze,-1)
        y=self.model(x)
        return y
        
# decode model

class Decoder(nn.Module):
    def __init__(self,z_dim):
        super(Decoder,self).__init__()
        self.model=nn.Sequential(
            nn.Linear(z_dim,128),
            nn.ReLU(),
            nn.Linear(128,512),
            nn.ReLU(),
            nn.Linear(512,1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024,image_size*image_size),
            nn.Tanh()
        )
    def forward(self,x):
        y=self.model(x)
        return y
    
class EncoderDecoder(nn.Module):
    def __init__(self,z_dim):
        super(EncoderDecoder,self).__init__()
        self.encoder=Encoder(z_dim)
        self.decoder=Decoder(z_dim)
    def forward(self,x):
        z=self.encoder(x)
        mean=z[:,[0]]
        logvar=z[:,[1]]
        norm_std=torch.randn(BATCH_SIze,z_dim).to(device)
        x=mean+norm_std*torch.exp(logvar*0.5)
        decoder_x=self.decoder(x)
        return decoder_x,mean,logvar


# 3 diacriminator model

In [7]:

class Discriminator(nn.Module):
    def __init__(self,image_size):
        super(Discriminator,self).__init__()
        self.model=nn.Sequential(
            nn.Linear(image_size**2,1024),
            nn.ReLU(),
            nn.Linear(1024,256),
            nn.ReLU(),
            nn.Linear(256,64),
            nn.ReLU(),
            nn.Linear(64,1),
            nn.Sigmoid()
        
        )
    def forward(self,x):
        
        x=self.model(x.view(BATCH_SIze,-1))
        return x.reshape(-1)

# 4 model ,optim, loss

In [8]:
# model
ed_model=EncoderDecoder(z_dim).to(device)
d_model=Discriminator(image_size).to(device)

# optimizer
ed_optim=torch.optim.Adam(ed_model.parameters(),5e-5)
d_optim=torch.optim.Adam(d_model.parameters(),1e-4)

# loss
# 1 reconstruction loss
loss_fn_mse=nn.MSELoss()
# 2 KL loss

# 3 discriminator loss
# 4 generator loss
loss_fn_bce=nn.BCELoss()


# 5 train

In [None]:
step=0
for epoch in range(EPOCHS):
    for data,label in mnist:
        x=data.to(device)
        
        ones=torch.ones(BATCH_SIze).to(device)
        zeros=torch.zeros(BATCH_SIze).to(device)
        
        (gx,mean,logvar)=ed_model(x)
        # vae model train
        ed_optim.zero_grad()
        # rec loss
        rec_loss=loss_fn_mse(x.reshape(BATCH_SIze,-1),gx)
        # KL loss
        kl_loss=torch.mean(-0.5*(logvar+1-mean**2-torch.exp(logvar)))
        # generator loss 
        g_loss=loss_fn_bce(d_model(gx),ones)
        
        loss=rec_loss+kl_loss+g_loss
        
        loss.backward()
        ed_optim.step()
        
        # discriminator train

        z=torch.randn(BATCH_SIze,z_dim).to(device)
#         decoder_x=ed_model.decoder(z)
        y_true=d_model(x)
        y_false=d_model(gx.detach())
        d_optim.zero_grad()
        d_loss=0.5*(loss_fn_bce(y_true,ones)+loss_fn_bce(y_false,zeros))
        d_loss.backward()
        d_optim.step()
        
        if step%200==0:
            print(f"epoch: {epoch} ,step : {step}, vae_Loss: {loss}, rec_loss: {rec_loss}, kl_loss : {kl_loss},g_loss: {g_loss},  d_loss: {d_loss}")
            if step%200==0:
                os.makedirs("image",exist_ok=True)
                image=gx.reshape(BATCH_SIze,1,28,28).data[:25]
                save_image(image,"image/%d.png" % step,normalize=True,nrow=5)
        step+=1

epoch: 0 ,step : 0, vae_Loss: 1.7705625295639038, rec_loss: 1.044593095779419, kl_loss : 0.0012291695456951857,g_loss: 0.7247402667999268,  d_loss: 0.6971707940101624
epoch: 1 ,step : 200, vae_Loss: 5.60386848449707, rec_loss: 0.9433663487434387, kl_loss : 0.010479286313056946,g_loss: 4.650022983551025,  d_loss: 0.011575628072023392
epoch: 3 ,step : 400, vae_Loss: 8.625696182250977, rec_loss: 0.8570592999458313, kl_loss : 0.5531123280525208,g_loss: 7.215524673461914,  d_loss: 0.019500596448779106
epoch: 5 ,step : 600, vae_Loss: 8.155243873596191, rec_loss: 0.7585256099700928, kl_loss : 0.27250781655311584,g_loss: 7.124210357666016,  d_loss: 0.0042810868471860886
epoch: 6 ,step : 800, vae_Loss: 5.682069778442383, rec_loss: 0.6626573801040649, kl_loss : 0.22469651699066162,g_loss: 4.794715881347656,  d_loss: 0.04316107556223869
epoch: 8 ,step : 1000, vae_Loss: 9.97937297821045, rec_loss: 0.5591983199119568, kl_loss : 0.2742924392223358,g_loss: 9.145882606506348,  d_loss: 0.01884712465107

epoch: 83 ,step : 9800, vae_Loss: 5.247716426849365, rec_loss: 0.40245071053504944, kl_loss : 0.01224815659224987,g_loss: 4.833017349243164,  d_loss: 0.06445269286632538
epoch: 85 ,step : 10000, vae_Loss: 6.458179473876953, rec_loss: 0.4178459048271179, kl_loss : 0.06452710181474686,g_loss: 5.97580623626709,  d_loss: 0.07334673404693604
epoch: 87 ,step : 10200, vae_Loss: 5.8740339279174805, rec_loss: 0.47632071375846863, kl_loss : 0.1103680431842804,g_loss: 5.2873454093933105,  d_loss: 0.11511889845132828
epoch: 88 ,step : 10400, vae_Loss: 4.697774410247803, rec_loss: 0.3999008238315582, kl_loss : 0.06318850815296173,g_loss: 4.234684944152832,  d_loss: 0.06895769387483597
epoch: 90 ,step : 10600, vae_Loss: 8.219162940979004, rec_loss: 0.4553449749946594, kl_loss : 0.231176495552063,g_loss: 7.532641410827637,  d_loss: 0.04242455214262009
epoch: 92 ,step : 10800, vae_Loss: 5.72592830657959, rec_loss: 0.46476441621780396, kl_loss : 0.08586787432432175,g_loss: 5.175295829772949,  d_loss: 0

epoch: 165 ,step : 19400, vae_Loss: 4.304136276245117, rec_loss: 0.4666629731655121, kl_loss : 0.07830537855625153,g_loss: 3.7591676712036133,  d_loss: 0.1428660899400711
epoch: 167 ,step : 19600, vae_Loss: 2.8727293014526367, rec_loss: 0.46048399806022644, kl_loss : 0.022058434784412384,g_loss: 2.3901867866516113,  d_loss: 0.24343007802963257
epoch: 169 ,step : 19800, vae_Loss: 5.960926055908203, rec_loss: 0.48017507791519165, kl_loss : 0.04512684792280197,g_loss: 5.435624122619629,  d_loss: 0.08358967304229736
epoch: 170 ,step : 20000, vae_Loss: 5.3706464767456055, rec_loss: 0.4192090332508087, kl_loss : 0.03918970003724098,g_loss: 4.912247657775879,  d_loss: 0.04487435519695282
epoch: 172 ,step : 20200, vae_Loss: 2.8197522163391113, rec_loss: 0.4767586886882782, kl_loss : 0.04539422690868378,g_loss: 2.2975993156433105,  d_loss: 0.21373814344406128
epoch: 174 ,step : 20400, vae_Loss: 3.7353079319000244, rec_loss: 0.46041369438171387, kl_loss : 0.030374033376574516,g_loss: 3.244520187

epoch: 247 ,step : 29000, vae_Loss: 3.2945451736450195, rec_loss: 0.468490868806839, kl_loss : 0.06619491428136826,g_loss: 2.759859323501587,  d_loss: 0.2666470408439636
epoch: 249 ,step : 29200, vae_Loss: 3.2544469833374023, rec_loss: 0.4748402237892151, kl_loss : 0.04532436281442642,g_loss: 2.7342824935913086,  d_loss: 0.26119133830070496
epoch: 251 ,step : 29400, vae_Loss: 2.9288268089294434, rec_loss: 0.47221872210502625, kl_loss : 0.05150020122528076,g_loss: 2.4051079750061035,  d_loss: 0.267612099647522
epoch: 252 ,step : 29600, vae_Loss: 3.7980728149414062, rec_loss: 0.4440133273601532, kl_loss : 0.018622171133756638,g_loss: 3.335437297821045,  d_loss: 0.20068037509918213
epoch: 254 ,step : 29800, vae_Loss: 3.283792018890381, rec_loss: 0.4705711603164673, kl_loss : 0.04432956874370575,g_loss: 2.7688913345336914,  d_loss: 0.2577970623970032
epoch: 256 ,step : 30000, vae_Loss: 4.363112449645996, rec_loss: 0.40404555201530457, kl_loss : 0.02529480680823326,g_loss: 3.933772325515747

epoch: 329 ,step : 38600, vae_Loss: 2.861449956893921, rec_loss: 0.4700307250022888, kl_loss : 0.04873783886432648,g_loss: 2.342681407928467,  d_loss: 0.2504655718803406
epoch: 331 ,step : 38800, vae_Loss: 3.1621603965759277, rec_loss: 0.4675562381744385, kl_loss : 0.0291846115142107,g_loss: 2.665419578552246,  d_loss: 0.24702231585979462
epoch: 333 ,step : 39000, vae_Loss: 4.500203609466553, rec_loss: 0.4795455038547516, kl_loss : 0.03786821663379669,g_loss: 3.9827897548675537,  d_loss: 0.250385046005249
epoch: 335 ,step : 39200, vae_Loss: 3.415775775909424, rec_loss: 0.47426125407218933, kl_loss : 0.022694244980812073,g_loss: 2.918820381164551,  d_loss: 0.23836266994476318
epoch: 336 ,step : 39400, vae_Loss: 2.315487861633301, rec_loss: 0.4810028374195099, kl_loss : 0.026782255619764328,g_loss: 1.8077027797698975,  d_loss: 0.3969726264476776
epoch: 338 ,step : 39600, vae_Loss: 3.262760639190674, rec_loss: 0.49381959438323975, kl_loss : 0.029885035008192062,g_loss: 2.73905611038208,  

In [None]:
# random z

z=torch.randn(25,z_dim).to(device)
decoder_x=ed_model.decoder(z)
decoder_x.shape
g_x=decoder_x.reshape(-1,1,28,28).data # generator image
save_image(g_x,"image/test.png",normalize=True,nrow=5)