In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image

In [2]:
# Define the generator
class Generator(nn.Module):
    def __init__(self, latent_dim, channels):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),#seecond changed to 7 from 4
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self, channels):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),  # Changed output channels to 1
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)


In [3]:
# Set up the GAN
latent_dim = 100
channels = 3  # RGB channels
img_size = 64  # Desired image size    64 to 224 change

# Initialize the generator and discriminator
generator = Generator(latent_dim, channels)
discriminator = Discriminator(channels)

# Set up the loss function and optimizers
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [4]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset

class SkinLesionDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.image_paths = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.png') or f.endswith('.jpg')]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        if self.transform:
            image = self.transform(image)
        return image

# Load your skin lesion dataset
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*channels, [0.5]*channels)
])

dataset = SkinLesionDataset('images', transform=transform)
total_images = len(dataset)

# Calculate the batch size
batch_size = min(total_images, 64)  # Use 64 as the maximum batch size

# Create the data loader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [6]:
# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    for real_images in dataloader:
        # Train the discriminator
        d_optimizer.zero_grad()
        batch_size = real_images.size(0)
        real_labels = torch.ones(batch_size, 1, 1, 1)
        fake_labels = torch.zeros(batch_size, 1, 1, 1)
        real_outputs = discriminator(real_images)
        real_loss = criterion(real_outputs, real_labels)
        latent_vectors = torch.randn(batch_size, latent_dim, 1, 1)
        fake_images = generator(latent_vectors)
        fake_outputs = discriminator(fake_images.detach())
        fake_loss = criterion(fake_outputs, fake_labels)
        d_loss = real_loss + fake_loss
        d_loss.backward()
        d_optimizer.step()

        # Train the generator
        g_optimizer.zero_grad()
        latent_vectors = torch.randn(batch_size, latent_dim, 1, 1)
        fake_images = generator(latent_vectors)
        fake_outputs = discriminator(fake_images)
        g_loss = criterion(fake_outputs, real_labels)  # Use real labels for generator loss
        g_loss.backward()
        g_optimizer.step()

    # Print losses and save generated images
    print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')
    if (epoch + 1) % 10 == 0:
        with torch.no_grad():
            latent_vectors = torch.randn(16, latent_dim, 1, 1)
            fake_images = generator(latent_vectors)
            image_path = os.path.join('gan_images', f'generated_image_epoch_{epoch+1}.png')
            save_image(fake_images, image_path, normalize=True)

# Save the trained generator model
torch.save(generator.state_dict(), 'generator.pth')


Epoch [1/100], d_loss: 0.02593955211341381, g_loss: 6.403222560882568
Epoch [2/100], d_loss: 0.007860735058784485, g_loss: 6.515342712402344
Epoch [3/100], d_loss: 0.015968013554811478, g_loss: 6.424894332885742
Epoch [4/100], d_loss: 0.0032134975772351027, g_loss: 6.879193305969238
Epoch [5/100], d_loss: 0.008223801851272583, g_loss: 6.21870231628418
Epoch [6/100], d_loss: 0.006766670383512974, g_loss: 6.731940746307373
Epoch [7/100], d_loss: 0.0075263953767716885, g_loss: 6.514240741729736
Epoch [8/100], d_loss: 0.01137425284832716, g_loss: 6.28379487991333
Epoch [9/100], d_loss: 0.005279887467622757, g_loss: 6.377618789672852
Epoch [10/100], d_loss: 0.010681990534067154, g_loss: 6.394758701324463
Epoch [11/100], d_loss: 0.011862666346132755, g_loss: 5.996876239776611
Epoch [12/100], d_loss: 0.00313474889844656, g_loss: 6.748960494995117
Epoch [13/100], d_loss: 0.009139335714280605, g_loss: 6.2197346687316895
Epoch [14/100], d_loss: 0.006484953220933676, g_loss: 6.692654609680176
Epo