In [175]:
from typing import List
import numpy as np
import os
from tqdm.auto import tqdm
from pathlib import Path

import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision
import torchvision.transforms as T
from torchvision.utils import save_image

In [176]:
class args:
    gpus = "0"
    # dataset
    n_epochs = 200000
    batch_size = 64
    img_shape = [1,32,32]
    
    n_latent = 100
    
    G_lr = 0.0002
    G_betas = (0.5, 0.999)
    D_lr = 0.0002
    D_betas = (0.5, 0.999)
    
    gene_img_dir = Path("./generated_images")

if not os.path.isdir(args.gene_img_dir):
    os.makedirs(args.gene_img_dir)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [177]:
transform = T.Compose([
    T.ToTensor()
])
train_dataset = torchvision.datasets.MNIST(root="/data", download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = args.batch_size, shuffle=True)

In [237]:
def init_weight(layer):
    cls_name = layer.__class__.__name__
    if cls_name.find("Conv") != -1:
        torch.nn.init.normal_(layer.weight.data, 0.0, 0.02)
    elif cls_name.find("BatchNorm") != -1:
        torch.nn.init.normal_(layer.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(layer.bias.data, 0.0)
        

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = 28 // 4
        self.l1 = nn.Sequential(nn.Linear(100, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.ConvTranspose2d(128, 128, 2, 2),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 64, 2, 2),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 1, 3, 1, 1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(1, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = 32 // 2 ** 4
        self.adv_layer = nn.Linear(128 * ds_size ** 2, 1)

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity

In [238]:
# def init_weight(layer):
#     cls_name = layer.__class__.__name__
#     if cls_name.find("Conv") != -1:
#         torch.nn.init.normal_(layer.weight.data, 0.0, 0.02)
#     elif cls_name.find("BatchNorm") != -1:
#         torch.nn.init.normal_(layer.weight.data, 1.0, 0.02)
#         torch.nn.init.constant_(layer.bias.data, 0.0)
        

# class Generator(nn.Module):
#     def __init__(self):
#         super(Generator, self).__init__()

#         self.init_size = 28 // 4
#         self.l1 = nn.Sequential(nn.Linear(100, 128 * self.init_size ** 2))

#         self.conv_blocks = nn.Sequential(
#             nn.Upsample(scale_factor=2),
#             nn.Conv2d(128, 128, 3, stride=1, padding=1),
#             nn.BatchNorm2d(128, 0.8),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Upsample(scale_factor=2),
#             nn.Conv2d(128, 64, 3, stride=1, padding=1),
#             nn.BatchNorm2d(64, 0.8),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Conv2d(64, 1, 3, stride=1, padding=1),
#             nn.Tanh(),
#         )

#     def forward(self, z):
#         out = self.l1(z)
#         out = out.view(out.shape[0], 128, self.init_size, self.init_size)
#         img = self.conv_blocks(out)
#         return img


# class Discriminator(nn.Module):
#     def __init__(self):
#         super(Discriminator, self).__init__()

#         def discriminator_block(in_filters, out_filters, bn=True):
#             block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
#             if bn:
#                 block.append(nn.BatchNorm2d(out_filters, 0.8))
#             return block

#         self.model = nn.Sequential(
#             *discriminator_block(1, 16, bn=False),
#             *discriminator_block(16, 32),
#             *discriminator_block(32, 64),
#             *discriminator_block(64, 128),
#         )

#         # The height and width of downsampled image
#         ds_size = 32 // 2 ** 4
#         self.adv_layer = nn.Linear(128 * ds_size ** 2, 1)

#     def forward(self, img):
#         out = self.model(img)
#         out = out.view(out.shape[0], -1)
#         validity = self.adv_layer(out)

#         return validity

In [239]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)
generator.apply(init_weight)
discriminator.apply(init_weight)                                         

Discriminator(
  (model): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Dropout2d(p=0.25, inplace=False)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Dropout2d(p=0.25, inplace=False)
    (6): BatchNorm2d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Dropout2d(p=0.25, inplace=False)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (12): LeakyReLU(negative_slope=0.2, inplace=True)
    (13): Dropout2d(p=0.25, inplace=False)
    (14): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  )
  (adv_laye

In [240]:
# generator = Generator(args.n_latent, args.img_shape).to(device)
# discriminator = Discriminator(args.img_shape).to(device)
# generator.apply(init_weight)
# discriminator.apply(init_weight)                                         

In [241]:
criterion_MSE = torch.nn.MSELoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr = args.G_lr, betas=args.G_betas)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr = args.D_lr, betas=args.D_betas)

In [None]:
ones_label = Variable(torch.ones(args.batch_size), requires_grad=False).to(device)
zeros_label = Variable(torch.zeros(args.batch_size), requires_grad=False).to(device)
for epoch in range(args.n_epochs):
    train_loop = tqdm(train_loader, total=len(train_loader), desc="training", colour="blue", leave=False)
    G_loss_sum = 0
    D_loss_sum = 0
    for img, label in train_loop:
        img = img.to(device)
        label = label.to(device)
        latent_z = torch.randn(args.batch_size, args.n_latent).to(device)
        
        
        
        gene_img = generator(latent_z)
        # training D
        
        gene_logit = discriminator(gene_img.detach())
        real_logit = discriminator(img)
        D_loss = (criterion_MSE(gene_logit, zeros_label) + criterion_MSE(real_logit, ones_label)) / 2
        optimizer_D.zero_grad()
        D_loss.backward()
        optimizer_D.step()    
        
        # training G
        gene_logit = discriminator(gene_img)
        G_loss = criterion_MSE(gene_logit, ones_label)
        optimizer_G.zero_grad()
        G_loss.backward()
        optimizer_G.step()
        
        D_loss_sum += D_loss.item()
        G_loss_sum += G_loss.item()
    print(f"D loss : {D_loss_sum/len(train_loader)}, G loss : {G_loss_sum/len(train_loader)}")
    save_image(gene_img, args.gene_img_dir / f"{epoch}.png")

training:   0%|          | 0/938 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


D loss : 0.2548767401656108, G loss : 0.2754301617521721


  return F.mse_loss(input, target, reduction=self.reduction)


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.24846034824276275, G loss : 0.25792739533984077


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.19520263977126398, G loss : 0.3774784980520511


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.12004356870034547, G loss : 0.5882726635045207


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.10924712336187296, G loss : 0.6260763279347024


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.1035724109344518, G loss : 0.6523555263297073


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.10156750957320716, G loss : 0.6593104291683448


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.10393557656627855, G loss : 0.6675183930551447


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.09704415490036644, G loss : 0.6893338115294096


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.09039472842882913, G loss : 0.6948069869153408


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.08651084705555775, G loss : 0.7340499018944466


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.08199972086953425, G loss : 0.7452040328177562


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.08217734487743171, G loss : 0.7304330029086009


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.07657982344264543, G loss : 0.7540727558611299


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.07425184233058522, G loss : 0.759927102211696


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.0723334171884715, G loss : 0.7753521281916068


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.06889812219272386, G loss : 0.7904512759274257


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.07041584572165029, G loss : 0.7876349793695437


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.06798102589510778, G loss : 0.7786724237379616


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.06693907839649205, G loss : 0.7806868094609363


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.06373399849295981, G loss : 0.7979178836366643


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.06242899356314591, G loss : 0.8073522317955997


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.06166772477811318, G loss : 0.7960328769677484


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.06000169735503143, G loss : 0.8084767268760118


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.058949252929606004, G loss : 0.8209490404565578


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05692683499225024, G loss : 0.8190853446166018


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05788255867852109, G loss : 0.8157173424943297


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05944624750106447, G loss : 0.8310795330575534


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05782779618434465, G loss : 0.8113566004931291


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05572336284753515, G loss : 0.8289882734672093


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.058443632901215285, G loss : 0.8232559271966979


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05719145860022573, G loss : 0.8169176541030534


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05700576088494504, G loss : 0.8148581352092819


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05616404821794393, G loss : 0.8185458779176161


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05717554913440556, G loss : 0.8064308121546245


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.055772483530047254, G loss : 0.8380120582640298


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05331745280373865, G loss : 0.8453755307076837


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.053509198911071046, G loss : 0.8439054444813525


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05426721309007866, G loss : 0.8631604775023867


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05329041998597112, G loss : 0.8447759423746484


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05658038514060602, G loss : 0.8357376894399301


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05541432649592188, G loss : 0.8387295622180011


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05246981686261147, G loss : 0.8217150845221365


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05351537037372732, G loss : 0.8284258075963968


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.052105151357621685, G loss : 0.8479840109374985


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04986067889863923, G loss : 0.8515614735038042


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05367344218010365, G loss : 0.8367699625522598


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05095184339420882, G loss : 0.8374747069500911


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05157300139239221, G loss : 0.8475546865448006


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05031208425406804, G loss : 0.8446192913758221


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05108637534040235, G loss : 0.8458055508162167


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05105935571306169, G loss : 0.8496524898736462


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05448625025166663, G loss : 0.8359217155875682


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.051625930234543614, G loss : 0.8530579773919669


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05233369342017689, G loss : 0.8524382417834898


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04790974814228332, G loss : 0.8401618502668734


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.05152942124319308, G loss : 0.8675709655607687


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.050831927198058825, G loss : 0.8581161646923022


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.051228559192722796, G loss : 0.8611132648866823


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.048354634687975684, G loss : 0.8650221566338021


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.0515361698652143, G loss : 0.8472372241524745


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04788861234686268, G loss : 0.8561012680882584


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04777290269400853, G loss : 0.8482516003029941


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04829826407970539, G loss : 0.8635279124082406


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.050174652236893075, G loss : 0.8508822877110958


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04869271547688461, G loss : 0.8653896256391682


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.049734924044182054, G loss : 0.8557035548568789


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.048742689827242607, G loss : 0.8561412321765032


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.047186105067520016, G loss : 0.8490888413462812


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.0460100985824196, G loss : 0.8700176972284246


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.0481804861267135, G loss : 0.855041694189948


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04637699531240345, G loss : 0.8571369243519646


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.0453581796382675, G loss : 0.8726619455192898


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.047532150182805495, G loss : 0.8647770890072465


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04679716044196935, G loss : 0.8626959123400483


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04731313382629861, G loss : 0.8501262768848873


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.045855843454603355, G loss : 0.8723018104746652


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.0453159734706031, G loss : 0.8730006886602465


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.047992121629309714, G loss : 0.8590425019865351


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04416026110385558, G loss : 0.8736723503991485


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04599053800793122, G loss : 0.8572896196008492


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.046583950890402896, G loss : 0.8647703102656773


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.047119086701869727, G loss : 0.8694536342962719


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.044938098084725646, G loss : 0.8721412222967473


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04818318235173599, G loss : 0.8762650362559473


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.044647891619759024, G loss : 0.8719679747880903


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.042484791129767926, G loss : 0.8755897855612514


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.0466445396318714, G loss : 0.8546743776117053


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04465179303948924, G loss : 0.8603958151043097


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.044621990732739425, G loss : 0.8723435702163782


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.043072139306404966, G loss : 0.8684308447444172


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04432313229694073, G loss : 0.8782986388849551


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.046436981318631707, G loss : 0.8643762638796367


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04476962809531348, G loss : 0.8720225349608769


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04347610398347595, G loss : 0.8570202629068004


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.043303452395170784, G loss : 0.8795048577000083


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.0430431447809995, G loss : 0.87233508171748


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04265166578793537, G loss : 0.8822567569358009


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04318044901634854, G loss : 0.8757295657449694


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04335283423770966, G loss : 0.891931964326769


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04373251299571389, G loss : 0.8898919561206659


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04456288988482572, G loss : 0.8690432117723706


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04398768740683309, G loss : 0.8754283905601197


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04123902982939829, G loss : 0.8748710447314706


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04374637747921153, G loss : 0.8755172926829313


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.043776233491650635, G loss : 0.8630438067638544


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.042644704396435335, G loss : 0.8827452541890938


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04390718812780228, G loss : 0.8741201466874782


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04365141158660393, G loss : 0.8807947217687361


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.0421469750591238, G loss : 0.8770315943083276


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04007066000125873, G loss : 0.8768771575458014


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04339875365773252, G loss : 0.88825063261269


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.041826378723392, G loss : 0.892818072711481


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.043231782422754474, G loss : 0.8793695075298423


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.042971184178849244, G loss : 0.8801295487547734


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03905891649189717, G loss : 0.8959317251817503


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04085655095475688, G loss : 0.8867467377168029


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.040540635939498845, G loss : 0.8912142180780104


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04117116014431439, G loss : 0.8836476546900867


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04008707181731267, G loss : 0.888794231039883


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.0399413523346142, G loss : 0.8795560131798675


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03946670047146505, G loss : 0.8870578984271235


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.039681026532844915, G loss : 0.899814143999299


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.039466371683072224, G loss : 0.8722250255694521


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03932869893142275, G loss : 0.8951900679546633


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03810744727306417, G loss : 0.8903060097144102


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.040385640015317276, G loss : 0.8961881731491862


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.038763537408728434, G loss : 0.8962764559365285


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.042561219974871736, G loss : 0.8916835057011037


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03893022218903801, G loss : 0.8987530579349634


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03674107347292377, G loss : 0.8894476865146206


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03815668310745677, G loss : 0.8894137536475399


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.0389444741749588, G loss : 0.8745680264254877


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.0382157053941412, G loss : 0.9022100523574901


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.04009699832840658, G loss : 0.8926267740505336


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03956527667897525, G loss : 0.891127071019683


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.0398116111204322, G loss : 0.8903103690983644


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03565814835888181, G loss : 0.8916453186001605


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.035892169462638034, G loss : 0.8985501319996075


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.037463121719088464, G loss : 0.8993509353986427


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03688647264072588, G loss : 0.8878439699631255


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03854637578708618, G loss : 0.8929253439150894


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03499127585125138, G loss : 0.8996278518584492


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03706757483883763, G loss : 0.8916297059323488


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.0363108983696766, G loss : 0.9027238688997622


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03552086943778386, G loss : 0.883611305007167


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03627540291855926, G loss : 0.9040832402291836


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03715930584652313, G loss : 0.9131189224594183


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03558871188937728, G loss : 0.9041398208437443


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03697334816465889, G loss : 0.8961977529119073


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.0350610574417964, G loss : 0.9046230367315349


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.038941280717508735, G loss : 0.8960708206904723


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03495431252008578, G loss : 0.9138467566354442


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.037623039126566955, G loss : 0.8903053749694245


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03631461377534618, G loss : 0.9084618569437121


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03535606216295247, G loss : 0.8977132154298998


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.037514152546683305, G loss : 0.9186029601167006


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03228838552163838, G loss : 0.9100233900871104


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.037461588638765156, G loss : 0.8962934230055127


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.033706993864936585, G loss : 0.89539160774842


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03497768607489995, G loss : 0.9176767698483173


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03426070070973457, G loss : 0.8943332000645493


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.032680979660297156, G loss : 0.8988845805560094


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03413116969942633, G loss : 0.900758332694009


training:   0%|          | 0/938 [00:00<?, ?it/s]

D loss : 0.03372885089894848, G loss : 0.9128478422824507


training:   0%|          | 0/938 [00:00<?, ?it/s]

In [235]:
gene_img.shape

torch.Size([64, 1, 14, 14])

In [236]:
from sympy import Derivative, symbols