In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor, Resize
from torchvision.utils import save_image
from torchvision.transforms import InterpolationMode
import torchvision.transforms as transforms
from tqdm import tqdm
from torchvision.utils import save_image

In [12]:
class ResidualBlock(nn.Module):
    def __init__(self, num_channels):
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_channels),
            nn.PReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_channels)
        )

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = torch.add(out, x)  # Residual connection

        return out

In [13]:

class Generator(nn.Module):
    def __init__(self, scale_factor=4, num_channels=3, num_residual_blocks=16):
        super(Generator, self).__init__()

        # First convolutional layer
        self.conv1 = nn.Sequential(
            nn.Conv2d(num_channels, 64, kernel_size=9, stride=1, padding=4),
            nn.PReLU()
        )

        # Residual blocks
        self.residual_blocks = nn.Sequential()
        for _ in range(num_residual_blocks):
            self.residual_blocks.add_module('res_block', ResidualBlock(64))

        # Second convolutional layer
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64)
        )

        # Upsampling layers
        self.upsampling = nn.Sequential(
            nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(scale_factor),  # Upsampling using pixel shuffle
            nn.PReLU()
        )
        print('num_channels = {}'.format(num_channels))
        # Final convolutional layer
        self.conv3 = nn.Conv2d(16, num_channels, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        out1 = self.conv1(x)
        residual = out1
        out2 = self.residual_blocks(out1)
        out2 = self.conv2(out2)

        out3 = torch.add(out1, out2)
        out4 = self.upsampling(out3)
        out = self.conv3(out4)

        return out

In [14]:
class Discriminator(nn.Module):
    def __init__(self, num_channels=3, num_features=64):
        super(Discriminator, self).__init__()

        # First convolutional layer
        self.conv1 = nn.Sequential(
            nn.Conv2d(num_channels, num_features, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2)
        )

        # Convolutional layers with stride 2
        self.conv2 = nn.Sequential(
            nn.Conv2d(num_features, num_features * 2, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(num_features * 2),
            nn.LeakyReLU(0.2)
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(num_features * 2, num_features * 4, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(num_features * 4),
            nn.LeakyReLU(0.2)
        )

        # Convolutional layers with stride 1
        self.conv4 = nn.Sequential(
            nn.Conv2d(num_features * 4, num_features * 8, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features * 8),
            nn.LeakyReLU(0.2)
        )

        # Final classification layer
        self.conv5 = nn.Conv2d(num_features * 8, 1, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        out1 = self.conv1(x)
        out2 = self.conv2(out1)
        out3 = self.conv3(out2)
        out4 = self.conv4(out3)
        out5 = self.conv5(out4)

        return out5


In [15]:
# Initialize the generator and discriminator
generator = Generator()
discriminator = Discriminator()


num_channels = 3


In [16]:
# Define loss function and optimizers
criterion = nn.BCELoss()  # Binary cross-entropy loss
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))


In [17]:
# Set the path to your data folder
data_path = "/home/gustavosmc/Documentos/GPT_Gan/dataset/"

# Define the transformations for your data
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize the images to a fixed size
    transforms.ToTensor(),          # Convert images to tensors
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize image pixels to the range [-1, 1]
])

transformHr = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize the images to a fixed size
    transforms.ToTensor(),          # Convert images to tensors
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize image pixels to the range [-1, 1]
])

# Load the high-resolution and low-resolution images
hr_dataset = ImageFolder(root=data_path + "hr/", transform=transform)
lr_dataset = ImageFolder(root=data_path + "lr/", transform=transformHr)

# Create the data loader for high-resolution and low-resolution images
batch_size = 10
shuffle = True  # Set to True if you want to shuffle the data
num_workers = 6  # Set the number of worker processes for data loading
hr_data_loader = DataLoader(hr_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
lr_data_loader = DataLoader(lr_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)


# Load and preprocess the dataset
dataset = ImageFolder("dataset", transform=Resize((64, 64), interpolation= InterpolationMode.BICUBIC))
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)


In [18]:
num_epochs = 10
sample_interval = 10
save_interval = 3

In [19]:
# Set the device to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate the generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Move the models to the device
generator.to(device)
discriminator.to(device)

# Define the loss function (adversarial and content losses)
adversarial_loss = nn.BCEWithLogitsLoss()
content_loss = nn.MSELoss()

# Define the optimizers for generator and discriminator
lr = 0.0005
betas = (0.5, 0.999)
generator_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=betas)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=betas)

num_channels = 3


In [20]:
# Training loop
for epoch in tqdm(range(num_epochs)):
    for i, (hr_images, lr_images) in enumerate(zip(hr_data_loader, lr_data_loader)):
        # Move images to the device
        hr_images = hr_images[0].to(device).float()
        lr_images = lr_images[0].to(device).float()
        
        # --------------------
        # Train the discriminator
        # --------------------
        
        discriminator.zero_grad()
        # Generate high-resolution images from low-resolution images
        
        # Real and fake labels for adversarial loss
        real_labels = torch.ones(hr_images.size(0), 1, device=device)
        fake_labels = torch.zeros(hr_images.size(0), 1, device=device)
        
        # Discriminator loss for real images
        #real_outputs = discriminator(hr_images)
        #d_loss_real = adversarial_loss(real_outputs, real_labels)
        
        real_outputs = discriminator(hr_images)
        real_labels = torch.ones_like(real_outputs)
        d_loss_real = adversarial_loss(real_outputs, real_labels)
        
        # Discriminator loss for fake images
        #fake_outputs = discriminator(sr_images.detach())
        #d_loss_fake = adversarial_loss(fake_outputs, fake_labels)

        sr_images = generator(lr_images)
        sr_images = nn.functional.interpolate(sr_images, size=hr_images.shape[2:], mode='bicubic', align_corners=False)


        fake_outputs = discriminator(sr_images)
        fake_labels = torch.zeros_like(fake_outputs)
        d_loss_fake = adversarial_loss(fake_outputs, fake_labels)
        
        # Total discriminator loss
        d_loss = d_loss_real + d_loss_fake
        
        # Update discriminator weights
        discriminator_optimizer.zero_grad()
        d_loss.backward(retain_graph=True)
        discriminator_optimizer.step()
        
        # --------------------
        # Train the generator
        # --------------------
        
        generator.zero_grad()
        
        # Adversarial loss
        fake_outputs = discriminator(sr_images)
        g_loss_adversarial = adversarial_loss(fake_outputs, real_labels)
        
        # Content loss
        g_loss_content = content_loss(sr_images, hr_images)
        
        # Total generator loss
        g_loss = g_loss_adversarial + 0.01 * g_loss_content
        
        # Update generator weights
        generator_optimizer.zero_grad()
        g_loss.backward(retain_graph=True)
        generator_optimizer.step()
        
        # Print progress
    
        print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(hr_data_loader)}], "
                f"Generator Loss: {g_loss.item():.4f}, Discriminator Loss: {d_loss.item():.4f}")
        
        if (i + 1) % save_interval == 0:
            torch.save(generator.state_dict(), f"generator_model_epoch{epoch+1}_batch{i+1}.pt")
            print(f"Saved generator model at epoch {epoch+1}, batch {i+1}")

            with torch.no_grad():
                sr_images = generator(lr_images)
            save_image(sr_images, f"sr_image_epoch{epoch + 1}_batch{i+1}.png", normalize = True)


  0%|          | 0/10 [00:00<?, ?it/s]