## Tutorial 3. CARN on SuperResolution. 

In this tutorial, we will show 

- How to end-to-end train and structurally prune a CARN from scratch on DIV2K to get a compressed CARN.
- In this specific run, the compressed (via pruning mode of OTO) could reduce FLOPs and parameters by 81.6% and 82%.
- The PSNR on Set14, B100, and Urban100 are 33.05, 31.81, and 30.79, respectively. 

### Step 1. Create OTO instance

In [1]:
import sys
sys.path.append('..')
from sanity_check.backends import CarnNet
from only_train_once import OTO
import torch

scale = 2
model = CarnNet(scale=scale, multi_scale=False, group=1)
dummy_input = torch.rand(1, 3, 224, 224)

oto = OTO(model.cuda(), (dummy_input.cuda(), scale))

OTO graph constructor
graph build


#### (Optional) Visualize the pruning dependancy graph of DNN

Set `display_params=True` could display parameters and shapes on each node.

In [2]:
oto.visualize(view=False, out_dir='./cache')

### Step 2 Set up the second last conv operator as unprunable

It was observed having some trouble if that conv included into pruning upon current salience score calculation.

It can be done by either mark node group as unprunable via node_ids or param_names

In [3]:
# Different torch version may have different node id, use the visualization tool to locate
# oto.mark_unprunable_by_node_ids(['node-158']) 

# Or use param_name to locate the node group to make it as unprunable
oto.mark_unprunable_by_param_names(['exit.weight']) 

# Check the pruning dependency graph after `mark_unprunable`
oto.visualize(view=False, out_dir='./cache', display_params=True)

### Step 3. Dataset Preparation

Follow https://github.com/nmhkahn/CARN-pytorch/tree/master?tab=readme-ov-file to prepare train, val datasets. 

In [4]:
# Download training dataset
!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip -P ./data/carn_sr
!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip -P ./data/carn_sr
!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip -P ./data/carn_sr
!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip -P ./data/carn_sr

!unzip  -q './data/carn_sr/*.zip' -d ./data/carn_sr

--2024-02-04 21:02:04--  http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
Resolving data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)... 2001:67c:10ec:36c2::178, 129.132.52.178
Connecting to data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)|2001:67c:10ec:36c2::178|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip [following]
--2024-02-04 21:02:05--  https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
Connecting to data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)|2001:67c:10ec:36c2::178|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3530603713 (3.3G) [application/zip]
Saving to: ‘./data/carn_sr/DIV2K_train_HR.zip’


2024-02-04 21:05:14 (17.9 MB/s) - ‘./data/carn_sr/DIV2K_train_HR.zip’ saved [3530603713/3530603713]

--2024-02-04 21:05:14--  http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip
Resolving data.vision.ee.ethz.ch (data.vision.ee.ethz

#### To accelerate training, first convert training images to h5 format as follow (h5py module has to be installed).

In [5]:
import os
import glob
import h5py
import cv2
import numpy as np
from tqdm import tqdm 

dataset_dir = "./data/carn_sr/"
dataset_type = "train"

f = h5py.File(os.path.join(dataset_dir, "DIV2K_{}.h5".format(dataset_type)), "w")
dt = h5py.special_dtype(vlen=np.dtype('uint8'))

for subdir in ["HR", "X2", "X3", "X4"]:
    if subdir in ["HR"]:
        im_paths = glob.glob(os.path.join(dataset_dir, 
                                          "DIV2K_{}_HR".format(dataset_type), 
                                          "*.png"))

    else:
        im_paths = glob.glob(os.path.join(dataset_dir, 
                                          "DIV2K_{}_LR_bicubic".format(dataset_type), 
                                          subdir, "*.png"))
    im_paths.sort()
    grp = f.create_group(subdir)

    for i, path in enumerate(im_paths):
        im = cv2.imread(path)
        print(path)
        grp.create_dataset(str(i), data=im)

./data/carn_sr/DIV2K_train_HR/0001.png
./data/carn_sr/DIV2K_train_HR/0002.png
./data/carn_sr/DIV2K_train_HR/0003.png
./data/carn_sr/DIV2K_train_HR/0004.png
......


In [6]:
import os
import glob
import h5py
import random
import numpy as np
from PIL import Image
import torch.utils.data as data
import torchvision.transforms as transforms

def random_crop(hr, lr, size, scale):
    h, w = lr.shape[:-1]
    x = random.randint(0, w-size)
    y = random.randint(0, h-size)

    hsize = size*scale
    hx, hy = x*scale, y*scale

    crop_lr = lr[y:y+size, x:x+size].copy()
    crop_hr = hr[hy:hy+hsize, hx:hx+hsize].copy()

    return crop_hr, crop_lr


def random_flip_and_rotate(im1, im2):
    if random.random() < 0.5:
        im1 = np.flipud(im1)
        im2 = np.flipud(im2)

    if random.random() < 0.5:
        im1 = np.fliplr(im1)
        im2 = np.fliplr(im2)

    angle = random.choice([0, 1, 2, 3])
    im1 = np.rot90(im1, angle)
    im2 = np.rot90(im2, angle)

    # have to copy before be called by transform function
    return im1.copy(), im2.copy()


class TrainDataset(data.Dataset):
    def __init__(self, path, size, scale):
        super(TrainDataset, self).__init__()

        self.size = size
        h5f = h5py.File(path, "r")
        
        self.hr = [v[:] for v in h5f["HR"].values()]
        # perform multi-scale training
        if scale == 0:
            self.scale = [2, 3, 4]
            self.lr = [[v[:] for v in h5f["X{}".format(i)].values()] for i in self.scale]
        else:
            self.scale = [scale]
            self.lr = [[v[:] for v in h5f["X{}".format(scale)].values()]]
        
        h5f.close()

        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

    def __getitem__(self, index):
        size = self.size

        item = [(self.hr[index], self.lr[i][index]) for i, _ in enumerate(self.lr)]
        item = [random_crop(hr, lr, size, self.scale[i]) for i, (hr, lr) in enumerate(item)]
        item = [random_flip_and_rotate(hr, lr) for hr, lr in item]
        
        return [(self.transform(hr), self.transform(lr)) for hr, lr in item]

    def __len__(self):
        return len(self.hr)
        

class TestDataset(data.Dataset):
    def __init__(self, dirname, scale):
        super(TestDataset, self).__init__()

        self.name  = dirname.split("/")[-1]
        self.scale = scale
        
        if "DIV" in self.name:
            self.hr = glob.glob(os.path.join("{}_HR".format(dirname), "*.png"))
            self.lr = glob.glob(os.path.join("{}_LR_bicubic".format(dirname), 
                                             "X{}/*.png".format(scale)))
        else:
            all_files = glob.glob(os.path.join(dirname, "x{}/*.png".format(scale)))
            self.hr = [name for name in all_files if "HR" in name]
            self.lr = [name for name in all_files if "LR" in name]

        self.hr.sort()
        self.lr.sort()

        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

    def __getitem__(self, index):
        hr = Image.open(self.hr[index])
        lr = Image.open(self.lr[index])

        hr = hr.convert("RGB")
        lr = lr.convert("RGB")
        filename = self.hr[index].split("/")[-1]

        return self.transform(hr), self.transform(lr), filename

    def __len__(self):
        return len(self.hr)

In [7]:
from torch.utils.data import Dataset, DataLoader

train_data = TrainDataset('./data/carn_sr/DIV2K_train.h5', scale=2, size=64)
train_loader = DataLoader(train_data, batch_size=64, num_workers=4, shuffle=True, drop_last=True)

#### Download the test datasets following [CARN repo](https://github.com/nmhkahn/CARN-pytorch)


To this tutorial, we place them into `./data/SR_benchmark`, which includes `B100`, `Set5`, `Set14`, and `Urban100`.


### Step 4. Setup HESSO optimizer

The following main hyperparameters need to be taken care.

- `variant`: The optimizer that is used for training the baseline full model. Currently support `sgd`, `adam` and `adamw`.
- `lr`: The initial learning rate.
- `weight_decay`: Weight decay as standard DNN optimization.
- `target_group_sparsity`: The target group sparsity, typically higher group sparsity refers to more FLOPs and model size reduction, meanwhile may regress model performance more.
- `start_pruning_steps`: The number of steps that **starts** to prune.
- `pruning_steps`: The number of steps that **finishes** pruning (reach `target_group_sparsity`) after `start_pruning_steps`.
- `pruning_periods`:  Incrementally produce the group sparsity equally among pruning periods.

We empirically suggest `start_pruning_steps` as 1/10 of total number of training steps. `pruning_steps` until 1/4 or 1/5 of total number of training steps.

In [8]:
optimizer = oto.hesso(
    variant='adam', 
    lr=1e-4, 
    target_group_sparsity=0.6,
    start_pruning_step=60000,
    pruning_periods=10,
    pruning_steps=60000,
    importance_score_criteria='default'
)

Setup HESSO
Target redundant groups per period:  [84, 84, 84, 84, 84, 84, 84, 84, 84, 88]


In [9]:
from skimage.metrics import peak_signal_noise_ratio
from skimage.color import rgb2ycbcr
import math

# helpers
import matplotlib.pyplot as plt
%matplotlib inline

from pylab import rcParams
rcParams['figure.figsize'] = 15, 10

def psnr(im1, im2):
    def im2double(im):
        min_val, max_val = 0, 255
        out = (im.astype(np.float64)-min_val) / (max_val-min_val)
        return out

    im1 = im2double(im1)
    im2 = im2double(im2)
    psnr = peak_signal_noise_ratio(im1, im2, data_range=1)
    return psnr

def display1(img):
    plt.imshow(img, interpolation="nearest")
    plt.show()

def evaluate(model, test_data_dir, scale=2):
    shave = 20
    mean_psnr = 0
    model.eval()
    
    test_data   = TestDataset(test_data_dir, scale=scale)
    test_loader = DataLoader(test_data,
                             batch_size=1,
                             num_workers=1,
                             shuffle=False)

    for step, inputs in enumerate(test_loader):
        hr = inputs[0].squeeze(0)
        lr = inputs[1].squeeze(0)
        name = inputs[2][0]

        h, w = lr.size()[1:]
        h_half, w_half = int(h/2), int(w/2)
        h_chop, w_chop = h_half + shave, w_half + shave

        # split large image to 4 patch to avoid OOM error
        lr_patch = torch.FloatTensor(4, 3, h_chop, w_chop)
        lr_patch[0].copy_(lr[:, 0:h_chop, 0:w_chop])
        lr_patch[1].copy_(lr[:, 0:h_chop, w-w_chop:w])
        lr_patch[2].copy_(lr[:, h-h_chop:h, 0:w_chop])
        lr_patch[3].copy_(lr[:, h-h_chop:h, w-w_chop:w])
        lr_patch = lr_patch.cuda()
        
        # run refine process in here!
        sr = model(lr_patch, scale).data
        
        h, h_half, h_chop = h*scale, h_half*scale, h_chop*scale
        w, w_half, w_chop = w*scale, w_half*scale, w_chop*scale
        
        # merge splited patch images
        result = torch.FloatTensor(3, h, w).cuda()
        result[:, 0:h_half, 0:w_half].copy_(sr[0, :, 0:h_half, 0:w_half])
        result[:, 0:h_half, w_half:w].copy_(sr[1, :, 0:h_half, w_chop-w+w_half:w_chop])
        result[:, h_half:h, 0:w_half].copy_(sr[2, :, h_chop-h+h_half:h_chop, 0:w_half])
        result[:, h_half:h, w_half:w].copy_(sr[3, :, h_chop-h+h_half:h_chop, w_chop-w+w_half:w_chop])
        sr = result

        hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
        sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
        
        # evaluate PSNR
        bnd = scale
        im1 = hr[bnd:-bnd, bnd:-bnd]
        im2 = sr[bnd:-bnd, bnd:-bnd]

        # change to evaluate y-channel, based on a reproduction open issue in CARN 
        im1_y = rgb2ycbcr(im1)[..., 0]
        im2_y = rgb2ycbcr(im2)[..., 0]

        mean_psnr += psnr(im1_y, im2_y) / len(test_data)
    return mean_psnr

### Step 4. Train and prune via OTO HESSO

In [10]:
# max_step and lr_decay_step are the same as carn official repo. 
max_step = 600000
lr_decay_step = 400000 
print_interval = 1000
loss_fn = torch.nn.L1Loss()
step = 0
f_avg_val = 0.0
learning_rate = optimizer.get_learning_rate()

while True:
    for inputs in train_loader:
        model.train()
        hr, lr = inputs[-1][0], inputs[-1][1]
        hr, lr = hr.cuda(), lr.cuda()
        sr = model(lr, scale)
        loss = loss_fn(sr, hr)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), 10.0)
        optimizer.step()

        learning_rate = learning_rate * (0.5 ** (step // lr_decay_step))
        optimizer.set_learning_rate(learning_rate)
        f_avg_val += loss.item()  
        if step % print_interval == 0:
            group_sparsity, param_norm, _ = optimizer.compute_group_sparsity_param_norm()
            norm_important, norm_redundant, num_grps_important, num_grps_redundant = optimizer.compute_norm_groups()
            print("Step: {step}, loss: {f:.4f}, norm_all:{param_norm:.2f}, grp_sparsity: {gs:.2f}, norm_import: {norm_import:.2f}, norm_redund: {norm_redund:.2f}, num_grp_import: {num_grps_import}, num_grp_redund: {num_grps_redund}"\
                 .format(step=step, f=f_avg_val/print_interval, param_norm=param_norm, gs=group_sparsity, norm_import=norm_important, \
                         norm_redund=norm_redundant, num_grps_import=num_grps_important, num_grps_redund=num_grps_redundant
                ))
            f_avg_val = 0.0
            psnr_B100 = evaluate(model, './data/SR_benchmark/B100', scale=2)
            psnr_Set14 = evaluate(model, './data/SR_benchmark/Set14', scale=2)
            psnr_Urban100 = evaluate(model, './data/SR_benchmark/Urban100', scale=2)
            print("Val PSNR: B100: {p_b:.4f}, Set14: {p_s:.4f}, Urban100: {p_u:.4f}.".format(p_b=psnr_B100, p_s=psnr_Set14, p_u=psnr_Urban100))
        step += 1
    if step > max_step: 
        break

Step: 0, loss: 0.0002, norm_all:952.65, grp_sparsity: 0.00, norm_import: 952.65, norm_redund: 0.00, num_grp_import: 1408, num_grp_redund: 0
Val PSNR: B100: 12.6870, Set14: 11.7356, Urban100: 11.0549.
Step: 1000, loss: 0.0433, norm_all:997.35, grp_sparsity: 0.00, norm_import: 997.35, norm_redund: 0.00, num_grp_import: 1408, num_grp_redund: 0
Val PSNR: B100: 27.9681, Set14: 27.8497, Urban100: 25.0936.
Step: 2000, loss: 0.0204, norm_all:1019.24, grp_sparsity: 0.00, norm_import: 1019.24, norm_redund: 0.00, num_grp_import: 1408, num_grp_redund: 0
Val PSNR: B100: 28.9096, Set14: 28.8339, Urban100: 26.0179.
Step: 3000, loss: 0.0188, norm_all:1030.98, grp_sparsity: 0.00, norm_import: 1030.98, norm_redund: 0.00, num_grp_import: 1408, num_grp_redund: 0
Val PSNR: B100: 29.0512, Set14: 29.0506, Urban100: 26.1827.
Step: 4000, loss: 0.0182, norm_all:1040.63, grp_sparsity: 0.00, norm_import: 1040.63, norm_redund: 0.00, num_grp_import: 1408, num_grp_redund: 0
Val PSNR: B100: 29.1222, Set14: 29.2128, U

### (Optional) get FLOPs and number of parameters for full model

It must be excuted before oto.construct_subnet(). Otherwise, the API would calculate based on the pruned model.

`oto.compute_flops()` returns a dictionary with each node group' FLOPs, use `total` can get the total FLOPs for the whole DNN. 

`in_million` and `in_billion` argument could scale the numbers in the unit of million or billion. 

In [11]:
full_flops = oto.compute_flops(in_million=True)['total'] 
full_num_params = oto.compute_num_params(in_million=True)

print("Full FLOPs(M)", full_flops, "full Number of parameters (M)", full_num_params)

Full FLOPs(M) 48527.66822400001 full Number of parameters (M) 0.964187


### Step 5. Get pruned model in torch format

In [12]:
# By default OTO will construct subnet by the last checkpoint. If intermedia ckpt reaches the best performance,
# need to reinitialize OTO instance
# oto = OTO(torch.load(ckpt_path), dummy_input)
# then construct subnetwork
oto.construct_subnet(out_dir='./cache')

### (Optional) Check the compressed model size

In [13]:
full_model_size = os.stat(oto.full_group_sparse_model_path)
compressed_model_size = os.stat(oto.compressed_model_path)
print("Size of full model     : ", full_model_size.st_size / (1024 ** 3), "GBs")
print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs")

Size of full model     :  0.0036377040669322014 GBs
Size of compress model :  0.0006903214380145073 GBs


### (Optional) get FLOPs and number of parameters for compressed model

In [14]:
pruned_flops = oto.compute_flops(in_million=True)['total'] 
pruned_num_params = oto.compute_num_params(in_million=True)

print("Pruned FLOPs(M)", pruned_flops, "pruned Number of parameters (M)", pruned_num_params)

Pruned FLOPs(M) 8913.716224, pruned Number of parameters (M) 0.173537


### (Optional) Check the pruned model and full model difference. 
#### # Both full and pruned model should return the exact same output given the same input.

In [15]:
full_model = torch.load(oto.full_group_sparse_model_path).cpu()
compressed_model = torch.load(oto.compressed_model_path).cpu()

full_output = full_model(dummy_input, scale)
compressed_output = compressed_model(dummy_input, scale)

max_output_diff = torch.max(torch.abs(full_output - compressed_output))
print("Maximum output difference ", str(max_output_diff.item()))

Maximum output difference  5.960464477539062e-07
