In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Set the paths to training and validation folders on Google Drive
train_wf_path = '/content/drive/My Drive/BioSR/CCP/training_wf'
train_gt_path = '/content/drive/My Drive/BioSR/CCP/training_gt'
validate_wf_path = '/content/drive/My Drive/BioSR/CCP/validate_wf'
validate_gt_path = '/content/drive/My Drive/BioSR/CCP/validate_gt'
output_dir = '/content/drive/My Drive/BioSR/CCP/outputs'

# Create output directory if not exists
os.makedirs(output_dir, exist_ok=True)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
from PIL import Image
import torchvision.transforms as transforms

# Define image loading function
def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        img = Image.open(os.path.join(folder, filename)).convert('L')  # Convert to grayscale
        img = transforms.ToTensor()(img)  # Convert image to tensor
        images.append(img)
    return images

# Load WF and GT training images
train_wf_images = load_images_from_folder(train_wf_path)
train_gt_images = load_images_from_folder(train_gt_path)

# Load WF and GT validation images
validate_wf_images = load_images_from_folder(validate_wf_path)
validate_gt_images = load_images_from_folder(validate_gt_path)


In [None]:
!git clone https://github.com/JingyunLiang/SwinIR.git
!pip install timm

Cloning into 'SwinIR'...
remote: Enumerating objects: 333, done.[K
remote: Counting objects: 100% (13/13), done.[K
remote: Compressing objects: 100% (11/11), done.[K
remote: Total 333 (delta 6), reused 5 (delta 2), pack-reused 320 (from 1)[K
Receiving objects: 100% (333/333), 29.84 MiB | 20.10 MiB/s, done.
Resolving deltas: 100% (119/119), done.
Collecting timm
  Downloading timm-1.0.9-py3-none-any.whl.metadata (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.4/42.4 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Downloading timm-1.0.9-py3-none-any.whl (2.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m37.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: timm
Successfully installed timm-1.0.9


In [None]:
import torch
import torch.nn as nn
from SwinIR.models.network_swinir import SwinIR

# Instantiate the model with your required configuration
model = SwinIR(
    upscale=2,  # Upscaling factor for 32x32 -> 64x64
    in_chans=1,  # Grayscale input
    img_size=32,  # WF input size of 32x32
    window_size=8,  # Default window size
    depths=[4, 4, 4, 4],  # Depth of each layer
    embed_dim=120,  # Embedding dimension
    num_heads=[4, 4, 4, 4],  # Number of attention heads
    mlp_ratio=2,  # MLP ratio
    upsampler='pixelshuffledirect',  # Upsampler for super-resolution
    resi_connection='1conv'  # Residual connection
)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [None]:
import torch.optim as optim

# Define normalization function (Normalize images to [0, 1])
def normalize(img):
    return (img - img.min()) / (img.max() - img.min())

# Define optimizer and learning rate
optimizer = optim.Adam(model.parameters(), lr=1e-6)  # Lower learning rate

# Use L1 Loss
criterion = torch.nn.L1Loss()  # Switch to L1 loss

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)


In [None]:
import torch
import torch.nn as nn

# Define normalization function (Normalize images to [0, 1])
def normalize(img):
    if img.max() == img.min():  # Check if the image is constant
        return img  # Return the image unchanged if all values are the same
    return (img - img.min()) / (img.max() - img.min())

# Early stopping parameters
patience = 10  # How many epochs to wait after last validation improvement
best_val_loss = float('inf')  # Initialize best validation loss as infinity
patience_counter = 0  # To keep track of how long we've gone without improvement

# Training loop with early stopping
epochs = 100
for epoch in range(epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0

    for wf_img, gt_img in zip(train_wf_images, train_gt_images):
        wf_img, gt_img = wf_img.to(device), gt_img.to(device)

        # Normalize both WF and GT images
        wf_img = normalize(wf_img)
        gt_img = normalize(gt_img)

        optimizer.zero_grad()  # Clear gradients for next backward pass

        # Model inference
        output = model(wf_img.unsqueeze(0))  # Add batch dimension

        # Compute loss and normalize by output pixel count (64x64)
        num_pixels = 64 * 64  # Output image has 64x64 pixels
        loss = criterion(output, gt_img.unsqueeze(0)) / num_pixels  # Normalize by the number of pixels
        loss.backward()  # Backpropagate the error

        # Gradient clipping to avoid exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)  # More aggressive clipping

        optimizer.step()  # Update model weights

        running_loss += loss.item()

    # Print training loss for this epoch
    print(f'Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss/len(train_wf_images)}')

    # Validation step
    model.eval()  # Set model to evaluation mode
    val_loss = 0.0
    with torch.no_grad():  # Disable gradient calculation for validation
        for wf_img, gt_img in zip(validate_wf_images, validate_gt_images):
            wf_img, gt_img = wf_img.to(device), gt_img.to(device)

            # Normalize validation images
            wf_img = normalize(wf_img)
            gt_img = normalize(gt_img)

            # Model inference
            output = model(wf_img.unsqueeze(0))  # Add batch dimension

            # Compute validation loss
            loss = criterion(output, gt_img.unsqueeze(0)) / num_pixels  # Normalize by the number of pixels
            val_loss += loss.item()

    # Average validation loss
    val_loss = val_loss / len(validate_wf_images)
    print(f'Epoch [{epoch+1}/{epochs}], Validation Loss: {val_loss}')

    # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss  # Update best validation loss
        patience_counter = 0  # Reset patience counter
        # Save the best model
        torch.save(model.state_dict(), '/content/drive/My Drive/BioSR/CCP/swinir_best_model.pth')
        print("Validation loss improved, model saved!")
    else:
        patience_counter += 1  # Increment patience counter
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}. No improvement in validation loss for {patience} epochs.")
            break  # Stop training

# Save the final trained model
torch.save(model.state_dict(), '/content/drive/My Drive/BioSR/CCP/swinir_final_model.pth')


Epoch [1/1000], Training Loss: 1.603611149569133e-05
Epoch [1/1000], Validation Loss: 1.3928509093198551e-05
Validation loss improved, model saved!
Epoch [2/1000], Training Loss: 1.4020608224450371e-05
Epoch [2/1000], Validation Loss: 1.3577499179165089e-05
Validation loss improved, model saved!
Epoch [3/1000], Training Loss: 1.3777634656042892e-05
Epoch [3/1000], Validation Loss: 1.3424711267412527e-05
Validation loss improved, model saved!
Epoch [4/1000], Training Loss: 1.3641464552290472e-05
Epoch [4/1000], Validation Loss: 1.3270514330593466e-05
Validation loss improved, model saved!
Epoch [5/1000], Training Loss: 1.3543990832005192e-05
Epoch [5/1000], Validation Loss: 1.3186985070711268e-05
Validation loss improved, model saved!
Epoch [6/1000], Training Loss: 1.3478532500005259e-05
Epoch [6/1000], Validation Loss: 1.317756565564802e-05
Validation loss improved, model saved!
Epoch [7/1000], Training Loss: 1.3423196314334973e-05
Epoch [7/1000], Validation Loss: 1.3113094035411956e-0