In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from wgan_image_synthesis import Generator, Discriminator
import matplotlib.pyplot as plt
import numpy as np
import torchvision.utils as vutils
import torch.optim as optim
import torch.nn as nn
from torchvision.utils import save_image

In [None]:
class DiseaseDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.image_files = os.listdir(img_dir)
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

In [None]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),      
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize to [-1, 1]
])

In [None]:
dev = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = torch.device(dev)
ngpu = 2

In [None]:
path = "/kaggle/input/vce-dataset/training/Erosion"
batch_size = 32
shuffle = True

In [None]:
dataset = DiseaseDataset(path, transform=transform)

In [None]:
dataloader = DataLoader(dataset = dataset, batch_size = batch_size, shuffle = shuffle)

In [None]:
real_batch = next(iter(dataloader))
plt.figure(figsize = (8,8))
plt.axis("off")
plt.title("Training_imgs")
plt.imshow(np.transpose(vutils.make_grid(real_batch.to(device), padding = 2, normalize=True).cpu(),(1,2,0)))
plt.show()

In [None]:
netD = Discriminator(ngpu,64,3).to(device)
# # Handle multi-GPU if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Print the model
print(netD)

In [None]:
netG = Generator(ngpu,64,3,100).to(device)

# Handle multi-GPU if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Print the model
print(netG)

In [None]:
def filtered_params(model):
    return [param for name, param in model.named_parameters() if 'batch' not in name]
params = filtered_params(netG)

In [None]:
fixed_noise = torch.randn(batch_size, 100, 1,1,device = device)

In [None]:
D_lr = 0.001
G_lr = 0.001

In [None]:
optimizerD = optim.RMSprop(netD.parameters(), lr = D_lr)
optimizerG = optim.RMSprop(netG.parameters(), lr = G_lr)

In [None]:
num_epochs = 20

In [None]:
print("Starting Training Loop...")
img_list = []
G_losses = []
D_losses = []
D_real_losses = []
D_fake_losses = []
iters = 0

for epoch in range(1,num_epochs+1):
    for i, data in enumerate(dataloader, 0):
        for _ in range (5):
            ##train with all reals
            netD.zero_grad()
            #setup batch
            real_cpu = data.to(device)
            b_size = real_cpu.size(0)
            
            #fwd pass thru D
            output = netD(real_cpu)
            errD_real = -torch.mean(output)
            D_x = output.mean().item()
            ##train with all fakes
            #gen batch of latent

            noise = torch.randn(b_size, 100,1,1, device = device)
            #generate fake batch by G
            fake = netG(noise)
            output = netD(fake.detach())
            D_G_z1 = output.mean().item()
            errD_fake = torch.mean(output)
            #compute errD as sum of fake and real
            errD = errD_fake + errD_real

            #update D
            errD.backward()
            optimizerD.step()
            
            with torch.no_grad():
                for name, param in netD.named_parameters():
                    if 'batch' not in name:
                        param.clamp_(-0.01,0.01)

        
        ############################
        # (2) Update G network
        ###########################
        
        netG.zero_grad()
#         label.fill_(real_label)# fake labels are real for generator cost
        noise = torch.randn(b_size, 100,1,1, device = device)
        fake = netG(noise)
        output = netD(fake)
        errG = -torch.mean(output)
#         print(errG)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()
        
        if i%1 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        D_real_losses.append(errD_real.item())
        D_fake_losses.append(errD_fake.item())
    
    if(epoch%5 == 0):
        print(f"EPOCH_{epoch} OUTPUT")
        with torch.no_grad():
            fake = netG(fixed_noise).detach().cpu()
        img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
        grid_image = vutils.make_grid(torch.cat(img_list), padding = True, normalize = True)
        processed = np.transpose(grid_image.cpu().detach().numpy(), (1, 2, 0))
        fig = plt.figure(figsize=(16,16))
        plt.axis("off")
    #     ims = [[ plt.imshow(np.transpose(i,(1,2,0)), cmap = 'gray')]for i in img_list]
        plt.imshow(np.transpose(grid_image, (1, 2, 0)))
        plt.show()

        img_list = []

    iters += 1
    

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(D_real_losses, label = "D_real_loss")
plt.plot(D_fake_losses, label = "D_fake_loss")
plt.plot(G_losses,label="G_loss")
plt.plot(D_losses,label="D_net_loss")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
os.makedirs("checkpoints",exist_ok=True)

PATH = f"/kaggle/working/checkpoints/last_GAN.pt"
torch.save({

    'Gen_state_dict': netG.state_dict(),
    'Disc_state_dict': netD.state_dict(),
    'optimizerG_state_dict': optimizerG.state_dict(),
    'optimizerD_state_dict': optimizerD.state_dict(),
}, PATH)

In [None]:
save_dir = 'SynthesizedImages'
os.makedirs(save_dir, exist_ok=True)

# Number of images to generate
n = 32

# Set your generator to evaluation mode
netG.eval()

# Generate n random latent vectors
latent_dim = 100  # Replace with the dimension of your latent space
random_noise = torch.randn(n, latent_dim, 1, 1).to(device)  # Assuming you use 1x1 feature maps

# Generate images using the generator
with torch.no_grad():
    generated_images = netG(random_noise)

# Denormalize the images if you used normalization during training
generated_images = (generated_images + 1) / 2  # Assuming images were normalized to [-1, 1]

# Save each generated image
for i in range(n):
    save_image(generated_images[i], os.path.join(save_dir, f'generated_image_{i+1}.png'))

In [None]:
!zip -r SynthesizedImages.zip SynthesizedImages 