# ResCSNet
Inspired by ConvCSNet. Can I just port the ResNet Multibranch structure in?

# About this notebook
This notebook is intended to run in [Google Colaboratory](https://colab.research.google.com). It may require a lot changes (mostly deletions) if you want to run it from your local device.

In order to view tensorboard plots in the Colab VM during trainning, I have applied some dirty hacks (using `frpc` and a remote VPS running `frpd`). Also I am using `pydrive` module to download dataset from my Google Drive, and upload the model checkpoint at the end of each epoch.

**Update**: Tensorboard 2.0 has added a "inline" tensorboard magic for juputer notebooks. It recommended that you open **another** notebook which shares the same VM with this notebook and run the following:
```
!pip install -q tf-nightly-2.0-preview
# Load the TensorBoard notebook extension
%load_ext tensorboard
%tensorboard --logdir runs
```
In this way you don't have to bother with a VPS or `frpc` or something.

# About the dataset
I will load the pictures from the COCO dataset downloaded (and grayscaled and center-cropped already) by myself. You may download it from https://drive.google.com/open?id=12Nje-yhxcIVyz7L_lVxxcfXcTWa5Ba-m

Some of the code below may try to download it from (your) Google Drive. It may be better to upload the file to your Google Drive.

# Install necessary packages and download dataset

In [0]:
!pip install -U -q tensorboardX
!pip install -U -q PyDrive

from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client.
# This only needs to be done once per notebook.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

# List .gz files in the root.
#
# Search query reference:
# https://developers.google.com/drive/v2/web/search-parameters
listed = drive.ListFile({'q': "title contains '.gz' and 'root' in parents"}).GetList()
for file in listed:
  print('title {}, id {}'.format(file['title'], file['id']))

downloaded = drive.CreateFile({'id': "12Nje-yhxcIVyz7L_lVxxcfXcTWa5Ba-m"})
downloaded.GetContentFile("center-crop-100.tar.gz")

# Unextract dataset
print("Extract dataset")
!tar -xzf "center-crop-100.tar.gz"


from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client.
print("Authenticate and create the PyDrive client")
auth.authenticate_user()
_gauth = GoogleAuth()
_gauth.credentials = GoogleCredentials.get_application_default()
# drive = GoogleDrive(_gauth)

# Create a file instance to upload checkpoint later
# gfile_ckpt = drive.CreateFile()

# pytorch-wavelets-1.0.0 is known to be OK. Higher versions may work as well
!git clone "https://github.com/fbcotter/pytorch_wavelets.git"
!pip install ./pytorch_wavelets
print("Done")

# Download exsiting parameters
Needed if you want to resume trainning or test with exsiting parameters.

Some pretrained models:
[ResCsNet-colab-5_2_1-r0.25_checkpoint.pth](https://drive.google.com/open?id=1QQJJ3c9SlMK03v_v28J_K0ymvvA2vbhG), [ResCsNet-colab-5_2_1-r0.20_checkpoint.pth](https://drive.google.com/open?id=12g8reeeyi4Ei8v9dZudS_jsRE8ImYuq5), [ResCsNet-colab-5_2_1-r0.15_checkpoint.pth](https://drive.google.com/open?id=1J2tOS02BJVBWwOT-p2SDwpKRuociEKPO), [ResCsNet-colab-5_2_1-r0.10_checkpoint.pth](https://drive.google.com/open?id=1oLLB45GWfGxhdxYqNoYPQRcNX5_v9XFE), [ResCsNet-colab-5_2_1-r0.05_checkpoint.pth](https://drive.google.com/open?id=1xJbTgKgjqUwyQJn8gxvWy7rJ5N-9maWu)

In [0]:
drive = GoogleDrive(gauth)
# download trained params
listed = drive.ListFile({'q': "title contains 'ResCsNet-colab-5_2_1-r0.25_checkpoint.pth' and trashed=False"}).GetList()
for file in listed:
  print('title {}, id {}'.format(file['title'], file['id']))
# listed = drive.ListFile({'q': "title contains 'ResCsNet-colab-5_2_1-r0.20_checkpoint.pth' and trashed=False"}).GetList()
# for file in listed:
#   print('title {}, id {}'.format(file['title'], file['id']))

In [0]:
drive = GoogleDrive(gauth)
downloaded2 = drive.CreateFile({'id': "1QQJJ3c9SlMK03v_v28J_K0ymvvA2vbhG"})
downloaded2.GetContentFile("ResCsNet-colab-5_2_1-r0.25_checkpoint.pth")
# downloaded3 = drive.CreateFile({'id': "12g8reeeyi4Ei8v9dZudS_jsRE8ImYuq5"})
# downloaded3.GetContentFile("ResCsNet-colab-5_2_1-r0.20_checkpoint.pth")

-----------------------------------------------------------------

# Runtime configurations
This section mostly covers hacks and tricks. If you are using tensorflow 2.0 and inline tensorboard you probably do not need to run cells in this section.

## nvidia-smi

In [0]:
!nvidia-smi

## tensorboard

In [0]:
!tar -xjvf runs.tbz

In [0]:
get_ipython().system_raw("tensorboard --logdir runs --host 127.0.0.1 &")
!ps -ef | grep tensorboard

## python http.server

In [0]:
get_ipython().system_raw("python3 -m http.server 8000 --bind 127.0.0.1 &")
!ps -ef | grep http

## frpc

In [0]:
!wget "https://github.com/fatedier/frp/releases/download/v0.24.1/frp_0.24.1_linux_amd64.tar.gz"
!mkdir frp
!tar -xvf "frp_0.24.1_linux_amd64.tar.gz" -C frp

In [0]:
get_ipython().system_raw("./frp/frp_0.24.1_linux_amd64/frpc -c ./frpc.ini &")
#!./frp/frp_0.24.1_linux_amd64/frpc -c ./frpc.ini

In [0]:
!ps -ef | grep frpc

In [0]:
!cat frpc.ini

## Control the go and stop of the trainning

In [0]:
# tell the program to stop trainning
# !touch _stop

# or lift the ban
# !rm _stop

## tensorboard data archive

In [0]:
get_ipython().system_raw("bash get_runs.sh &")
# get_ipython().system_raw("python3 upload_runs.py &")
#!bash get_runs.sh
#!python3 upload_runs.py

In [0]:
!ps -ef | grep get_runs
# !ps -ef | grep upload_runs.py

## tail trainning log

In [0]:
# !tail "ResCsNet-colab-5_2_1-r0.10_trainning.log"

## Google drive re-auth
I am doing this because sometimes I encouter bugs if not re-auth with Google drive. Not sure if I missed something in pydrive documention.

In [0]:
# if the upload has to fail I will do that manually
from pathlib import Path
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# !mkdir authentication

gauth = GoogleAuth()
gauth.LoadCredentialsFile("mycreds.txt")
if gauth.credentials is None:
    # Authenticate if they're not there
    # self.gauth.LocalWebserverAuth()
    print("no creds saved")
    auth.authenticate_user()
    gauth.credentials = GoogleCredentials.get_application_default()
elif gauth.access_token_expired:
    # Refresh them if expired
    print("token expired")
    gauth.Refresh()
else:
    # Initialize the saved creds
    print("Initialize the saved creds")
    gauth.Authorize()
# Save the current credentials to a file
gauth.SaveCredentialsFile("mycreds.txt") 


############################################
# the_drive = GoogleDrive(gauth)
# ckpt_name = "ResCsNet-colab-5_2_1-r0.10_checkpoint.pth"


# id_file = Path('_id')
# if id_file.exists():
#     print("_id exsits")
#     fileid = id_file.read_text()
#     _gfile_ckpt = the_drive.CreateFile({'id': fileid})
#     _gfile_ckpt.SetContentFile(ckpt_name)
#     _gfile_ckpt.Upload()
# else:
#     print("_id not exsits")
#     _gfile_ckpt = the_drive.CreateFile()      
#     _gfile_ckpt.SetContentFile(ckpt_name)
#     _gfile_ckpt.Upload()
#     id_file.write_text(_gfile_ckpt['id']) 
print("Done")

## Check uptime

In [0]:
!uptime

# ===========================

# The real things

In [0]:
# imports
from pathlib import Path
from PIL import Image, ImageFile
# see https://stackoverflow.com/questions/12984426/python-pil-ioerror-image-file-truncated-with-big-images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# from six.moves import cPickle as pickle
# import platform

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from torch.utils.data import sampler
import torch.nn.functional as F

# import torchvision.datasets as dset
import torchvision.transforms as T

import numpy as np
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter 

# from trainning_func import get_evaluation


# import ipdb

%matplotlib inline
# %matplotlib tk
# plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
# plt.rcParams['image.interpolation'] = 'nearest'
# plt.rcParams['image.cmap'] = 'gray'

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

print("Done")

## Dataset classes

In [0]:
input_width_height = 100
# build somthing like an image pyramid. For details see class definition of ResCsNet
TOTAL_PIXELS = input_width_height*input_width_height   # N
SAMPLING_RATE = 0.25
SAMPLED_PIXELS = int(TOTAL_PIXELS * SAMPLING_RATE)     # M
WIDTH_INITIAL = int(np.ceil(np.sqrt(SAMPLED_PIXELS)))
WIDTH_HALF = int(input_width_height/4*2)
WIDTH_THREE_QUARTERS = int(input_width_height/4*3)

class CocoDataset(Dataset):
    def __init__(self, path_dataset, transform=None):
        p_iter = Path(path_dataset).iterdir()
        self.transform = transform
        self.images_list = [name for name in p_iter]
    
    def __len__(self):
        '''provides the size of the dataset'''
        return len(self.images_list)
    
    def __getitem__(self, idx):
        '''
        supporting integer indexing in range from 0 to len(self) exclusive
        '''
        im = Image.open(self.images_list[idx]).convert("L")  # to grayscale
        if self.transform:
            im = self.transform(im)
        return im

## Dataset instances

In [0]:
# Test if CocoDataset is OK
# path_dset_train = '/home/xzc/center-crop-100/train2014'
# path_dset_test = '/home/xzc/center-crop-100/val2014'
path_dset_train = 'center-crop-100/train2014'
path_dset_test = 'center-crop-100/val2014'
# path_dset_train = r'D:\dev_workspace\CS-DL\center-crop-100\train2014'
# path_dset_test = r'D:\dev_workspace\CS-DL\center-crop-100\val2014'
input_width_height = 100

train_set = CocoDataset(path_dset_train,
                         transform=T.Compose(
                            [
                                T.RandomHorizontalFlip(),
                                T.RandomVerticalFlip(),
                                T.ToTensor(),
                                T.Normalize((0.5,), (0.5,))
                            ])
                        )
test_set = CocoDataset(path_dset_test,
                       transform=T.Compose(
                            [
                                T.ToTensor(),
                                T.Normalize((0.5,), (0.5,))
                            ])
                      )

print("Done")

In [0]:
# function to denomalize normalized image
def denormalize(im, mean, std):
    '''
    im: pytorch tensor view as image
    '''
    assert len(im.size()) == 3
    if im.shape[0] == 1: # grayscale
#         im = im.reshape(im.shape[0], im.shape[1])
        im = im * std[0] + mean[0]
    else: # rgb
        for ch in range(3):
            im[:,:,ch] = im[:,:,ch] * std[ch] + mean[ch]
    return im

def test_dataset():
    print(len(train_set))
    im = train_set[5]

    print(im.shape)
    im = denormalize(im, (0.5,), (0.5,))
    print(im.shape)

    _im = im.reshape((im.shape[1], im.shape[2]))
    #     plt.figure()
    plt.imshow(_im, cmap='gray')
    plt.show()

# test_dataset()
print("Done")

In [0]:
print(f"{len(train_set)}, {len(test_set)}")

## Dataloader instances

In [0]:
# make the pytorch loader

# final run: use this set
TOTAL_SAMPLES = len(train_set)
NUM_TRAIN = TOTAL_SAMPLES // 5 * 4
BATCH_SIZE = 64

# mini test: use this set
# TOTAL_SAMPLES = 2000
# NUM_TRAIN = TOTAL_SAMPLES // 5 * 4
# BATCH_SIZE = 60

# debug only
# TOTAL_SAMPLES = 100
# NUM_TRAIN = TOTAL_SAMPLES // 5 * 4
# BATCH_SIZE = 4

loader_train = DataLoader(train_set, batch_size=BATCH_SIZE, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))
loader_subtrain = DataLoader(Subset(train_set, [i for i in range(3)]), batch_size=BATCH_SIZE)  # used for checking avg_psnr
loader_val = DataLoader(train_set, batch_size=BATCH_SIZE, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, TOTAL_SAMPLES)))
loader_test = DataLoader(test_set, batch_size=BATCH_SIZE)

print(f'len(loader_train)=={len(loader_train)}, len(loader_subtrain)={len(loader_subtrain)}')
print(f'len(loader_val)=={len(loader_val)}, len(loader_test)=={len(loader_test)}')
print('Done')

In [0]:
# Test the usage of dataloaders
def test_dataloader():
    train_iter = iter(loader_train)
    original_im = next(train_iter)

    print(type(original_im))
    print(original_im.size())

    print("------------------")
    test_iter = iter(loader_test)
    im = next(test_iter)
    print(type(im))
    print(im.size())
    
# test_dataloader()

## Set up device

In [0]:
# set up device
# will use cuda if available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('using device:', device)

# Encoder

In [0]:
# original linear encoder
class EncoderLinear(nn.Module):
    def _init_weights(self, m):
#         print(m)
        if type(m) == nn.Conv2d:
            nn.init.kaiming_normal_(m.weight.data)
            if m.bias is not None:
                nn.init.constant_(m.bias.data, 0)
    
    def __init__(self, N, M):
        super().__init__()
        self.linear = nn.Linear(N, M)
        self.linear.apply(self._init_weights)
    
    def forward(self, x):
        x_vec = x.view(x.shape[0], 1, 1, -1)
        out = self.linear(x_vec)
        return out

In [0]:
# make instance and test the model
def test_encoder():
#     train_iter = iter(loader_train)
#     im = train_iter.next()
#     encoder = Encoder(1, kernel_size=15, stride=1, padding=0)
    print(f"input_width_height={input_width_height}")
    N = input_width_height*input_width_height
    M = int(N * 0.30)
    encoder = EncoderLinear(N, M)
    im = torch.randn(BATCH_SIZE,1,input_width_height,input_width_height)
    with torch.no_grad():
        y = encoder(im)
    print(y.size())

# test_encoder()

# The Recovery Network (Decoder)

In [0]:
class TheUpsample(nn.Module):
    '''
    Wrapper for torch.nn.functional.interpolate
    (why don't they write a ready-for-use 'nn.Interpolate'?)
    '''
    def __init__(self, size):
        super().__init__()
        self.size = size
    
    def forward(self, x):
        return nn.functional.interpolate(x, size=self.size, mode='nearest')

class ResCsNet(nn.Module):
#     def _init_weights(self, m):
#         '''
#         used to init weights.
#         '''
#         print(m)
#         if type(m) == nn.Conv2d:
#             nn.init.kaiming_normal_(m.weight.data)
#             nn.init.constant_(m.bias.data, 0)
#         if type(m) == nn.Linear:
#             nn.init.kaiming_normal_(m.weight.data)
#             nn.init.constant_(m.bias.data, 0)
#         if type(m) == nn.ConvTranspose2d:
#             nn.init.kaiming_normal_(m.weight.data)
#             nn.init.constant_(m.bias.data, 0)

    def __init__(self, N, M):
#     def __init__(self, encoder_out_ch, encoder_ksize, encoder_stride, encoder_padding):
        super().__init__()
        # encoder
#         self.encoder= Encoder(encoder_out_ch, encoder_ksize, encoder_stride, encoder_padding)
        self.encoder= EncoderLinear(N, M)
        
        # upsample
        self.initial_width = int(np.ceil(np.sqrt(M))) # should be identical to WIDTH_INITIAL
        self.upsample0 = TheUpsample((1, int(self.initial_width**2)) )
#         upsample1_width = int(input_width_height/4)
#         self.upsample1 = TheUpsample((upsample1_width, upsample1_width))
        upsample2_width = WIDTH_HALF
        self.upsample2 = TheUpsample((upsample2_width, upsample2_width))
        upsample3_width = WIDTH_THREE_QUARTERS
        self.upsample3 = TheUpsample((upsample3_width, upsample3_width))
        # the last upsample should up sample to the original size
        self.upsample4 = TheUpsample((input_width_height, input_width_height))
        
        # conv_scale
        self.conv_scale1 = nn.Conv2d(96, 1, kernel_size=1, stride=1, padding=0)
        self.conv_scale2 = nn.Conv2d(96, 1, kernel_size=1, stride=1, padding=0)
        self.conv_scale3 = nn.Conv2d(96, 1, kernel_size=1, stride=1, padding=0)
        self.conv_scale4 = nn.Conv2d(96, 1, kernel_size=1, stride=1, padding=0)
        
        # units
        # unit 1
        self.unit1 = nn.Sequential(
#             nn.Conv2d(encoder_out_ch, 96, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(1, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96),
            nn.LeakyReLU()

        )

        # unit2
        self.unit2a_branch1 = nn.Sequential(
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96),
        )
        
        self.unit2a_branch2 = nn.Sequential(
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(),
            
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(),
            
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96)
        )
        
        self.unit2b_branch2 = nn.Sequential(
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(),
            
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(),
            
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96)
        )
        
        
        # unit 3
        # branch1 is identical pass
        self.unit3_branch2 = nn.Sequential(
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(),
            
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(),
            
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96)
        )
        
        # unit 4
        self.unit4_branch1 = nn.Sequential(
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96),
        )
        
        self.unit4_branch2 = nn.Sequential(
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(),
            
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(),
            
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96)
        )
        
        # unit 5
        # branch 1 is identical pass
        self.unit5_branch2 = nn.Sequential(
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(),
            
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(),
            
            nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(96)
        )
    
    def forward(self, x):
        x = self.encoder(x)
        x_vec = self.upsample0(x)
        x = x_vec.view(x.shape[0], 1, self.initial_width, self.initial_width)  # 23x23, 32x32, 45x45, 55x55
        
        # unit 1
        x = self.unit1(x)
#         x = self.upsample1(x)
        
        # unit 2
        x2a_1 = self.unit2a_branch1(x)
        x2a_2 = self.unit2a_branch2(x)
        x = x2a_1 + x2a_2
        x2b_1 = x.clone()
        x2b_2 = self.unit2b_branch2(x)
        x = x2b_1 + x2b_2
        
        x = self.upsample2(x)  # remove this layer for r=0.3
        # unit 3
        x3_1 = x.clone()
        x3_2 = self.unit3_branch2(x)
        x = x3_1 + x3_2
        
        x = self.upsample3(x)
        # unit 4
        x4_1 = self.unit4_branch1(x)
        x4_2 = self.unit4_branch2(x)
        x = x4_1 + x4_2
        
        x = self.upsample4(x)
        # unit5
        x5_1 = x.clone()
        x5_2 = self.unit5_branch2(x)
        x = x5_1 + x5_2
        im4 = self.conv_scale4(x)
        
        return im4

############################################
def ResCsNet_init_weights(m):
        '''
        used to init weights. This is just a clone of the method in the class above. Provided for convenience
        '''
#         print(m)
        if type(m) == nn.Conv2d:
            nn.init.kaiming_normal_(m.weight.data)
            if m.bias is not None:
                nn.init.constant_(m.bias.data, 0)
        if type(m) == nn.Linear:
            nn.init.kaiming_normal_(m.weight.data)
            nn.init.constant_(m.bias.data, 0)
        if type(m) == nn.ConvTranspose2d:
            nn.init.kaiming_normal_(m.weight.data)
            if m.bias is not None:
                nn.init.constant_(m.bias.data, 0)

In [0]:
# Make instance and check
# %debug 
def test_rescsnet():
    print(f'using device: {device}')
    # 8, kernel_size=11, stride=5, padding=0 for r=0.28
#     model = ResCsNet(encoder_out_ch=1, encoder_ksize=15, encoder_stride=1, encoder_padding=0)
    N = input_width_height*input_width_height
    M = int(N * 0.05)
    model = ResCsNet(N, M)
    model.apply(ResCsNet_init_weights)
#     train_iter = iter(loader_train)
#     im = train_iter.next()
    im = torch.randn(BATCH_SIZE,1,input_width_height,input_width_height).to(device=device)
    with torch.no_grad():
        model.train()
        model.to(device=device)
        recovered_im = model(im)
    print(recovered_im.size())
    print("------------------")
    with torch.no_grad():
        model.eval()
        model.to(device=device)
        recovered_im = model(im)
    print(recovered_im.size())
    del model
    
# test_rescsnet()

In [0]:
print(torch.cuda.memory_cached())
torch.cuda.empty_cache()
print(torch.cuda.memory_cached())

# Training

## Experiment setup

In [0]:
exp_name = 'ResCsNet-colab'

print(f"input_width_height={input_width_height}")
N = input_width_height*input_width_height
M = int(N * SAMPLING_RATE)
model = ResCsNet(N, M)

## Load exsiting params

In [0]:
#######################################
# Load the parameters to continue trainning if desired
# see https://github.com/pytorch/examples/blob/d6b52110bae32cbefeea6d4ffbf8cede98ac16fc/imagenet/main.py#L175
#######################################

want_load_params = True
if want_load_params:
    #######################################
    # You need to make this correct
    fname = 'ResCsNet-colab-5_2_1-r0.25_checkpoint.pth'
    #######################################
    checkpoint = torch.load(fname)
    tfx_steps = checkpoint['tfx_steps']
    print(f"tfx_steps is {tfx_steps}")
    tfx_epochs_done = checkpoint['tfx_epochs_done']
    print(f"tfx_epochs_done is {tfx_epochs_done}")
    model = ResCsNet(N, M)
    model.load_state_dict(checkpoint['state_dict'])
    model.train()
    model.cuda()
    
    optimizer = optim.Adam(model.parameters())
    optimizer.load_state_dict(checkpoint['optimizer'])
    
    
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=10, threshold=5e-3 ,verbose=True)
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
    #             print("copy to cuda")
                state[k] = v.cuda()

    print("Done")

In [0]:
# (re)setting learning rate if needed
for g in optimizer.param_groups:
    g['lr'] = 1e-5

In [0]:
print("The current lrs:")
for g in optimizer.param_groups:
    print(g['lr'])

## Train and validation functions

In [0]:
# The trainning function needs some logging
import logging
logging.basicConfig(format="[%(asctime)s] %(message)s", filename=exp_name+"_trainning.log",level=logging.INFO)

In [0]:
class FrobeniusLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x2, x1, eps=1e-8):
        '''
        x1, x2: both of shape [N, C, H, W]. x1 is the source
        '''

        diff = x1 - x2
        num = torch.norm(diff, p='fro')
        den = torch.norm(x1, p='fro') + eps

        frob = num / den

        return frob

In [0]:
from pathlib import Path
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
from trainning_func import get_evaluation
from pytorch_dwt_ssim import DWT_SSIM

stop_file = Path('_stop')


def save_checkpoint(model, optimizer, lr_scheduler, tfx_steps, tfx_epochs_done, ckpt_name):
    # Save the state model, just in case
    logging.info("Saving the state of model")
    state = {
        'tfx_steps': tfx_steps,
        'tfx_epochs_done': tfx_epochs_done,
        'state_dict': model.state_dict(),
        'optimizer' : optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict()
    }
    torch.save(state, ckpt_name)
    logging.info("State saving done, uploading to google drive")
    
    try:
        gauth = GoogleAuth()
        gauth.LoadCredentialsFile("mycreds.txt")
        if gauth.credentials is None:
            # Authenticate if they're not there
            # self.gauth.LocalWebserverAuth()
            print("[!] No creds saved")
            auth.authenticate_user()
            gauth.credentials = GoogleCredentials.get_application_default()
        elif gauth.access_token_expired:
            # Refresh them if expired
            print("[!] Token expired")
            gauth.Refresh()
        else:
            # Initialize the saved creds
#             print("Initialize the saved creds")
            gauth.Authorize()
        # Save the current credentials to a file
        gauth.SaveCredentialsFile("mycreds.txt")
        the_drive = GoogleDrive(gauth)
        
        id_file = Path('_id')
        if id_file.exists():
            fileid = id_file.read_text()
            _gfile_ckpt = the_drive.CreateFile({'id': fileid})
            _gfile_ckpt.SetContentFile(ckpt_name)
            _gfile_ckpt.Upload()
        else:
            _gfile_ckpt = the_drive.CreateFile()      
            _gfile_ckpt.SetContentFile(ckpt_name)
            _gfile_ckpt.Upload()
            id_file.write_text(_gfile_ckpt['id']) 
        
    except:
        print("??? Some error occured when trying to upload to gdrive")
#         raise

        

def train(model, optimizer, lr_scheduler,
          fn_mse=FrobeniusLoss(), fn_cwssim=DWT_SSIM(J=3, wave='haar'),
          mse_weight=0.3, cwssim_weight=0.7,
          epochs=1, logdir=None, print_every=10, tfx_steps=0, tfx_epochs_done=0, device=torch.device('cuda'),
          ckpt_name="checkpoint.pt"):
    """
    Train a model

    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - lr_scheduler: learning rate scheduler
    - fn_mse, fn_cwssim: loss functions for L2 loss and wavelet loss
    - mse_weight, cwssim_weight: weights indicating how important they contribute to the total loss
    - epochs: A Python integer giving the number of epochs to train for
    - logdir: string. Used to specific the logdir of tensorboard
    - print_every: after print_every epochs this function will report to logging
    - tfx_steps, tfx_epochs_done: helps tensorboardX summary writer find what current step and epoch is
    - device: torch.device('cuda') or torch.device('cpu')

    Returns:
    - tfx_steps: the end of the tfx_steps
    """
    try:
        writer = SummaryWriter(log_dir=logdir)
        print(f"Run `tensorboard --logdir={logdir} --host=127.0.0.1` to visualize in realtime")

        fn_mse = fn_mse.to(device=device)
        fn_cwssim = fn_cwssim.to(device=device)

        # PSNR scheduler
#         lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=10, threshold=5e-3 ,verbose=True)
        # CW-SSIM scheduler
    #     lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=10,
    #                                                         verbose=True)
        for e in range(epochs):
            if stop_file.exists():
                print("Stop file found. Will stop trainning now")
                save_checkpoint(model,
                                optimizer,
                                lr_scheduler,
                                tfx_steps=tfx_steps,
                                tfx_epochs_done=tfx_epochs_done,
                                ckpt_name=ckpt_name)
                break

            model.train()  # ensure the model is in training mode
            logging.info('-----------------------------')
            logging.info(f'* epoch {tfx_epochs_done}')
            for t, original_im in enumerate(loader_train):
                original_im = original_im.to(device=device)

                recovered_im = model(original_im)
                mse_loss = fn_mse(recovered_im, original_im)

                cwssim_loss = 1 - fn_cwssim(recovered_im, original_im)

                # construct the total loss
                loss = mse_weight*mse_loss + cwssim_weight*cwssim_loss

                writer.add_scalars('train/loss',
                                   {
                                       'mse_loss.item()': mse_loss.item(),
                                       'cwssim_loss.item()': cwssim_loss.item(),
                                       'loss.item()': loss.item()
                                   },
                                   tfx_steps)

                optimizer.zero_grad()
                loss.backward()

                optimizer.step()

                if t % int(print_every) == 0:
                    logging.info('Iteration %d/%d, loss = %.4f' % (t, len(loader_train) , loss.item()))

                tfx_steps += 1
            #end for

            # after the end of each epoch
            ## increment tfx_epochs_done counter
            tfx_epochs_done += 1
            ## check the performance of the model
            logging.info("Checking on subtrain and validation set...")
            subtrain_psnr, val_psnr, subtrain_mix_psnr, val_mix_psnr = get_evaluation(
                model,
                loader_val,
                loader_subtrain,
                device,
                fn_mse=nn.MSELoss(),
                fn_cwssim=fn_cwssim,
                mse_weight=mse_weight,
                cwssim_weight=cwssim_weight)
            logging.info(f"Average PSNR for subtrain set is {subtrain_psnr} dB")
            logging.info(f"Average PSNR for validation set is {val_psnr} dB")

            logging.info(f"Average mixed gain for subtrain set is {subtrain_mix_psnr} dB")
            logging.info(f"Average mixed gain for validation set is {val_mix_psnr} dB")

    #         loss_subtrain = mse_weight*subtrain_avg_mse + cwssim_weight*subtrain_avg_cwssim_loss
    #         loss_val = mse_weight*val_avg_mse + cwssim_weight*val_avg_cwssim_loss
            if lr_scheduler is not None:
                lr_scheduler.step(val_mix_psnr)  # check loss and determine if the lr should be decreased

            writer.add_scalars('train/val_evaluation', 
                               {
                                   'subtrain_psnr': subtrain_psnr.item(),
                                   'val_psnr': val_psnr.item(),
                                   'subtrain_mix_psnr': subtrain_mix_psnr.item(),
                                   'val_mix_psnr': val_mix_psnr.item()
                               },
                               tfx_epochs_done
                              )
            # Save the state model, just in case
            logging.info("Saving the state of model")
            save_checkpoint(model,
                            optimizer,
                            lr_scheduler,
                            tfx_steps=tfx_steps,
                            tfx_epochs_done=tfx_epochs_done,
                            ckpt_name=ckpt_name)

        #end for

        writer.close()  # tensorboardX writer
        return tfx_steps, tfx_epochs_done
    
    except (KeyboardInterrupt, SystemExit):
        print("KeyboardInterrupt: save the state of model")
        save_checkpoint(model,
                        optimizer,
                        lr_scheduler,
                        tfx_steps=tfx_steps,
                        tfx_epochs_done=tfx_epochs_done,
                        ckpt_name=ckpt_name)
        return tfx_steps, tfx_epochs_done
    except:
        print("Emergency: save the state of model")
        save_checkpoint(model,
                        optimizer,
                        lr_scheduler,
                        tfx_steps=tfx_steps,
                        tfx_epochs_done=tfx_epochs_done,
                        ckpt_name=ckpt_name)
        raise

In [0]:
# old style train, mse only
def get_avg_psnr(model, loader_val, loader_subtrain, device, print_every=10):
    model.eval()  # ensure the model is in evaluation mode
    subtrain_avg_psnr = 0
    val_avg_psnr = 0

    with torch.no_grad():
        for t, val_im in enumerate(loader_val):
            if t % int(print_every) == 0:
                logging.info(f"checked {t}/{len(loader_val)} in loader_val")
            val_original = val_im.to(device)
            val_recovered = model(val_original)
            val_mse = F.mse_loss(val_recovered, val_original)
            
            # PSNR
            val_psnr = 10 * np.log10(1 / val_mse.item())
            val_avg_psnr += val_psnr
            
        val_avg_psnr /= len(loader_val)


        for t, subtrain_im in enumerate(loader_subtrain):
            if t % int(print_every) == 0:
                logging.info(f"checked {t}/{len(loader_subtrain)} in loader_subtrain")
            subtrain_original = subtrain_im.to(device)
            subtrain_recovered = model(subtrain_original)
            subtrain_mse = F.mse_loss(subtrain_recovered, subtrain_original)
            
            # PSNR
            subtrain_psnr = 10 * np.log10(1 / subtrain_mse.item())
            subtrain_avg_psnr += subtrain_psnr
            
        subtrain_avg_psnr /= len(loader_subtrain)
    
    return subtrain_avg_psnr, val_avg_psnr

#############################################################

def train_oldstyle(model, optimizer, lr_scheduler,
          epochs=1, logdir=None, print_every=10, tfx_steps=0, tfx_epochs_done=0, device=torch.device('cuda'),
          ckpt_name="checkpoint.pt"):
    """
    Train a model

    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - lr_scheduler: learning rate scheduler
    - epochs: A Python integer giving the number of epochs to train for
    - logdir: string. Used to specific the logdir of tensorboard

    Returns:
    - tfx_steps: the end of the tfx_steps
    """
    try:
        writer = SummaryWriter(log_dir=logdir)
        print(f"Run `tensorboard --logdir={logdir} --host=127.0.0.1` to visualize in realtime")

        fn_mse = nn.MSELoss()
        fn_mse = fn_mse.to(device=device)
        
        for e in range(epochs):
            if stop_file.exists():
                print("Stop file found. Will stop trainning now")
                save_checkpoint(model,
                                optimizer,
                                lr_scheduler,
                                tfx_steps=tfx_steps,
                                tfx_epochs_done=tfx_epochs_done,
                                ckpt_name=ckpt_name)
                break

            model.train()  # ensure the model is in training mode
            logging.info('-----------------------------')
            logging.info(f'* epoch {tfx_epochs_done}')
            for t, original_im in enumerate(loader_train):
                original_im = original_im.to(device=device)

                recovered_im = model(original_im)
                mse_loss = fn_mse(recovered_im, original_im)

                # construct the total loss
                loss = mse_loss

                writer.add_scalars('train/loss',
                                   {
                                       'mse_loss.item()': mse_loss.item(),
                                       'loss.item()': loss.item()
                                   },
                                   tfx_steps)

                optimizer.zero_grad()
                loss.backward()

                optimizer.step()

                if t % int(print_every) == 0:
                    logging.info('Iteration %d/%d, loss = %.4f' % (t, len(loader_train) , loss.item()))

                tfx_steps += 1
            #end for

            # after the end of each epoch
            ## increment tfx_epochs_done counter
            tfx_epochs_done += 1
            ## check the performance of the model
            logging.info("Checking on subtrain and validation set...")
            subtrain_psnr, val_psnr = get_avg_psnr(model, loader_val, loader_subtrain, device)
            
            logging.info(f"Average PSNR for subtrain set is {subtrain_psnr} dB")
            logging.info(f"Average PSNR for validation set is {val_psnr} dB")

            if lr_scheduler is not None:
                lr_scheduler.step(val_psnr)  # check loss and determine if the lr should be decreased

            writer.add_scalars('train/val_evaluation', 
                               {
                                   'subtrain_psnr': subtrain_psnr.item(),
                                   'val_psnr': val_psnr.item(),
                               },
                               tfx_epochs_done
                              )
            # Save the state model, just in case
            logging.info("Saving the state of model")
            save_checkpoint(model,
                            optimizer,
                            lr_scheduler,
                            tfx_steps=tfx_steps,
                            tfx_epochs_done=tfx_epochs_done,
                            ckpt_name=ckpt_name)

        #end for

        writer.close()  # tensorboardX writer
        return tfx_steps, tfx_epochs_done
    
    except (KeyboardInterrupt, SystemExit):
        print("KeyboardInterrupt: save the state of model")
        save_checkpoint(model,
                        optimizer,
                        lr_scheduler,
                        tfx_steps=tfx_steps,
                        tfx_epochs_done=tfx_epochs_done,
                        ckpt_name=ckpt_name)
        return tfx_steps, tfx_epochs_done
    except:
        print("Emergency: save the state of model")
        save_checkpoint(model,
                        optimizer,
                        lr_scheduler,
                        tfx_steps=tfx_steps,
                        tfx_epochs_done=tfx_epochs_done,
                        ckpt_name=ckpt_name)
        raise

## Time consuming part ...
It is a good idea to train with `train_oldstyle` which only uses MSE loss as the criterion, then switch to `train` which combines l2 loss and DW-SSIM loss to finetune the model. Since running wavelet code is still slower even on GPU. By using this trainning scheme you can save some time.

In [0]:
learning_rate = 5e-4
tfx_steps = 0
tfx_epochs_done = 0
model = model.to(device=device)  # move to proper device before constructing the optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=10, threshold=5e-3 ,verbose=True)

# train_oldstyle only incorporates L2 loss (nn.MSELoss())
# to use wavelet loss function, use the train function to train
tfx_steps, tfx_epochs_done = train_oldstyle(model, optimizer, lr_scheduler,
                                            epochs=49, logdir='runs/' + exp_name,
                                            tfx_steps=tfx_steps, tfx_epochs_done=tfx_epochs_done, device=device,
                                            ckpt_name=exp_name+"_checkpoint.pth")

In [0]:
tfx_steps, tfx_epochs_done = train_oldstyle(model, optimizer, lr_scheduler,
                                            epochs=35, logdir='runs/' + exp_name,
                                            tfx_steps=tfx_steps, tfx_epochs_done=tfx_epochs_done, device=device,
                                            ckpt_name=exp_name+"_checkpoint.pth")

In [0]:
tfx_steps, tfx_epochs_done = train(model, optimizer, lr_scheduler,
                                   mse_weight=0.5, cwssim_weight=0.5,
                                   epochs=22, logdir='runs/' + exp_name,
                                   tfx_steps=tfx_steps, tfx_epochs_done=tfx_epochs_done, print_every=10, device=device,
                                   ckpt_name=exp_name+"_checkpoint.pth")

In [0]:
tfx_steps, tfx_epochs_done = train(model, optimizer, lr_scheduler,
                                   mse_weight=0.5, cwssim_weight=0.5,
                                   epochs=22, logdir='runs/' + exp_name,
                                   tfx_steps=tfx_steps, tfx_epochs_done=tfx_epochs_done, print_every=10, device=device,
                                   ckpt_name=exp_name+"_checkpoint.pth")

In [0]:
# print("The current lrs:")
for g in optimizer.param_groups:
    g['lr'] = 1e-6

In [0]:
tfx_steps, tfx_epochs_done = train(model, optimizer, lr_scheduler,
                                   mse_weight=0.8, cwssim_weight=0.2,
                                   epochs=22, logdir='runs/' + exp_name,
                                   tfx_steps=tfx_steps, tfx_epochs_done=tfx_epochs_done, print_every=10, device=device,
                                   ckpt_name=exp_name+"_checkpoint.pth")

In [0]:
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=10, threshold=5e-3 ,verbose=True)
tfx_steps, tfx_epochs_done = train(model, optimizer, lr_scheduler,
                                   mse_weight=0.8, cwssim_weight=0.2,
                                   epochs=13, logdir='runs/' + exp_name,
                                   tfx_steps=tfx_steps, tfx_epochs_done=tfx_epochs_done, print_every=10, device=device,
                                   ckpt_name=exp_name+"_checkpoint.pth")

In [0]:
tfx_steps, tfx_epochs_done = train(model, optimizer, lr_scheduler,
                                   mse_weight=0.5, cwssim_weight=0.5,
                                   epochs=13, logdir='runs/' + exp_name,
                                   tfx_steps=tfx_steps, tfx_epochs_done=tfx_epochs_done, print_every=10, device=device,
                                   ckpt_name=exp_name+"_checkpoint.pth")

# Save the model again (just in case)

In [0]:
print("Saving the state of model")
state = {
    'tfx_steps': tfx_steps,
    'tfx_epochs_done': tfx_epochs_done,
    'state_dict': model.state_dict(),
    'optimizer' : optimizer.state_dict(),
}
torch.save(state, exp_name+"_checkpoint.bak.pth")

print("Uploading to google drive")
gfile_ckpt.SetContentFile(ckpt_name)
gfile_ckpt.Upload()

print("Done")

# Testing

In [0]:
# Visualize the recovered and original image
def denormalize(im):
    if im.shape[2] == 1: # grayscale
        mean, std = (0.5,), (0.5,)
        im = im.reshape(im.shape[0], im.shape[1])
        im = im * std[0] + mean[0]
    else: # rgb
        mean, std = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
        for ch in range(3):
            im[:,:,ch] = im[:,:,ch] * std[ch] + mean[ch]
    return im

# visualize recovered image
def vis_show(model, idx=0):
    model = model.to(device=device)
    model.eval()

    test_iter = iter(loader_test)
    _y = test_iter.next()
    print(_y.shape)

    with torch.no_grad():
        _y = _y.to(device=torch.device("cuda"))
        restored_im = model(_y)
    print(restored_im.shape)
    _im = restored_im[idx].cpu().numpy()
    _im = _im.transpose(1, 2, 0).astype(np.float)
    _ori = _y[idx].cpu().numpy()
    _ori = _ori.transpose(1, 2, 0).astype(np.float)
    # mean, std = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)

    # _im[_im < -1] = -1.0
    # _im[_im > 1] = 1.0

    mse = np.mean((_im - _ori) ** 2 )
    psnr = 10 * np.log10(1/mse)
    print(f"mse is {mse}, psnr is {psnr} dB")

    _im = denormalize(_im)
    _ori = denormalize(_ori)
    # im_mx = np.amax(_im)
    # im_mn = np.amin(_im)
    # _im = _im / (im_mx-im_mn) * 1.0

    # print(f'_im is {_im}')
    # print(f'_ori is {_ori}')


    print(f"the original image: mx = {np.amax(_ori)}, mn={np.amin(_ori)}")
    print(f"the restored image: mx = {np.amax(_im)}, mn={np.amin(_im)}")
    plt.subplot(1,2,1)
    plt.imshow(_ori, cmap="gray"); plt.title('original')
    plt.grid(False)
    plt.subplot(1,2,2)
    plt.imshow(_im, cmap="gray"); plt.title('restored')
    plt.grid(False)
    plt.show()
    
vis_show(model, 6)

In [0]:
# Calculate PSNR
device = torch.device('cuda')
model.to(device)
avg_psnr = 0
_psnrs = []
with torch.no_grad():
    model.eval()
    for t,batch in enumerate(loader_test):
        if t % 10 == 0: print(f"{t}/{len(loader_test)}")
        
        original = batch.to(device)
        recovered = model(original)
            
        recovered = denormalize(recovered)
        original = denormalize(original)
        diff = recovered - original
        rmse = np.sqrt( np.mean(diff ** 2 ) )
        psnr = 20 * np.log10(1/rmse)
#         mse = F.mse_loss(recovered, original)
#         psnr = 10 * np.log10(1 / mse.item())
        _psnrs.append(psnr)
        avg_psnr += psnr
print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(loader_test)))

psnrs = np.array(_psnrs)
plt.hist(psnrs, bins=15); plt.xlabel('PSNR/dB'); plt.ylabel('Number of samples')
plt.show()