<a href="https://colab.research.google.com/github/yastiaisyah/DataSynthesis/blob/main/adversarial_autoencoder_realfake_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [34]:
import os
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
import torch
import torchvision
import torch.utils.data
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from io import BytesIO
import numpy as np


# Autentikasi di Google Colab
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

# Definisikan transformasi yang ingin Anda terapkan pada gambar
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Buat custom dataset
class CustomImageDataset(Dataset):
    def __init__(self, drive, folder_id, transform=None):
        self.drive = drive
        self.folder_id = folder_id
        self.file_list = self.get_file_list()
        self.transform = transform

    def get_file_list(self):
        file_list = self.drive.ListFile({'q': "'{}' in parents".format(self.folder_id)}).GetList()
        return file_list

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file = self.file_list[idx]
        img = self.load_image(file)
        if self.transform:
            img = self.transform(img)
        return img

    @staticmethod
    def load_image(file):
        img = drive.CreateFile({'id': file['id']})
        img.GetContentFile(file['title'])
        img = Image.open(file['title'])
        return img

# Ganti 'YOUR_FOLDER_ID' dengan ID folder Google Drive yang sesuai
folder_id = '1gkvKPLPtyTlUUdLVZ0yOiL0Vga-enW6n'

# Gunakan custom dataset untuk mengakses gambar-gambar
google_drive_dataset = CustomImageDataset(drive, folder_id, transform=transform)

# Tentukan ukuran batch (mb_size) sesuai dengan preferensi Anda
mb_size = 32  # Misalnya, gunakan ukuran batch 32
z_dim = 5
h_dim = 128
X_dim = 28 * 28  # Untuk gambar berukuran 28x28 piksel
lr = 1e-3

# DataLoader untuk dataset
data_loader = DataLoader(google_drive_dataset, batch_size=mb_size, shuffle=True)

# Lanjutkan dengan definisi model dan pelatihan sesuai yang telah Anda lakukan sebelumnya.

# Encoder
Q = nn.Sequential(
    nn.Linear(X_dim, h_dim),
    nn.ReLU(),
    nn.Linear(h_dim, z_dim)
)

# Decoder
P = nn.Sequential(
    nn.Linear(z_dim, h_dim),
    nn.ReLU(),
    nn.Linear(h_dim, X_dim),
    nn.Sigmoid()
)

# Discriminator
D = nn.Sequential(
    nn.Linear(z_dim, h_dim),
    nn.ReLU(),
    nn.Linear(h_dim, 1),
    nn.Sigmoid()
)

def reset_grad():
    Q.zero_grad()
    P.zero_grad()
    D.zero_grad()

Q_solver = optim.Adam(Q.parameters(), lr=lr)
P_solver = optim.Adam(P.parameters(), lr=lr)
D_solver = optim.Adam(D.parameters(), lr=lr)

cnt = 0

for it in range(100000):
    for X in data_loader:
        """ Reconstruction phase """
        z_sample = Q(X.view(mb_size, 1, 28, 28))
        X_sample = P(z_sample)

        # Clip values to be within [0, 1]
        X_sample = X_sample.clamp(0, 1)

        # Use BCELoss for binary cross entropy
        recon_loss = nn.BCELoss()(X_sample, X.view(-1, X_dim))

        recon_loss.backward()
        P_solver.step()
        Q_solver.step()
        reset_grad()

        """ Regularization phase """
        # Discriminator
        z_real = torch.randn(mb_size, z_dim)
        z_fake = Q(X.view(-1, X_dim))

        D_real = D(z_real)
        D_fake = D(z_fake)

        D_loss = -torch.mean(torch.log(D_real) + torch.log(1 - D_fake))

        D_loss.backward()
        D_solver.step()
        reset_grad()

        # Generator
        z_fake = Q(X.view(-1, X_dim))
        D_fake = D(z_fake)

        G_loss = -torch.mean(torch.log(D_fake))

        G_loss.backward()
        Q_solver.step()
        reset_grad()

        # Print and plot every now and then
        if it % 1000 == 0:
            print('Iter-{}; D_loss: {:.4}; G_loss: {:.4}; recon_loss: {:.4}'
                  .format(it, D_loss.item(), G_loss.item(), recon_loss.item()))

            samples = P(z_real).data.numpy()[:16]

            fig = plt.figure(figsize=(4, 4))
            gs = gridspec.GridSpec(4, 4)
            gs.update(wspace=0.05, hspace=0.05)

            for i, sample in enumerate(samples):
                ax = plt.subplot(gs[i])
                plt.axis('off')
                ax.set_xticklabels([])
                ax.set_yticklabels([])
                ax.set_aspect('equal')
                plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

            if not os.path.exists('out/'):
                os.makedirs('out/')

            plt.savefig('out/{}.png'
                        .format(str(cnt).zfill(3)), bbox_inches='tight')
            cnt += 1
            plt.close(fig)


RuntimeError: ignored