In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import lmdb
import numpy as np
import io
from tqdm import tqdm



In [2]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 512 * 8 * 8), # 512 channels for 128 by 128 images
            nn.ReLU(),
            nn.Unflatten(1, (512, 8, 8)),

            #nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1), # 1024 * 16 * 16
            nn.BatchNorm2d(512, momentum=0.5),
            nn.ReLU(),

            #nn.Upsample(scale_factor=2), # 512 * 32 * 32
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256, momentum=0.5),
            nn.ReLU(),

            #nn.Upsample(scale_factor=2), # 256 * 64 * 64
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128, momentum=0.5),
            nn.ReLU(),

            #nn.Upsample(scale_factor=2), # 128 * 128 * 128
            nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()

        )
    def forward(self, x):
        img = self.model(x)
        return img

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # maybe add spectral_norm?
        self.model = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, stride=2, padding=1), # 64 by 64 * 128
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 32 by 32 * 256
            nn.BatchNorm2d(256, momentum=0.82),
            nn.LeakyReLU(0.25),
            nn.Dropout(0.1),

            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 16 by 16 * 256
            nn.BatchNorm2d(512, momentum=0.82),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),

            nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1), # 8 by 8 * 512
            nn.BatchNorm2d(1024, momentum=0.8),
            nn.LeakyReLU(0.25),
            nn.Dropout(0.1),

            nn.Flatten(),
            nn.Linear(1024 * 16 * 16, 1),
            nn.Sigmoid()
        )

    def forward(self, img):



        '''                
        x = self.model[0](img)
        print(f"After first Conv2d: {x.shape}")
        x = self.model[1](x)
        print(f"After LeakyReLU and Dropout: {x.shape}")
        x = self.model[2](x)
        print(f"After second Conv2d: {x.shape}")
        x = self.model[3](x)
        print(f"After second LeakyReLU and Dropout: {x.shape}")
        x = self.model[4](x)
        print(f"After third Conv2d: {x.shape}")
        x = self.model[5](x)
        print(f"After third LeakyReLU and Dropout: {x.shape}")
        x = self.model[6](x)
        print(f"After fourth Conv2d: {x.shape}")
        x = self.model[7](x)
        print(f"After fourth LeakyReLU and Dropout: {x.shape}")
        
        x = self.model[8](x)  # Flatten
        print(f"After Flatten: {x.shape}")
        validity = self.model[9](x)  # Linear layer
        return validity
        '''
    
        validity = self.model(img)
        return validity



In [4]:
class LSUNDataset(Dataset):
    def __init__(self, lmdb_path, transform=None):
        self.lmdb_path = lmdb_path
        self.transform = transform

        self.env = lmdb.open(self.lmdb_path, readonly=True, lock=False)

        with self.env.begin(write=False) as txn:
            self.length = txn.stat()['entries']
        
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):

        with self.env.begin(write=False) as txn:
            cursor = txn.cursor()
            cursor.iternext()

            for i in range(idx):
                cursor.iternext()
                
            key, value = cursor.item()
        

            image = Image.open(io.BytesIO(value))


            if self.transform:
                image = self.transform(image)


            return image

In [5]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

train_db_path = "./data/bridge_train_lmdb"
val_db_path = "./data/bridge_val_lmdb"



train_dataset = LSUNDataset(train_db_path, transform)
val_dataset = LSUNDataset(train_db_path, transform)

In [6]:
image = train_dataset[0]
print(image.shape)

torch.Size([3, 128, 128])


In [7]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=32, shuffle=True, num_workers=0)

In [8]:
for i, images in enumerate(train_dataloader):
    print(f"Batch {i+1}:")
    print("Images shape:", images.shape)
    
    break

Batch 1:
Images shape: torch.Size([32, 3, 128, 128])


In [9]:
latent_dim = 100
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
num_epochs = 10

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

adversarial_loss = nn.BCELoss()

optimizer_G = optim.Adam(generator.parameters()\
                         , lr=lr, betas=(beta1, beta2))
optimizer_D = optim.Adam(discriminator.parameters()\
                         , lr=lr, betas=(beta1, beta2))

In [24]:
for epoch in range(num_epochs):
    with tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f'Epoch {epoch+1}/{num_epochs}') as pbar:
        for i, batch in pbar:
            real_images = batch.to(device) 


            #print(real_images.shape)

            valid = torch.ones(real_images.size(0), 1, device=device)
            fake = torch.zeros(real_images.size(0), 1, device=device)

            real_images = real_images.to(device)

            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()

            z = torch.randn(real_images.size(0), latent_dim, device=device)

            fake_images = generator(z)

            #print(f"Real images shape: {real_images.shape}")  # Should be [batch_size, 3, 128, 128]
            #print(f"Fake images shape: {fake_images.shape}")
            
            real_loss = adversarial_loss(discriminator(real_images), valid)
            fake_loss = adversarial_loss(discriminator(fake_images.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            # -----------------
            #  Train Generator
            # -----------------
            optimizer_G.zero_grad()

            gen_images = generator(z)

            g_loss = adversarial_loss(discriminator(gen_images), valid)

            g_loss.backward()
            optimizer_G.step()

            
            if torch.isnan(d_loss).any() or torch.isnan(g_loss).any():
                print(f"NaN detected! D Loss: {d_loss}, G Loss: {g_loss}")

            # ---------------------
            #  Progress Monitoring
            # ---------------------
            if (i + 1) % 100 == 0:
                pbar.set_postfix({
                    'D Loss': d_loss.item(),
                    'G Loss': g_loss.item()
                })

        # Save generated images for every epoch
        if (epoch + 1) % 10 == 0:
            with torch.no_grad():
                z = torch.randn(16, latent_dim, device=device)
                generated = generator(z).detach().cpu()
                grid = torchvision.utils.make_grid(generated, nrow=4, normalize=True)
                plt.imshow(np.transpose(grid, (1, 2, 0)))
                plt.axis("off")
                plt.show()
                plt.imsave(f"./output/{epoch+1}")


Epoch 1/10:   0%|          | 86/25584 [00:39<3:21:36,  2.11it/s]