# ESPCN SR with TensorBoard & Network Visualization

This notebook trains a stronger ESPCN model, logs scalars and images to TensorBoard, and shows model summaries and a computation graph.

**Features:**
- stronger ESPCN (64->32 filters)
- Charbonnier loss
- data augmentation
- checkpointing
- TensorBoard logging: loss, PSNR, SSIM, and example images
- network summary (torchsummary) and computation graph (torchviz)

Run on Colab with GPU runtime for best results.


In [ ]:
# Install required packages
!pip install torch torchvision tensorboard pillow scikit-image torchsummary torchviz --quiet

In [ ]:
# Imports
import os, zipfile, requests, io, random
from pathlib import Path
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

from torch.utils.tensorboard import SummaryWriter
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim

# visualization helpers
from torchsummary import summary
from torchviz import make_dot

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

## Download DIV2K (small subset)
We download the training HR zip and extract. This may take a few minutes.

In [ ]:
# Download a smaller subset or the full train set depending on availability
url = 'http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip'
if not os.path.exists('DIV2K_train_HR'):
    print('Downloading DIV2K_train_HR.zip (this may be large)...')
    r = requests.get(url, stream=True)
    open('DIV2K_train_HR.zip','wb').write(r.content)
    with zipfile.ZipFile('DIV2K_train_HR.zip','r') as z:
        z.extractall('DIV2K_train_HR')
    # many files extracted into DIV2K_train_HR/DIV2K_train_HR
    # normalize path
    if os.path.exists('DIV2K_train_HR/DIV2K_train_HR'):
        os.rename('DIV2K_train_HR/DIV2K_train_HR', 'DIV2K_train_HR/images')
    else:
        # move pngs into folder images
        os.makedirs('DIV2K_train_HR/images', exist_ok=True)
        for f in os.listdir('DIV2K_train_HR'):
            if f.endswith('.png'):
                os.rename(os.path.join('DIV2K_train_HR', f), os.path.join('DIV2K_train_HR/images', f))

# collect image folder
if os.path.exists('DIV2K_train_HR/images'):
    data_root = 'DIV2K_train_HR/images'
else:
    data_root = 'DIV2K_train_HR'

print('Data root:', data_root)
print('Found', len([f for f in os.listdir(data_root) if f.lower().endswith('.png')]), 'images')

## Dataset and augmentations

In [ ]:
class SRDataset(Dataset):
    def __init__(self, root, scale=2, patch_size=96, train=True, max_images=None):
        self.files = sorted([os.path.join(root, x) for x in os.listdir(root) if x.lower().endswith('.png')])
        if max_images:
            self.files = self.files[:max_images]
        self.scale = scale
        self.patch_size = patch_size
        self.train = train
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert('RGB')
        if self.train:
            w,h = img.size
            if w < self.patch_size or h < self.patch_size:
                img = img.resize((self.patch_size, self.patch_size), Image.BICUBIC)
                w,h = img.size
            x = random.randint(0, w-self.patch_size)
            y = random.randint(0, h-self.patch_size)
            hr = img.crop((x,y,x+self.patch_size,y+self.patch_size))
            # augment
            if random.random() < 0.5:
                hr = hr.transpose(Image.FLIP_LEFT_RIGHT)
            if random.random() < 0.5:
                hr = hr.transpose(Image.FLIP_TOP_BOTTOM)
            if random.random() < 0.5:
                hr = hr.rotate(90)
        else:
            hr = img.resize((self.patch_size, self.patch_size), Image.BICUBIC)
        lr = hr.resize((hr.size[0]//self.scale, hr.size[1]//self.scale), Image.BICUBIC)
        lr_up = lr.resize(hr.size, Image.BICUBIC)
        return self.to_tensor(lr), self.to_tensor(hr), self.to_tensor(lr_up)

# small subset for quick runs
train_ds = SRDataset(data_root, train=True, max_images=None)
val_ds = SRDataset(data_root, train=False, max_images=None)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=True, num_workers=1)
print('Train size:', len(train_ds), 'Val size:', len(val_ds))

## Model: stronger ESPCN + Charbonnier loss

In [ ]:
class ESPCN(nn.Module):
    def __init__(self, scale=2):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, scale*scale*3, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.pixel_shuffle(self.conv3(x))
        return x

class CharbonnierLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
    def forward(self, pred, target):
        return torch.mean(torch.sqrt((pred - target) ** 2 + self.eps))

model = ESPCN(scale=2).to(device)
print('Model created on', device)

In [ ]:
# Network text summary
try:
    summary(model, (3, 48, 48))
except Exception as e:
    print('torchsummary failed:', e)

# Computation graph via torchviz (make_dot)
try:
    dummy = torch.randn(1,3,48,48).to(device)
    out = model(dummy)
    g = make_dot(out, params=dict(list(model.named_parameters())))
    g.format = 'png'
    g.render('espcn_graph', cleanup=True)
    from IPython.display import Image, display
    display(Image('espcn_graph.png'))
except Exception as e:
    print('torchviz graph failed:', e)

## Training loop with TensorBoard logging and checkpoints
We will log: training loss per step, and per-epoch validation PSNR/SSIM. We also log image triplets (LR upsample, SR, HR) to TensorBoard every epoch.

In [ ]:
writer = SummaryWriter(log_dir='runs/espcn_experiment')
criterion = CharbonnierLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

start_epoch = 0
ckpt_path = 'espcn_tb_checkpoint.pth'
if os.path.exists(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt['model'])
    optimizer.load_state_dict(ckpt['optim'])
    start_epoch = ckpt['epoch']
    print('Resumed from', start_epoch)

num_epochs = 200
global_step = 0
for epoch in range(start_epoch, num_epochs):
    model.train()
    running_loss = 0.0
    for i, (lr, hr, lr_up) in enumerate(train_loader):
        lr = lr.to(device)
        hr = hr.to(device)
        sr = model(lr)
        loss = criterion(sr, hr)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        running_loss += loss.item()
        writer.add_scalar('train/loss', loss.item(), global_step)
        global_step += 1
    avg_loss = running_loss / len(train_loader)
    print(f'Epoch {epoch+1}/{num_epochs} - train loss {avg_loss:.6f}')

    # Validation: compute PSNR/SSIM average on val set (first N)
    model.eval()
    psnr_b, psnr_sr, ssim_b, ssim_sr = [], [], [], []
    with torch.no_grad():
        for j, (lr, hr, lr_up) in enumerate(val_loader):
            if j >= 10: break
            lr = lr.to(device); hr = hr.to(device); lr_up = lr_up.to(device)
            sr = model(lr)
            hr_np = hr.squeeze().permute(1,2,0).cpu().numpy()
            bic_np = lr_up.squeeze().permute(1,2,0).cpu().numpy()
            sr_np = sr.squeeze().permute(1,2,0).cpu().numpy()
            psnr_b.append(psnr(hr_np, bic_np, data_range=1.0))
            psnr_sr.append(psnr(hr_np, sr_np, data_range=1.0))
            ssim_b.append(ssim(hr_np, bic_np, channel_axis=2, data_range=1.0))
            ssim_sr.append(ssim(hr_np, sr_np, channel_axis=2, data_range=1.0))
            # log images for the first val sample only
            if j==0:
                # make a grid: LR up, SR, HR
                grid = utils.make_grid([lr_up.squeeze().cpu(), sr.squeeze().cpu().clamp(0,1), hr.squeeze().cpu()], nrow=3)
                writer.add_image('val/sample_epoch_%d' % (epoch+1), grid, epoch)
    mean_psnr_b = np.mean(psnr_b); mean_psnr_sr = np.mean(psnr_sr)
    mean_ssim_b = np.mean(ssim_b); mean_ssim_sr = np.mean(ssim_sr)
    print(f'Val PSNR Bicubic: {mean_psnr_b:.3f}, SR: {mean_psnr_sr:.3f} | SSIM Bicubic: {mean_ssim_b:.4f}, SR: {mean_ssim_sr:.4f}')

    writer.add_scalar('val/psnr_bic', mean_psnr_b, epoch)
    writer.add_scalar('val/psnr_sr', mean_psnr_sr, epoch)
    writer.add_scalar('val/ssim_bic', mean_ssim_b, epoch)
    writer.add_scalar('val/ssim_sr', mean_ssim_sr, epoch)

    # checkpoint
    torch.save({'epoch': epoch+1, 'model': model.state_dict(), 'optim': optimizer.state_dict()}, ckpt_path)

writer.close()
print('Training finished. Run `tensorboard --logdir runs/espcn_experiment` to inspect.')

## How to run TensorBoard in Colab
In Colab run:

```python
%load_ext tensorboard
%tensorboard --logdir runs/espcn_experiment
```

This will open an inline TensorBoard where you can view scalars and images.

## Notes
- For better final performance, set `num_epochs=200+` and use more training images (remove `max_images` limits). 
- If `torchviz` fails due to Graphviz not installed, you can install system packages in Colab (`apt-get install graphviz`) or rely on the textual `torchsummary` output.
