# Colab-BasicSR (pytorch lightning)

[This tutorial](https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09), [this issue](https://stackoverflow.com/questions/65387967/misconfigurationerror-no-tpu-devices-were-found-even-when-tpu-is-connected-in)  and [this Colab](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03-basic-gan.ipynb#scrollTo=3vKszYf6y1Vv) were very helpful. This Colab does support single-GPU, multi-GPU and TPU training.

Can use various loss functions and has the context_encoder discriminator as default. Currently there are only various inpainting generators from [my BasicSR fork](https://github.com/styler00dollar/Colab-BasicSR).

What is not included inside this Colab, but is included in [my normal BasicSR Colab](https://colab.research.google.com/github/styler00dollar/Colab-BasicSR/blob/master/Colab-BasicSR.ipynb):
- [edge-informed-sisr](https://github.com/knazeri/edge-informed-sisr/blob/master/src/models.py)
- [USRNet](https://github.com/cszn/KAIR/blob/master/models/network_usrnet.py)
- [OFT Dataloader](https://github.com/styler00dollar/Colab-BasicSR/tree/master/codes/data)
- Some loss functions, but most are here
- DiffAug / Mixup

What currently is here but not inside the other Colab:
- Custom mask loading
- New discriminators (EfficientNet, ResNeSt, Transformer)
- [AdamP](https://github.com/clovaai/AdamP)

Sidenotes:
- Does validation on set validation frequency and epoch end

In [None]:
!nvidia-smi

In [None]:
#@title GPU
# create empty folders
!mkdir /content/masks
!mkdir /content/validation
!mkdir /content/data
!mkdir /content/logs/

#!pip install pytorch-lightning -U
# Hotfix, to avoid pytorch-lightning bug
!pip install git+https://github.com/PyTorchLightning/pytorch-lightning
!pip install tensorboardX

In [None]:
#@title TPU  (restart runtime afterwards)
# create empty folders
!mkdir /content/masks
!mkdir /content/validation
!mkdir /content/data
!mkdir /content/logs/

#!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
#!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
#!pip install pytorch-lightning
!pip install lightning-flash

import collections
from datetime import datetime, timedelta
import os
import requests
import threading

_VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server')
VERSION = "xrt==1.15.0"  #@param ["xrt==1.15.0", "torch_xla==nightly"]
CONFIG = {
    'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'),
    'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format(
        (datetime.today() - timedelta(1)).strftime('%Y%m%d'))),
}[VERSION]
DIST_BUCKET = 'gs://tpu-pytorch/wheels'
TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)

# Update TPU XRT version
def update_server_xrt():
  print('Updating server-side XRT to {} ...'.format(CONFIG.server))
  url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format(
      TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0],
      XRT_VERSION=CONFIG.server,
  )
  print('Done updating server-side XRT: {}'.format(requests.post(url)))

update = threading.Thread(target=update_server_xrt)
update.start()

# Install Colab TPU compat PyTorch/TPU wheels and dependencies
!pip uninstall -y torch torchvision
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" .
!pip install "$TORCH_WHEEL"
!pip install "$TORCH_XLA_WHEEL"
!pip install "$TORCHVISION_WHEEL"
!sudo apt-get install libomp5
update.join()

!pip install pytorch-lightning


!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py > /dev/null
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev > /dev/null
!pip install pytorch-lightning > /dev/null

!pip install tensorboardX

In [None]:
#@title Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
print('Google Drive connected.')

Paths:
```
/content/data (rgb data)
/content/masks (1 channel masks, black = mask, white = original image)
/content/validation (images for validation)
/content/validation_output (validation destination, will be created if not present)
/content/test (rgb data)
/content/test_output (test output, will be created if not present)
```
By default, random masks will have 50% chance and custom masks will have 50% chance. Current validation does not rely on metrics and will take a green masked LR image as input, but metrics are added and only need a custom dataloader.

In [None]:
#@title copy data somehow
!mkdir '/content/data'
!mkdir '/content/data/images'
!cp "/content/drive/MyDrive/classification_v3.7z" "/content/data/images/data.7z"
%cd /content/data/images
!7z x "data.7z"
!rm -rf /content/data/images/data.7z

# Optional

In [None]:
# EfficientNet
!pip install efficientnet_pytorch
# AdamP
!pip install adamp

# Training

In [None]:
#@title delete validation, logs and checkpoints if needed
%cd /content/
!sudo rm -rf /content/validation_output
!sudo rm -rf /content/lightning_logs
!sudo rm -rf /content/logs
#!mkdir /content/logs/
!find . -name "*.ckpt" -type f -delete

In [None]:
!python train.py

# Testing 

In [None]:
#@title testing the model
dm = DS_green_from_mask('/content/test')
model = CustomTrainClass()
# GPU
#trainer = pl.Trainer(gpus=1, max_epochs=10, progress_bar_refresh_rate=20, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=1000, save_path='/content/')])
# GPU with AMP (amp_level='O1' = mixed precision)
trainer = pl.Trainer(gpus=1, precision=16, amp_level='O1', max_epochs=10, progress_bar_refresh_rate=20, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=1000, save_path='/content/')])
# TPU
#trainer = pl.Trainer(tpu_cores=8, max_epochs=10, progress_bar_refresh_rate=20, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=1000, save_path='/content/')])
trainer.test(model, dm, ckpt_path='/content/Checkpoint_0_0.ckpt')

# Misc

In [None]:
#@title creating 16x16 images
import cv2
import numpy
import glob
rootdir = '/content/data' #@param {type:"string"}
destination_dir = "/content/4k/" #@param {type:"string"}

files = glob.glob(rootdir + '/**/*.png', recursive=True)
files_jpg = glob.glob(rootdir + '/**/*.jpg', recursive=True)
files.extend(files_jpg)
err_files=[]

filepos = 0
img_cnt = 0
tmp_img = numpy.zeros((4096,4096, 3))
while True:
  for i in range(16):
    for j in range(16):
      image = cv2.imread(files[filepos])
      filepos += 1
      
      image = cv2.resize(image, (256,256))
      
      tmp_img[i*256:(i+1)*256, j*256:(j+1)*256] = image
  #cv2.imwrite("/content/4k/"+str(img_cnt)+".png", tmp_img)
  cv2.imwrite(destination_dir+str(img_cnt)+".jpg", tmp_img, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
  
  img_cnt += 1

In [None]:
#@title creating 16x16 images (with skip)
import cv2
import numpy
import glob
import shutil
import tqdm
import os
rootdir = '/content/data' #@param {type:"string"}
destination_dir = "/content/4k/" #@param {type:"string"}
broken_dir = '/content/opencv_fail/' #@param {type:"string"}
 
files = glob.glob(rootdir + '/**/*.png', recursive=True)
files_jpg = glob.glob(rootdir + '/**/*.jpg', recursive=True)
files_jpeg = glob.glob(rootdir + '/**/*.jpeg', recursive=True)
files_webp = glob.glob(rootdir + '/**/*.webp', recursive=True)
files.extend(files_jpg)
files.extend(files_jpeg)
files.extend(files_webp)
err_files=[]

filepos = 0
img_cnt = 0
filename_cnt = 0
tmp_img = numpy.zeros((4096,4096, 3))

with tqdm.tqdm(files) as pbar:
  while True:
      image = cv2.imread(files[filepos])
      filepos += 1

      if image is not None:
        
        i = img_cnt % 16
        j = img_cnt // 16

        image = cv2.resize(image, (256,256))
        tmp_img[i*256:(i+1)*256, j*256:(j+1)*256] = image
        img_cnt += 1
      else:
        print(files[filepos])
        print(f'{broken_dir}/{os.path.basename(files[filepos])}')
        shutil.move(files[filepos], f'{broken_dir}/{os.path.basename(files[filepos])}')

      if img_cnt == 256:
        cv2.imwrite(destination_dir+str(filename_cnt)+".jpg", tmp_img, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
        filename_cnt += 1
        img_cnt = 0
      pbar.update(1)

In [None]:
#@title creating 3x3 grayscale images (with skip)
import cv2
import numpy
import glob
import shutil
import tqdm
import os
rootdir = '/content/data' #@param {type:"string"}
destination_dir = "/content/merged/" #@param {type:"string"}
broken_dir = '/content/opencv_fail/' #@param {type:"string"}
 
files = glob.glob(rootdir + '/**/*.png', recursive=True)
files_jpg = glob.glob(rootdir + '/**/*.jpg', recursive=True)
files_jpeg = glob.glob(rootdir + '/**/*.jpeg', recursive=True)
files_webp = glob.glob(rootdir + '/**/*.webp', recursive=True)
files.extend(files_jpg)
files.extend(files_jpeg)
files.extend(files_webp)
err_files=[]

image_size = 400 #@param

filepos = 0
img_cnt = 0
filename_cnt = 0
tmp_img = numpy.zeros((image_size*3,image_size*3))

with tqdm.tqdm(files) as pbar:
  while True:
      image = cv2.imread(files[filepos], cv2.IMREAD_GRAYSCALE)
      filepos += 1

      if image is not None:
        
        i = img_cnt % 3
        j = img_cnt // 3

        image = cv2.resize(image, (image_size,image_size))
        tmp_img[i*image_size:(i+1)*image_size, j*image_size:(j+1)*image_size] = image
        img_cnt += 1
      else:
        print(files[filepos])
        print(f'{broken_dir}/{os.path.basename(files[filepos])}')
        shutil.move(files[filepos], f'{broken_dir}/{os.path.basename(files[filepos])}')

      if img_cnt == 9:
        #cv2.imwrite(destination_dir+str(filename_cnt)+".png", tmp_img, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
        cv2.imwrite(destination_dir+str(filename_cnt)+".png", tmp_img)
        filename_cnt += 1
        img_cnt = 0
      pbar.update(1)

In [None]:
#@title convert to onnx
#@markdown Make sure the input dimensions are correct. Maybe a runtime restart is needed if it complains about ``TypeError: forward() missing 1 required positional argument``. Make sure you only run the required cells.
from torch.autograd import Variable
model = CustomTrainClass()
checkpoint_path = '/content/Checkpoint_0_0.ckpt' #@param
output_path = '/content/output.onnx' #@param
model = model.load_from_checkpoint(checkpoint_path) # start training from checkpoint, warning: apperantly global_step will be reset to zero and overwriting validation images, you could manually make an offset
dummy_input = Variable(torch.randn(1, 1, 64, 64))

model.to_onnx(output_path, input_sample=dummy_input)

In [None]:
#@title copy pasting data to create artificatial dataset for debugging
import shutil
from random import random
from tqdm import tqdm
for i in tqdm(range(5000)):
  shutil.copy("/content/4k/0.jpg", "/content/4k/"+str(random())+"jpg")