<a href="https://colab.research.google.com/github/tenten0727/dcgan_pokemon_pytorch/blob/main/dcgan_pokemon_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [2]:
import torch
from torch import nn, optim
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import transforms, datasets
import tqdm
from statistics import mean
import os, glob
from PIL import Image

In [3]:
# import tarfile
# with tarfile.open('/content/gdrive/My Drive/sample-data.tar.gz', 'r:gz') as tar:tar.extractall()

In [4]:
def Transformimage(input_path, out_path, flag_delete_original_files):
  filepath_list = glob.glob(input_path + '/*.png') # .pngファイルをリストで取得する
  for filepath in filepath_list:
      basename  = os.path.basename(filepath) 
      save_filepath = out_path + '/' + basename [:-4] + '.jpg' 
      img = Image.open(filepath)

      #alpha部分を白に変換
      alpha = img.convert('RGBA').split()[-1]
      new = Image.new("RGBA", img.size, (255, 255, 255, 255))
      new.paste(img, mask=alpha)

      # RGBA(png)→RGB(jpg)へ変換
      new = new.convert('RGB') 
      new.save(save_filepath, "JPEG", quality=95)
      print(filepath, '->', save_filepath)
      if flag_delete_original_files:
          os.remove(filepath)
          print('delete', filepath)

input_path = "/content/gdrive/My Drive/pokemon/images"
out_path = input_path
flag_delete_original_files = True
Transformimage(input_path, out_path, flag_delete_original_files)

In [5]:
dataset = datasets.ImageFolder("/content/gdrive/My Drive/pokemon",
    transform=transforms.Compose([ #先頭から順に実行していく
        transforms.Resize((64,64)),
        transforms.ToTensor() #tensor型に変更。画像の輝度の範囲を[0.0, 1.0]に
])) 

batch_size = 32

data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
print(dataset)

Dataset ImageFolder
    Number of datapoints: 809
    Root location: /content/gdrive/My Drive/pokemon
    StandardTransform
Transform: Compose(
               Resize(size=(64, 64), interpolation=PIL.Image.BILINEAR)
               ToTensor()
           )


In [6]:
class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.main = nn.Sequential(
        nn.ConvTranspose2d(100, 256, 4, 1, 0, bias=False),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        
        nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
        nn.BatchNorm2d(32),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
        nn.Tanh()
    )

  def forward(self, x):
    return self.main(x)

In [7]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(

            nn.Conv2d(3, 32, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 4, 1, 0, bias=False),
        )

    def forward(self, x):
        return self.main(x).squeeze()

In [8]:
model_G = Generator().to("cuda:0")
model_D = Discriminator().to("cuda:0")

params_G = optim.Adam(model_G.parameters(), lr=0.0002, betas=(0.5, 0.999))
params_D = optim.Adam(model_D.parameters(), lr=0.0002, betas=(0.5, 0.999))

nz = 100

ones = torch.ones(batch_size).to("cuda:0")
zeros = torch.zeros(batch_size).to("cuda:0")

loss_f = nn.BCEWithLogitsLoss()

check_z = torch.randn(batch_size, nz, 1, 1).to("cuda:0")

def train_dcgan(model_G, model_D, params_G, params_D, data_loader):
  log_loss_G = []
  log_loss_D = []

  for real_img, _ in tqdm.tqdm(data_loader): #batch_size=32
    batch_len = len(real_img)

    z = torch.randn(batch_len, nz, 1, 1).to("cuda:0")
    fake_img = model_G(z)

    fake_img_tensor = fake_img.detach()

    #偽画像を実画像と思わせるようにlossを設定
    out = model_D(fake_img)
    loss_G = loss_f(out, ones[: batch_len])
    log_loss_G.append(loss_G.item())

    model_D.zero_grad()
    model_G.zero_grad()
    loss_G.backward()
    params_G.step()

    real_img = real_img.to("cuda:0")

    real_out = model_D(real_img)
    loss_D_real = loss_f(real_out, ones[:batch_len])

    fake_img = fake_img_tensor

    fake_out = model_D(fake_img_tensor)
    loss_D_fake = loss_f(fake_out, zeros[:batch_len])
    
    loss_D = loss_D_real + loss_D_fake
    log_loss_D.append(loss_D.item())

    model_D.zero_grad()
    model_G.zero_grad()
    loss_D.backward()
    params_D.step()

  return mean(log_loss_G), mean(log_loss_D)

In [9]:
!mkdir Weight_Generator
!mkdir Generated_Image

In [None]:
for epoch in range(100000):
  train_dcgan(model_G, model_D, params_G, params_D, data_loader)

  if epoch % 10 == 0:
    torch.save(
        model_G.state_dict(),
        "Weight_Generator/G_{:03d}.pth".format(epoch),
        pickle_protocol=4
    )
    torch.save(
        model_D.state_dict(),
        "Weight_Generator/D_{:03d}.pth".format(epoch),
        pickle_protocol=4
    )

    generated_img = model_G(check_z)
    save_image(generated_img, "Generated_Image/{:03d}.jpg".format(epoch))

[1;30;43mストリーミング出力は最後の 5000 行に切り捨てられました。[0m
100%|██████████| 26/26 [00:01<00:00, 14.05it/s]
100%|██████████| 26/26 [00:01<00:00, 14.28it/s]
100%|██████████| 26/26 [00:01<00:00, 14.11it/s]
100%|██████████| 26/26 [00:01<00:00, 13.90it/s]
100%|██████████| 26/26 [00:01<00:00, 14.00it/s]
100%|██████████| 26/26 [00:01<00:00, 14.30it/s]
100%|██████████| 26/26 [00:01<00:00, 13.59it/s]
100%|██████████| 26/26 [00:01<00:00, 13.32it/s]
100%|██████████| 26/26 [00:01<00:00, 13.79it/s]
100%|██████████| 26/26 [00:01<00:00, 14.44it/s]
100%|██████████| 26/26 [00:01<00:00, 14.34it/s]
100%|██████████| 26/26 [00:01<00:00, 14.55it/s]
100%|██████████| 26/26 [00:01<00:00, 13.96it/s]
100%|██████████| 26/26 [00:01<00:00, 13.60it/s]
100%|██████████| 26/26 [00:01<00:00, 13.34it/s]
100%|██████████| 26/26 [00:01<00:00, 13.93it/s]
100%|██████████| 26/26 [00:01<00:00, 14.06it/s]
100%|██████████| 26/26 [00:01<00:00, 14.60it/s]
100%|██████████| 26/26 [00:01<00:00, 14.31it/s]
100%|██████████| 26/26 [00:01<00:00, 14.40

In [None]:
%cd Generated_Image
from IPython.display import Image,display_jpeg
for i in range(11):
  print('epoch'+str(i*10000):)
  display_jpeg(Image(str(i*10000)+'.jpg'))

%cd ..