In [1]:
pip install torch torchvision numpy matplotlib flask

Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import os
from torch.utils.data import DataLoader
from torch.nn.utils import spectral_norm
import torch.nn.functional as F

In [3]:
batch_size = 128
image_size = 64  # Resize Pokémon images to 64x64
nz = 100  # Latent vector size
ngf = 128  # Generator feature map size
ndf = 64  # Discriminator feature map size
nc = 3  # Number of channels (RGB)
epochs = 200
lr = 0.0002
beta1 = 0.5  # Adam optimizer beta
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

dataset_path = "/kaggle/input/pokemonclassification"  # Correct path

# Define transformations for 64x64 images
transform = transforms.Compose([
    transforms.Resize(64),  # Resize images to 64x64
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

# Load dataset using ImageFolder
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Check if images are loaded
print(f"Total images found: {len(dataset)}")
print(f"Class labels: {dataset.classes}")

Total images found: 6820
Class labels: ['PokemonData']


In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()  # Output in range [-1,1]
        )

    def forward(self, x):
        return self.main(x)

generator = Generator().to(device)

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            spectral_norm(nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.3),

            spectral_norm(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.3),

            spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.3),

            spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.3),

            spectral_norm(nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False))  # No Sigmoid
        )

    def forward(self, x):
        return self.main(x)

# Set device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ✅ Initialize the Discriminator and move it to device
discriminator = Discriminator().to(device)

In [7]:
criterion = nn.BCEWithLogitsLoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(beta1, 0.999))

In [8]:
# ✅ Fixed noise for consistent evaluation
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# ✅ Training loop
for epoch in range(epochs):
    for i, (real_images, _) in enumerate(dataloader):
        real_images = real_images.to(device)
        batch_size = real_images.size(0)

        ### ---- Train Discriminator ---- ###
        for _ in range(1):  
            optimizer_D.zero_grad()

            # ✅ Add Small Noise to Real Images (Regularization)
            real_images_noisy = real_images + 0.05 * torch.randn_like(real_images)  # Add noise

            # ✅ Random Label Smoothing (Prevents Overconfidence)
            labels_real = torch.rand((batch_size, 1), device=device) * 0.2 + 0.8  # Range: [0.8, 1.0]
            labels_fake = torch.rand((batch_size, 1), device=device) * 0.2  # Range: [0.0, 0.2]

            # ✅ Train on Real Data
            output_real = discriminator(real_images).view(-1, 1)
            loss_real = criterion(output_real, labels_real)

            # ✅ Train on Fake Data
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            fake_images = generator(noise)
            output_fake = discriminator(fake_images.detach()).view(-1, 1)
            loss_fake = criterion(output_fake, labels_fake)

            loss_D = loss_real + loss_fake
            loss_D.backward()
            optimizer_D.step()

        ### ---- Train Generator with Feature Matching (MSE Loss) ---- ###
        optimizer_G.zero_grad()
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake_images = generator(noise)

        # ✅ Extract Features from Discriminator (Real & Fake)
        real_features = discriminator(real_images).mean(dim=0)
        fake_features = discriminator(fake_images).mean(dim=0)

        # ✅ Use MSE Loss (Feature Matching) Instead of BCE
        loss_G = F.mse_loss(fake_features, real_features)

        loss_G.backward()
        optimizer_G.step()

    ### ---- Save Images (Using Fixed Noise) ---- ###
    if epoch % 10 == 0:
        with torch.no_grad():
            fake_images_fixed = generator(fixed_noise)  # ✅ Same noise for consistency
        vutils.save_image(fake_images_fixed, f"generated_epoch_{epoch}.png", normalize=True)

    print(f"Epoch {epoch}/{epochs} - Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}")

# ✅ Save Model (Better Naming)
torch.save(generator.state_dict(), "/kaggle/working/pokemon_generator.pth")

Epoch 0/200 - Loss_D: 1.1417, Loss_G: 2.6706
Epoch 1/200 - Loss_D: 1.4425, Loss_G: 2.1503
Epoch 2/200 - Loss_D: 1.5500, Loss_G: 2.2297
Epoch 3/200 - Loss_D: 1.4698, Loss_G: 3.0955
Epoch 4/200 - Loss_D: 1.4402, Loss_G: 4.5004
Epoch 5/200 - Loss_D: 1.3197, Loss_G: 0.0283
Epoch 6/200 - Loss_D: 1.2118, Loss_G: 6.6246
Epoch 7/200 - Loss_D: 1.2708, Loss_G: 6.3284
Epoch 8/200 - Loss_D: 1.5599, Loss_G: 1.1379
Epoch 9/200 - Loss_D: 2.0484, Loss_G: 4.0903
Epoch 10/200 - Loss_D: 1.1346, Loss_G: 6.1253
Epoch 11/200 - Loss_D: 1.1022, Loss_G: 6.9065
Epoch 12/200 - Loss_D: 0.9108, Loss_G: 5.3443
Epoch 13/200 - Loss_D: 1.1190, Loss_G: 2.0596
Epoch 14/200 - Loss_D: 1.3406, Loss_G: 4.9137
Epoch 15/200 - Loss_D: 1.4400, Loss_G: 2.2633
Epoch 16/200 - Loss_D: 0.9630, Loss_G: 5.6834
Epoch 17/200 - Loss_D: 1.2324, Loss_G: 3.2261
Epoch 18/200 - Loss_D: 0.8719, Loss_G: 3.0519
Epoch 19/200 - Loss_D: 1.3991, Loss_G: 4.5422
Epoch 20/200 - Loss_D: 1.1036, Loss_G: 8.8303
Epoch 21/200 - Loss_D: 1.0547, Loss_G: 4.159

In [9]:
!pip install flask flask-ngrok torch torchvision
!pip install pyngrok

Collecting flask-ngrok
  Downloading flask_ngrok-0.0.25-py3-none-any.whl.metadata (1.8 kB)
Downloading flask_ngrok-0.0.25-py3-none-any.whl (3.1 kB)
Installing collected packages: flask-ngrok
Successfully installed flask-ngrok-0.0.25
Collecting pyngrok
  Downloading pyngrok-7.2.3-py3-none-any.whl.metadata (8.7 kB)
Downloading pyngrok-7.2.3-py3-none-any.whl (23 kB)
Installing collected packages: pyngrok
Successfully installed pyngrok-7.2.3


In [10]:
!ngrok authtoken 2uMhf2tjlzaHE9zqXtVSEVh87P2_3GHhmvZ5MNhp1bnsxyokS

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml                                


In [11]:
from pyngrok import ngrok

# Open an HTTP tunnel on port 5000 (Flask's default port)
ngrok_tunnel = ngrok.connect(5000, "http")
print(f"Public URL: {ngrok_tunnel.public_url}")

Public URL: https://4c79-34-147-76-119.ngrok-free.app


In [12]:
from flask import Flask, send_file
from flask_ngrok import run_with_ngrok
import torch
import torchvision.utils as vutils
import torch.nn as nn
import os

# Define Generator Model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()  # Output in range [-1,1]
        )

    def forward(self, x):
        return self.main(x)

# Initialize Flask app
app = Flask(__name__)
run_with_ngrok(app)  # Enable ngrok for Colab

# Load trained Generator
nz = 100
device = torch.device("cpu")
generator = Generator()
generator.load_state_dict(torch.load("/kaggle/working/pokemon_generator.pth", map_location=device))
generator.eval()

@app.route('/')
def home():
    return '''
        <h1>Pokémon GAN</h1>
        <img src="/generate" width="256">
        <br><br>
        <button onclick="location.reload();">Generate Again</button>
    '''

@app.route('/generate')
def generate_image():
    noise = torch.randn(1, nz, 1, 1, device=device)
    with torch.no_grad():
        fake_image = generator(noise)
    vutils.save_image(fake_image, "static/generated.png", normalize=True)
    return send_file("static/generated.png", mimetype='image/png')

# Run Flask App
if __name__ == '__main__':
    if not os.path.exists("static"):
        os.makedirs("static")
    app.run()

 * Serving Flask app '__main__'
 * Debug mode: off


  generator.load_state_dict(torch.load("/kaggle/working/pokemon_generator.pth", map_location=device))


 * Running on http://4c79-34-147-76-119.ngrok-free.app
 * Traffic stats available on http://127.0.0.1:4040
