In [29]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchvision.utils import save_image

In [30]:
class ResidualBlock(nn.Module):

    def __init__(self, kernelSize = 3, inChannels = 64, outChannels = 64, strd = 1, paddng = 1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels = inChannels, out_channels = outChannels, kernel_size = kernelSize, stride = strd, padding = paddng),
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.Conv2d(in_channels = inChannels, out_channels = outChannels, kernel_size = kernelSize, stride = strd, padding = paddng),
            nn.BatchNorm2d(64)
        )
    def forward(self, x):
        out = self.block(x)
        return torch.add(out, x)

In [31]:
class UpsampleBlock(nn.Module):
    def __init__(self, inChannels,scaleFactor):
        super().__init__()
        self.conv = nn.Conv2d(in_channels= inChannels, out_channels= inChannels * scaleFactor ** 2, kernel_size=3, stride=1, padding=1)
        self.ps = nn.PixelShuffle(scaleFactor)
        self.act = nn.PReLU(inChannels)
    def forward(self, x):
        return self.act(self.ps(self.conv(x)))

In [32]:
class SRResnet(nn.Module):
    def __init__(self):
        super(SRResnet, self).__init__()

        self.l1 = nn.Conv2d(kernel_size=9, stride=1, in_channels=3, out_channels=64, padding=4)
        self.l2 = nn.PReLU()

        self.residuals = nn.Sequential()
        for _ in range(0, 16):
            self.residuals.add_module('residualBlock',ResidualBlock())

        self.postResiduals = nn.Sequential(
            nn.Conv2d(in_channels= 64, out_channels=64, kernel_size= 3, stride=1, padding=1),
            nn.BatchNorm2d(64),
        )

        self.upsample = UpsampleBlock(64, 2)
        self.upsample2 = UpsampleBlock(64, 2)

        self.final = nn.Conv2d(64, 3, kernel_size= 9, stride=1, padding=4)

    def forward(self, x):
        x = self.l1(x)
        x1 = self.l2(x)
        x = self.residuals(x1)
        x = self.postResiduals(x)
        x = torch.add(x, x1)
        x = self.upsample(x)
        x = self.upsample2(x)
        x = self.final(x)

        return x

In [33]:
# Set the path to your data folder

data_path = "dataset/"

# Define the transformations for your data
transform = transforms.Compose([
    transforms.CenterCrop((256, 256)),  # Resize the images to a fixed size
    transforms.ToTensor(),          # Convert images to tensors
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize image pixels to the range [-1, 1]
])
transformLr = transforms.Compose([
    transforms.CenterCrop((256, 256)),
    transforms.Resize((64,64)), # Resize the images to a fixed size
    transforms.ToTensor(),          # Convert images to tensors
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize image pixels to the range [-1, 1]
])


# Load the high-resolution and low-resolution images
hr_dataset = ImageFolder(root=data_path + "lr", transform=transform)
lr_dataset = ImageFolder(root=data_path + "lr", transform=transformLr)
#sr_dataset = ImageFolder(root=data_path + "autoencodertrain", transform=transform)

# Create the data loader for high-resolution and low-resolution images
batch_size = 6
num_workers = 2  # Set the number of worker processes for data loading
hr_data_loader = DataLoader(hr_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
lr_data_loader = DataLoader(lr_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [34]:
num_epochs = 5000

In [35]:
# Set the device to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

autoencoder = SRResnet()

# Move the models to the device
autoencoder.to(device)

# Define the loss function (adversarial and content losses)
ae_content_loss = nn.MSELoss()

cuda


In [36]:
# Define the optimizers for generator and discriminator
lr = 0.0001
betas = (0.5, 0.9)
autoencoder_optimizer = optim.Adam(autoencoder.parameters(), lr = lr, betas = betas)

In [37]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x2c2e0713490>

In [38]:
# Training loop
for epoch in range(0, num_epochs):
    for i, (hr_images, sr_images) in enumerate(zip(hr_data_loader, lr_data_loader)):
        # Move images to the device
        hr_images = hr_images[0].to(device).float()
        sr_images = sr_images[0].to(device).float()

        if i == 40 and epoch == 0:
            save_image(hr_images, f"AE-results/hr_image_epoch{epoch + 1}_batch{i+1}.png", normalize = True)
            save_image(sr_images, f"AE-results/lr_image_epoch{epoch + 1}_batch{i+1}.png", normalize = True)

        # --------------------
        # Train the Autoencoder
        # --------------------
        ae_sr_images = autoencoder(sr_images)
        autoencoder_optimizer.zero_grad()
        ae_loss = ae_content_loss(hr_images, ae_sr_images)
        ae_loss.backward(retain_graph =True)
        autoencoder_optimizer.step()

        # Print progress
        if(i + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(hr_data_loader)}], "
                f"AELoss: {ae_loss.item():.4f}")

        if (i + 1) % 41 == 0 and ((epoch % 2) == 0):
            torch.save(autoencoder.state_dict(), f"Models/AE/autoencoder_model_epoch{epoch+1}_batch{i+1}.pt")
            print(f"Saved AE model at epoch {epoch+1}, batch {i+1}")

            #save_image(sr_images, f"AE-results/sr_image_epoch{epoch + 1}_batch{i+1}.png", normalize = True)
            save_image(ae_sr_images, f"AE-results/ae_sr_image_epoch{epoch + 1}_batch{i+1}.png", normalize = True)

Epoch [1/5000], Step [10/41], AELoss: 0.1007
Epoch [1/5000], Step [20/41], AELoss: 0.1057
Epoch [1/5000], Step [30/41], AELoss: 0.0932
Epoch [1/5000], Step [40/41], AELoss: 0.1144
Saved AE model at epoch 1, batch 41
Epoch [2/5000], Step [10/41], AELoss: 0.0630
Epoch [2/5000], Step [20/41], AELoss: 0.0866
Epoch [2/5000], Step [30/41], AELoss: 0.0784
Epoch [2/5000], Step [40/41], AELoss: 0.1034
Epoch [3/5000], Step [10/41], AELoss: 0.0549
Epoch [3/5000], Step [20/41], AELoss: 0.0800
Epoch [3/5000], Step [30/41], AELoss: 0.0720
Epoch [3/5000], Step [40/41], AELoss: 0.0964
Saved AE model at epoch 3, batch 41
Epoch [4/5000], Step [10/41], AELoss: 0.0509
Epoch [4/5000], Step [20/41], AELoss: 0.0754
Epoch [4/5000], Step [30/41], AELoss: 0.0676
Epoch [4/5000], Step [40/41], AELoss: 0.0913
Epoch [5/5000], Step [10/41], AELoss: 0.0480
Epoch [5/5000], Step [20/41], AELoss: 0.0723
Epoch [5/5000], Step [30/41], AELoss: 0.0638
Epoch [5/5000], Step [40/41], AELoss: 0.0871
Saved AE model at epoch 5, b

KeyboardInterrupt: 