In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, TensorDataset

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU not available, using CPU instead.")

  from .autonotebook import tqdm as notebook_tqdm


Using GPU: Tesla V100-PCIE-32GB


In [2]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, output_dim),
            # output_dim should match your data structure
        )

    def forward(self, z):
        return self.model(z)

In [3]:
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)

In [4]:
latent_dim = 120  # Adjust as needed
output_dim = 64  # Adjust to match your data (e.g., num_points * num_dimensions)
input_dim = output_dim  # Should be the same as the output_dim
# input_dim = 12288
# Initialize models
generator = Generator(latent_dim).to(device)
discriminator = Discriminator(input_dim).to(device)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Loss function
criterion = nn.BCELoss()

In [6]:
import numpy as np
image_size = 64  # Adjust based on your needs

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalizing for [-1, 1]
])

data_directory = 'broh/'  # Update with the path to your image directory

dataset = datasets.ImageFolder(root=data_directory, transform=transform)
batch_size = 64  # Adjust based on your GPU's memory
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [7]:
from torchvision.utils import save_image
import os

def generate_and_save_images(generator, num_images, latent_dim, output_dir, device):
    """
    Generate and save images using the GAN generator.

    :param generator: The generator model.
    :param num_images: Number of images to generate.
    :param latent_dim: Dimension of the latent space.
    :param output_dir: Directory to save the generated images.
    :param device: Device to run the model ('cpu' or 'cuda').
    """
    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Generate images
    with torch.no_grad():
        for i in range(num_images):
            # Generate a single image
            noise = torch.randn(1, latent_dim, device=device)
            fake_image = generator(noise).detach().cpu().squeeze(0)

            # Denormalize and save the image
            fake_image = denormalize(fake_image)
            save_image(fake_image, os.path.join(output_dir, f'fake_image_{i}.png'))

def denormalize(image):
    """
    Denormalize the image tensor from [-1, 1] to [0, 1].

    :param image: The image tensor to denormalize.
    :return: The denormalized image tensor.
    """
    return image * 0.6 + 0.5

In [None]:
epochs = 10  # Adjust as needed
batch_size = 64  # Adjust based on your data and resources

for epoch in range(1, epochs+1):
    for i, data in enumerate(dataloader, 0):  # Assuming `dataloader` is your DataLoade
        discriminator.zero_grad()

        # Flatten real data and pass to discriminator
        real_data = data[0].view(data[0].size(0), -1).to(device)
        batch_size = real_data.size(0)
        labels = torch.full((batch_size, 1), 1, dtype=torch.float, device=device)

        output = discriminator(real_data)
        loss_real = criterion(output, labels)
        loss_real.backward()

        # Generate fake data, flatten it, and pass to discriminator
        noise = torch.randn(batch_size, latent_dim, device=device)
        fake_data = generator(noise)
        fake_data = fake_data.view(batch_size, -1)  # Flatten the fake data
        labels.fill_(0)

        output = discriminator(fake_data.detach())
        loss_fake = criterion(output, labels)
        loss_fake.backward()

        optimizer_D.step()
        # Train Discriminator
#         discriminator.zero_grad()
#         real_data = data[0].to(device)
#         print(real_data.size())
#         batch_size = real_data.size(0)
#         labels = torch.full((batch_size,), 1, dtype=torch.float, device=device)

#         output = discriminator(real_data)
#         loss_real = criterion(output, labels)
#         loss_real.backward()

#         noise = torch.randn(batch_size, latent_dim, device=device)
#         fake_data = generator(noise)
#         labels.fill_(0)

#         output = discriminator(fake_data.detach())
#         loss_fake = criterion(output, labels)
#         loss_fake.backward()

#         optimizer_D.step()

#         # Train Generator
#         generator.zero_grad()
#         labels.fill_(1)
#         output = discriminator(fake_data)
#         loss_G = criterion(output, labels)
#         loss_G.backward()
#         optimizer_G.step()
    # Log progress, e.g., print losses, save models, generate sample moves
    if epoch % 1000 == 0:
        print(f"Epoch {epoch}, Loss D: {loss_real.item() + loss_fake.item()}, Loss G: {loss_G.item()}")
        generate_and_save_images(generator, 5, 100, f'fin/epoch_{epoch}', device)
        # # Optionally save the model
        # torch.save(generator.state_dict(), f'generator_epoch_{epoch}.pth')
        # torch.save(discriminator.state_dict(), f'discriminator_epoch_{epoch}.pth')