In [6]:
import torch
from torch import nn
from torch.optim import Adam
from realesrgan import RealESRGANer
from dataset import create_dataloader



In [None]:
# git clone https://github.com/xinntao/Real-ESRGAN
# cd Real-ESRGAN
# pip install -r requirements.txt
# python setup.py install


In [None]:
# Path to the preprocessed dataset
data_dir = "processed_dataset"  # Change this path to where the processed data is saved

# Create dataloaders
train_loader = create_dataloader(data_dir=data_dir, batch_size=8, shuffle=True)
valid_loader = create_dataloader(data_dir=data_dir, batch_size=8, shuffle=False)

# Example: Check one batch
for batch in train_loader:
    print(batch.shape)  # Expected: [batch_size, slice_depth, height, width]
    break

In [7]:

# Set device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load a pretrained Real-ESRGAN model (You can also train from scratch if you want)
model = RealESRGANer(device, scale=4)
# model.load_weights('weights/RealESRGAN_x4.pth')  # Specify the path to the pretrained weights
model.to(device)


TypeError: RealESRGANer.__init__() got multiple values for argument 'scale'

In [None]:
def train_real_esrgan(model, dataloader, device, num_epochs=10, learning_rate=1e-4):
    """
    Train the Real-ESRGAN model with the provided DataLoader.
    
    Args:
        model (nn.Module): The Real-ESRGAN model.
        dataloader (DataLoader): The data loader.
        device (str): Device to run the training on ('cuda' or 'cpu').
        num_epochs (int): Number of training epochs.
        learning_rate (float): Learning rate for the optimizer.
    """
    
    # Setup the optimizer and loss function
    optimizer = Adam(model.parameters(), lr=learning_rate)
    criterion = nn.L1Loss()  # Use L1 loss (Mean Absolute Error) or MSELoss

    # Train the model
    model.train()  # Set the model to training mode

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for lr_patches, hr_patches in dataloader:
            lr_patches = lr_patches[0].to(device)  # Take the first patch from the batch
            hr_patches = hr_patches[0].to(device)  # Take the first patch from the batch
            
            # Zero gradients
            optimizer.zero_grad()

            for lr_patch, hr_patch in zip(lr_patches, hr_patches):
                # Forward pass for a single patch
                sr_patch = model(lr_patch.unsqueeze(0))  # Add batch dimension for the model

                # Compute the loss (L1 Loss)
                loss = criterion(sr_patch, hr_patch.unsqueeze(0))  # Compare SR output with HR patch
                epoch_loss += loss.item()

                # Backpropagation
                loss.backward()
                optimizer.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(dataloader)}")

        # Save model checkpoint every few epochs (optional)
        if (epoch + 1) % 5 == 0:
            torch.save(model.state_dict(), f"real_esrgan_epoch_{epoch+1}.pth")

# Load your dataset (replace with actual paths)
root_dir = "dataset/train"  # Path to your dataset
metadata_csv = "train_metadata.csv"  # Path to the metadata CSV
dataloader = create_dataloader(root_dir, metadata_csv, batch_size=8, shuffle=True, num_workers=4)

# Initialize the Real-ESRGAN model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RealESRGAN(device, scale=4)  # 4x upscale, adjust based on your needs
model.to(device)

# Train the Real-ESRGAN model
train_real_esrgan(model, dataloader, device, num_epochs=10, learning_rate=1e-4)


In [None]:
# Model evaluation (inference)
model.eval()  # Set to evaluation mode

# Example prediction for a single image
test_image = torch.load('path_to_test_image.pt').to(device)  # Replace with actual image loading
with torch.no_grad():
    sr_image = model(test_image)

# Save the output image after super-resolution
sr_image = sr_image.cpu().numpy().squeeze()  # Convert to numpy for saving or visualization


In [None]:
import matplotlib.pyplot as plt

plt.imshow(sr_image[0], cmap='gray')  # Show the super-resolved image
plt.savefig('output_sr_image.png')  # Save the output image
