In [None]:
!nvidia-smi

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [None]:
import os

In [None]:
import pandas as pd
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as T
from PIL import Image
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from tqdm import tqdm
import time

In [None]:
from model_RestormerBA import Restormer
from utils_restormerba import parse_args, BlurDataset, rgb_to_y, psnr, ssim

In [None]:
data_path = 'data'
data_name = 'GOPRO' #'GOPRO' 'HIDE_dataset_restormer' 'GoPro_NU_restormer'
save_paths = 'result'
num_blocks = [4, 6, 6, 8]
num_heads = [1, 2, 4, 8]
channels = [48, 96, 192, 384]
expansion_factor = 2.66
num_refinement = 4
num_iter = 1 #20000 #300000
batch_size = [8, 5, 4, 2, 1, 1] #[16, 10, 8, 4, 2, 2]
patch_size = [128, 160, 192, 256, 320, 384]
lr = 0.0003 #0.0003 or 0.0001 (mod)
milestone = [3000, 5200, 6800, 8000, 9200]
seed = -1 #no manual seed
model_file = None #uncomment to training model & comment below
# model_file = 'result/GOPRO_mod3_10k_bam.pth' ##uncomment to testing model & comment above

In [None]:
def test_loop(net, data_loader, num_iter):
    net.eval()
    total_psnr, total_ssim, count = 0.0, 0.0, 0
    with torch.no_grad():
        test_bar = tqdm(data_loader, initial=1, dynamic_ncols=True)
        for blur, sharp, name, h, w in test_bar:
            blur, sharp = blur.cuda(), sharp.cuda()
            out = torch.clamp((torch.clamp(model(blur)[:, :, :h, :w], 0, 1).mul(255)), 0, 255).byte()
            sharp = torch.clamp(sharp[:, :, :h, :w].mul(255), 0, 255).byte()
            y, gt = rgb_to_y(out.double()), rgb_to_y(sharp.double())
            current_psnr, current_ssim = psnr(y, gt), ssim(y, gt)
            total_psnr += current_psnr.item()
            total_ssim += current_ssim.item()
            count += 1
            save_path = '{}/{}/{}'.format(save_paths, data_name, name[0])
            if not os.path.exists(os.path.dirname(save_path)):
                os.makedirs(os.path.dirname(save_path))
            Image.fromarray(out.squeeze(dim=0).permute(1, 2, 0).contiguous().cpu().numpy()).save(save_path)
            test_bar.set_description('Test Iter: [{}/{}] PSNR: {:.2f} SSIM: {:.3f}'
                                     .format(num_iter, 1 if model_file else num_iter,
                                             total_psnr / count, total_ssim / count))
    return total_psnr / count, total_ssim / count

In [None]:
def save_loop(net, data_loader, num_iter):
    global best_psnr, best_ssim
    val_psnr, val_ssim = test_loop(net, data_loader, num_iter)
    results['PSNR'].append('{:.2f}'.format(val_psnr))
    results['SSIM'].append('{:.3f}'.format(val_ssim))
    # save statistics
    data_frame = pd.DataFrame(data=results, index=range(1, (num_iter if model_file else num_iter // 1000) + 1))
    data_frame.to_csv('{}/{}_restormerba.csv'.format(save_paths, data_name), index_label='Iter', float_format='%.3f')
    if val_psnr > best_psnr and val_ssim > best_ssim:
        best_psnr, best_ssim = val_psnr, val_ssim
        with open('{}/{}.txt'.format(save_paths, data_name), 'w') as f:
            f.write('Iter: {} PSNR:{:.2f} SSIM:{:.3f}'.format(num_iter, best_psnr, best_ssim))
        f.close()
        
        #Best Epoch PSNR and SSIM
        f= open("training_log_restormerba.txt","a+")
        f.write("Training epoch: {}\n".format(num_iter))
        f.write("PSNR: {}\n".format(best_psnr))
        f.write("SSIM: {}\n".format(best_ssim))
        f.close()
        
        torch.save(model.state_dict(), '{}/{}_restormerba.pth'.format(save_paths, data_name))

In [None]:
if __name__ == '__main__':
    test_dataset = BlurDataset(data_path, data_name, 'test')
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    results, best_psnr, best_ssim = {'PSNR': [], 'SSIM': []}, 0.0, 0.0
    model = Restormer(num_blocks, num_heads, channels, num_refinement, expansion_factor).cuda()
    if model_file:
        model.load_state_dict(torch.load(model_file))
        save_loop(model, test_loader, 1)
    else:
        optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
        lr_scheduler = CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-6)
        total_loss, total_num, results['Loss'], i = 0.0, 0, [], 0
        train_bar = tqdm(range(1, num_iter + 1), initial=1, dynamic_ncols=True)
        for n_iter in train_bar:
            # progressive learning
            if n_iter == 1 or n_iter - 1 in milestone:
                end_iter = milestone[i] if i < len(milestone) else num_iter
                start_iter = milestone[i - 1] if i > 0 else 0
                length = batch_size[i] * (end_iter - start_iter)
                train_dataset = BlurDataset(data_path, data_name, 'train', patch_size[i], length)
                train_loader = iter(DataLoader(train_dataset, batch_size[i], True))
                i += 1
            # train
            model.train()
            blur, sharp, name, h, w = next(train_loader)
            blur, sharp = blur.cuda(), sharp.cuda()
            out = model(blur)
            loss = F.l1_loss(out, sharp)
            
            optimizer.zero_grad()
            loss.backward()
            
            optimizer.step()
            total_num += blur.size(0)
            total_loss += loss.item() * blur.size(0)
            train_bar.set_description('Train Iter: [{}/{}] Loss: {:.3f}'
                                      .format(n_iter, num_iter, total_loss / total_num))

            lr_scheduler.step()
            if n_iter % 1000 == 0:
                curr_time = time.strftime("%H:%M:%S", time.localtime())
                print("Current Time is :", curr_time)
                #log training per 1k epoch into txt file
                f= open("training_log_restormer_mod3_10k_bam2.txt","a+")
                f.write("Training epoch: {}\n".format(n_iter))
                f.write("Time: {}".format(curr_time))
                f.write("\n\n".format(curr_time))
                torch.save(model.state_dict(), '{}/{}_mod3_10k_bam2.pth'.format(save_paths, data_name))
                f.close()
                
            if n_iter % 10000 == 0:
                results['Loss'].append('{:.3f}'.format(total_loss / total_num))
                save_loop(model, test_loader, n_iter)
                #log training per 10k epoch into txt file
                curr_time = time.strftime("%H:%M:%S", time.localtime())
                print("Current Time is :", curr_time)
                f= open("training_log_restormer_mod3_10k_bam2.txt","a+")
                f.write("Testing epoch: {}\n".format(n_iter))
                f.write("Time: {}".format(curr_time))
                f.write("\n\n".format(curr_time))
                f.close()