In [1]:
from dataset import create_dataloader

In [None]:
# git clone https://github.com/xinntao/Real-ESRGAN
# cd Real-ESRGAN
# pip install -r requirements.txt
# python setup.py install


In [None]:
# Path to the preprocessed dataset
data_dir = "processed_dataset"  # Change this path to where the processed data is saved

# Create dataloaders
train_loader = create_dataloader(data_dir=data_dir, batch_size=8, shuffle=True)
valid_loader = create_dataloader(data_dir=data_dir, batch_size=8, shuffle=False)

# Example: Check one batch
for batch in train_loader:
    print(batch.shape)  # Expected: [batch_size, slice_depth, height, width]
    break

In [None]:
from realesrgan import RealESRGAN

# Set device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load a pretrained Real-ESRGAN model (You can also train from scratch if you want)
model = RealESRGAN(device, scale=4)
model.load_weights('weights/RealESRGAN_x4.pth')  # Specify the path to the pretrained weights
model.to(device)


In [None]:
import torch.optim as optim

# Loss function for training (Real-ESRGAN uses a combination of adversarial loss and pixel-wise loss)
criterion = torch.nn.MSELoss()  # Or any custom loss function like perceptual loss

# Optimizer (Adam)
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [None]:
# Example training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for batch_idx, lr_images in enumerate(train_loader):
        lr_images = lr_images.to(device)  # Move data to GPU if available
        hr_images = resize(lr_images, (512, 512))  # Resize HR to 512x512 for training

        # Forward pass
        sr_images = model(lr_images)  # Super-Resolution output

        # Calculate loss
        loss = criterion(sr_images, hr_images)
        epoch_loss += loss.item()

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader)}")


In [None]:
# Model evaluation (inference)
model.eval()  # Set to evaluation mode

# Example prediction for a single image
test_image = torch.load('path_to_test_image.pt').to(device)  # Replace with actual image loading
with torch.no_grad():
    sr_image = model(test_image)

# Save the output image after super-resolution
sr_image = sr_image.cpu().numpy().squeeze()  # Convert to numpy for saving or visualization


In [None]:
import matplotlib.pyplot as plt

plt.imshow(sr_image[0], cmap='gray')  # Show the super-resolved image
plt.savefig('output_sr_image.png')  # Save the output image


In [None]:
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from realesrgan import RealESRGAN
from torchvision import transforms

# Load model and weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RealESRGAN(device, scale=4)
model.load_weights('weights/RealESRGAN_x4.pth')
model.to(device)

# Dataset and DataLoader (assumes preprocessed data is ready)
train_loader = create_dataloader(data_dir="processed_dataset", batch_size=8, shuffle=True)

# Optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.MSELoss()

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for batch_idx, lr_images in enumerate(train_loader):
        lr_images = lr_images.to(device)
        hr_images = resize(lr_images, (512, 512))  # Resize HR

        # Forward pass
        sr_images = model(lr_images)

        # Compute loss
        loss = criterion(sr_images, hr_images)
        epoch_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader)}")

# Evaluate or predict using the model
model.eval()
with torch.no_grad():
    test_image = torch.load('path_to_test_image.pt').to(device)
    sr_image = model(test_image)

# Save and display the super-resolved image
sr_image = sr_image.cpu().numpy().squeeze()
plt.imshow(sr_image[0], cmap='gray')
plt.savefig('output_sr_image.png')
