In [1]:
import torch
import torch.nn as nn
import torch.optim as optim 
import torchvision.datasets as datasets
import torchvision.transforms as transforms 
from torch.utils.data import DataLoader

In [8]:
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.conv1 = self._block(channels_img, features_d, 4, 2, 1)
        self.conv2 = self._block(features_d, features_d * 2, 4, 2, 1)
        self.conv3 = self._block(features_d * 2, features_d * 4, 4, 2, 1)
        self.conv4 = self._block(features_d * 4, features_d * 8, 4, 2, 1)
        self.conv5 = nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0)
       

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
       
        return x

In [9]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.conv1 = self._block(channels_noise, features_g * 16, 4, 1, 0)  # img: 4x4
        self.conv2 = self._block(features_g * 16, features_g * 8, 4, 2, 1)  # img: 8x8
        self.conv3 = self._block(features_g * 8, features_g * 4, 4, 2, 1)  # img: 16x16
        self.conv4 = self._block(features_g * 4, features_g * 2, 4, 2, 1)  # img: 32x32
        self.conv5 = nn.ConvTranspose2d(features_g * 2, channels_img, kernel_size=4, stride=2, padding=1)
        self.activation = nn.Tanh()

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.activation(x)
        return x

In [10]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [17]:
generator_path = "generator.pth"
critic_loss_path = "critic_loss.pth"
gen_loss_path = "gen_loss.pth"


In [18]:
torch.save(gen.state_dict(), generator_path)
torch.save(loss_critic.item(), critic_loss_path)
torch.save(loss_gen.item(), gen_loss_path)


In [20]:
# Check if saved files exist and load them
import os
if os.path.exists(generator_path):
    gen.load_state_dict(torch.load(generator_path))
if os.path.exists(critic_loss_path):
    critic_loss = torch.load(critic_loss_path)
if os.path.exists(gen_loss_path):
    gen_loss = torch.load(gen_loss_path)


In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import os


# Hyperparameters etc
device = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 5e-5
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 3
Z_DIM = 128
NUM_EPOCHS = 20
FEATURES_CRITIC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
WEIGHT_CLIP = 0.01



transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])


dataset = datasets.DatasetFolder(
    root="flower_data/",
    loader=torchvision.datasets.folder.default_loader,
    extensions=('.jpg', '.jpeg', '.png', '.bmp'),  
    transform=transforms,
)

loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# initialize gen and disc/critic
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)

# initializate optimizer
opt_gen = optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)
opt_critic = optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)

# for tensorboard plotting
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

# Check if saved files exist and load them
generator_path = "generator.pth"
critic_loss_path = "critic_loss.pth"
gen_loss_path = "gen_loss.pth"

if os.path.exists(generator_path):
    gen.load_state_dict(torch.load(generator_path))
if os.path.exists(critic_loss_path):
    critic_loss = torch.load(critic_loss_path)
if os.path.exists(gen_loss_path):
    gen_loss = torch.load(gen_loss_path)

gen.train()
critic.train()

for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (data, _) in enumerate(tqdm(loader)):
        data = data.to(device)
        cur_batch_size = data.shape[0]

        # Train Critic: max E[critic(real)] - E[critic(fake)]
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
            fake = gen(noise)
            critic_real = critic(data).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake))
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

            # clip critic weights between -0.01, 0.01
            for p in critic.parameters():
                p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0 and batch_idx > 0:
            gen.eval()
            critic.eval()
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    data[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1
            gen.train()
            critic.train()

    # Save generator, critic loss, and generator loss
    torch.save(gen.state_dict(), generator_path)
    torch.save(loss_critic.item(), critic_loss_path)
    torch.save(loss_gen.item(), gen_loss_path)


 98%|█████████▊| 101/103 [00:50<00:01,  1.90it/s]

Epoch [0/20] Batch 100/103                   Loss D: -1.2555, loss G: 0.6082


100%|██████████| 103/103 [00:50<00:00,  2.03it/s]
 98%|█████████▊| 101/103 [00:48<00:01,  1.84it/s]

Epoch [1/20] Batch 100/103                   Loss D: -1.4160, loss G: 0.7134


100%|██████████| 103/103 [00:49<00:00,  2.08it/s]
 98%|█████████▊| 101/103 [00:47<00:01,  1.99it/s]

Epoch [2/20] Batch 100/103                   Loss D: -1.4707, loss G: 0.7204


100%|██████████| 103/103 [00:47<00:00,  2.15it/s]
 98%|█████████▊| 101/103 [00:45<00:00,  2.00it/s]

Epoch [3/20] Batch 100/103                   Loss D: -1.4615, loss G: 0.7202


100%|██████████| 103/103 [00:46<00:00,  2.23it/s]
 98%|█████████▊| 101/103 [00:47<00:01,  1.89it/s]

Epoch [4/20] Batch 100/103                   Loss D: -1.4504, loss G: 0.7067


100%|██████████| 103/103 [00:48<00:00,  2.12it/s]
 98%|█████████▊| 101/103 [00:48<00:01,  1.90it/s]

Epoch [5/20] Batch 100/103                   Loss D: -1.4278, loss G: 0.7048


100%|██████████| 103/103 [00:49<00:00,  2.08it/s]
 98%|█████████▊| 101/103 [00:49<00:01,  1.81it/s]

Epoch [6/20] Batch 100/103                   Loss D: -1.4264, loss G: 0.7053


100%|██████████| 103/103 [00:49<00:00,  2.06it/s]
 98%|█████████▊| 101/103 [00:48<00:01,  1.85it/s]

Epoch [7/20] Batch 100/103                   Loss D: -1.3750, loss G: 0.7122


100%|██████████| 103/103 [00:49<00:00,  2.08it/s]
 98%|█████████▊| 101/103 [00:48<00:01,  1.92it/s]

Epoch [8/20] Batch 100/103                   Loss D: -1.4076, loss G: 0.6945


100%|██████████| 103/103 [00:48<00:00,  2.11it/s]
 98%|█████████▊| 101/103 [00:48<00:01,  1.96it/s]

Epoch [9/20] Batch 100/103                   Loss D: -1.3423, loss G: 0.6970


100%|██████████| 103/103 [00:49<00:00,  2.09it/s]
 98%|█████████▊| 101/103 [00:48<00:01,  1.89it/s]

Epoch [10/20] Batch 100/103                   Loss D: -1.3978, loss G: 0.6996


100%|██████████| 103/103 [00:49<00:00,  2.08it/s]
 98%|█████████▊| 101/103 [00:48<00:01,  1.86it/s]

Epoch [11/20] Batch 100/103                   Loss D: -1.3545, loss G: 0.6682


100%|██████████| 103/103 [00:49<00:00,  2.08it/s]
 98%|█████████▊| 101/103 [00:47<00:01,  1.97it/s]

Epoch [12/20] Batch 100/103                   Loss D: -1.3646, loss G: 0.6704


100%|██████████| 103/103 [00:48<00:00,  2.13it/s]
 98%|█████████▊| 101/103 [00:45<00:01,  2.00it/s]

Epoch [13/20] Batch 100/103                   Loss D: -1.2902, loss G: 0.6765


100%|██████████| 103/103 [00:46<00:00,  2.21it/s]
 98%|█████████▊| 101/103 [00:48<00:01,  1.88it/s]

Epoch [14/20] Batch 100/103                   Loss D: -1.3332, loss G: 0.6609


100%|██████████| 103/103 [00:49<00:00,  2.08it/s]
 98%|█████████▊| 101/103 [00:49<00:00,  2.02it/s]

Epoch [15/20] Batch 100/103                   Loss D: -1.2435, loss G: 0.6764


100%|██████████| 103/103 [00:50<00:00,  2.05it/s]
 98%|█████████▊| 101/103 [00:46<00:00,  2.02it/s]

Epoch [16/20] Batch 100/103                   Loss D: -1.3002, loss G: 0.6437


100%|██████████| 103/103 [00:46<00:00,  2.20it/s]
 98%|█████████▊| 101/103 [00:45<00:00,  2.01it/s]

Epoch [17/20] Batch 100/103                   Loss D: -1.2605, loss G: 0.6081


100%|██████████| 103/103 [00:46<00:00,  2.21it/s]
 98%|█████████▊| 101/103 [00:46<00:01,  1.96it/s]

Epoch [18/20] Batch 100/103                   Loss D: -1.2905, loss G: 0.6233


100%|██████████| 103/103 [00:47<00:00,  2.19it/s]
 98%|█████████▊| 101/103 [00:46<00:00,  2.02it/s]

Epoch [19/20] Batch 100/103                   Loss D: -1.2170, loss G: 0.6326


100%|██████████| 103/103 [00:46<00:00,  2.20it/s]


In [None]:
import torch
from torchvision.utils import save_image

# Load the generator model
generator = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
generator.load_state_dict(torch.load("generator.pth"))
generator.eval()


In [23]:
# Generate flower images
num_images = 10  # Number of images to generate
fixed_noise = torch.randn(num_images, Z_DIM, 1, 1).to(device)
with torch.no_grad():
    generated_images = generator(fixed_noise).detach().cpu()

# Save the generated images
save_image(generated_images, "generated_flowers.png", nrow=5, normalize=True)
