In [None]:
import torch
import torch.nn as nn
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pickle
from pathlib import Path

print("Imports loaded")

Imports loaded


In [None]:
class SRCNN(nn.Module):
    def __init__(self, activation="relu", residual=False):
        super().__init__()
        act = nn.ReLU() if activation == "relu" else nn.PReLU()
        self.residual = residual
        self.net = nn.Sequential(
            nn.Conv2d(1, 64, 9, padding=4),
            act,
            nn.Conv2d(64, 32, 5, padding=2),
            act,
            nn.Conv2d(32, 1, 5, padding=2)
        )

    def forward(self, x):
        out = self.net(x)
        return x + out if self.residual else out

class ImprovedSRCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=9, padding=4)
        self.act1 = nn.PReLU()
        self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=2)
        self.act2 = nn.PReLU()
        self.conv3 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
        self.act3 = nn.PReLU()
        self.conv4 = nn.Conv2d(16, 1, kernel_size=5, padding=2)
        self.res_scale = 0.1

    def forward(self, x):
        identity = x
        out = self.act1(self.conv1(x))
        out = self.act2(self.conv2(out))
        out = self.act3(self.conv3(out))
        out = self.conv4(out)
        return identity + self.res_scale * out

In [None]:
def load_model(model_path, device='cpu'):

    with open(model_path, 'rb') as f:
        model_data = pickle.load(f)

    model_map = {
        'SRCNN': SRCNN,
        'ImprovedSRCNN': ImprovedSRCNN,
    }

    state = model_data.get('model_state_dict') or model_data.get('state_dict')
    if state is None:
        raise KeyError(f"No 'model_state_dict' or 'state_dict' found in {model_path}")

    model_obj = None

    if 'model_class' in model_data:
        mc = model_data['model_class']

        if isinstance(mc, str):
            cls = model_map.get(mc)
            if cls is None:
                raise KeyError(f"Unknown model_class name in pickle: {mc}")
            model_obj = cls()
        else:

            try:
                model_obj = mc()
            except Exception:
                # Fallback by name lookup
                name = getattr(mc, '__name__', None)
                if name and name in model_map:
                    model_obj = model_map[name]()
                else:
                    raise

    elif 'model_name' in model_data:
        name = model_data['model_name']
        cls = model_map.get(name)
        if cls is None:
            raise KeyError(f"Unknown model_name in pickle: {name}")
        model_obj = cls()
    else:

        try:
            model_obj = ImprovedSRCNN()
        except Exception:
            model_obj = SRCNN()

    model_obj.load_state_dict(state)
    model_obj.to(device)
    model_obj.eval()

    print(f"✓ Loaded model from {model_path}")
    print(f"  Parameters: {sum(p.numel() for p in model_obj.parameters()):,}")
    print(f"  Trained for {model_data.get('epochs', 'N/A')} epochs")
    return model_obj


def enhance_image(model, image_path, device='cpu', scale_factor=4):

    img = cv2.imread(str(image_path))
    if img is None:
        raise FileNotFoundError(f"Image not found: {image_path}")

    img_ycrcb = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)
    y_channel = img_ycrcb[:, :, 0]

    h, w = y_channel.shape
    y_bicubic = cv2.resize(y_channel, (w * scale_factor, h * scale_factor), 
                           interpolation=cv2.INTER_CUBIC)

    y_tensor = torch.tensor(y_bicubic / 255.).float().unsqueeze(0).unsqueeze(0)
    y_tensor = y_tensor.to(device)

    with torch.no_grad():
        sr_y = model(y_tensor).cpu().squeeze().numpy()

    sr_y = np.clip(sr_y * 255., 0, 255).astype(np.uint8)

    cr_channel = cv2.resize(img_ycrcb[:, :, 1], (w * scale_factor, h * scale_factor), 
                           interpolation=cv2.INTER_CUBIC)
    cb_channel = cv2.resize(img_ycrcb[:, :, 2], (w * scale_factor, h * scale_factor), 
                           interpolation=cv2.INTER_CUBIC)

    sr_ycrcb = np.stack([sr_y, cr_channel, cb_channel], axis=2)
    sr_bgr = cv2.cvtColor(sr_ycrcb, cv2.COLOR_YCrCb2BGR)

    return sr_bgr, y_bicubic

In [4]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Load improved model
improved_model = load_model('base_srcnn_model.pkl', device=DEVICE)

Using device: cpu
✓ Loaded model from base_srcnn_model.pkl
  Parameters: 57,281
  Trained for 19 epochs


In [5]:
# Process multiple images from a folder
from pathlib import Path

def batch_enhance(model, input_folder, output_folder, device='cpu'):
    """Enhance all images in a folder"""
    input_path = Path(input_folder)
    output_path = Path(output_folder)
    output_path.mkdir(exist_ok=True)
    
    image_files = list(input_path.glob("*.png")) + list(input_path.glob("*.jpg"))
    
    print(f"Found {len(image_files)} images to process")
    
    for img_file in image_files:
        print(f"Processing {img_file.name}...")
        sr_img, _ = enhance_image(model, img_file, device=device, scale_factor=4)
        output_file = output_path / f"sr_{img_file.name}"
        cv2.imwrite(str(output_file), sr_img)
    
    print(f"✓ All images saved to {output_folder}")

# Example usage (uncomment to use):
# batch_enhance(improved_model, "dataset/DIV2K_train_LR_bicubic/X4", "result/batch_output", device=DEVICE)

## Compare Before and After

In [8]:
# Specify your low-resolution image path
lr_image_path = "images2.jpeg"  # Change this to your image

# Create output directory
output_dir = Path("result")
output_dir.mkdir(exist_ok=True)

# Enhance with improved model
print("\nEnhancing image...")
improved_sr, bicubic = enhance_image(improved_model, lr_image_path, device=DEVICE, scale_factor=4)
cv2.imwrite(str(output_dir / "enhanced_output2.png"), improved_sr)
print(f"✓ Saved enhanced image to {output_dir / 'enhanced_output.png'}")


Enhancing image...
✓ Saved enhanced image to result/enhanced_output.png
✓ Saved enhanced image to result/enhanced_output.png
