## Stable Model Training (Large Batch/Limited GPU Memory Support)

## IMPORTANT: Training has -not- been verified by myself for this notebook ~jantic

### This notebook has been written to take advantage of the Large Model Support technology created by IBM.

### Information on Large Model Support
Large Model Support (LMS) is a feature provided in IBM Watson Machine Learning - Community Edition (WML-ce) PyTorch > V1.1.0 that allows the successful training of deep learning models that would otherwise exhaust GPU memory and abort with “out-of-memory” errors. LMS manages this oversubscription of GPU memory by temporarily swapping tensors to host memory when they are not needed. One or more elements of a deep learning model can lead to GPU memory exhaustion.

Requires the use of IBM WML-ce (Available here: https://www.ibm.com/support/knowledgecenter/en/SS5SF7_1.7.0/welcome/welcome.html)

Further Reading on PyTorch with Large Model Support: https://www.ibm.com/support/knowledgecenter/en/SS5SF7_1.7.0/navigation/wmlce_getstarted_pytorch.html#wmlce_getstarted_pytorch__lms_section


#### NOTE:
Using Large Model Support (LMS) will allow you to train the DeOldify models with a GPU that generally would not be suitable (e.g. GTX 1070 8Gb). A penalty to using LMS on x86_64 is that you will notice increased training times, due to the lack of a high bandwidth NVLink between the GPU and CPU.

If you are training on a PPC64LE system with NVLink (e.g. IBM AC922), then you will NOT suffer any penalty when using LMS and you can also increase the batch size to decrease the overall training times.

### Changes made

1. Larger ResNet backend (152 vs 101)
2. Easily Train on existing models (aka Tranfer Learn)
3. Easily Train via Half Precision
4. Increased Progressive Resizing to 512px
5. WIP: Train using EfficientNet backend (b7). You can optionally train to 600px if using EfficientNet

#### NOTES:  
* This is "NoGAN" based training, described in the DeOldify readme.
* This model prioritizes stable and reliable renderings.  It does particularly well on portraits and landscapes.  It's not as colorful as the artistic model.

In [1]:
#NOTE:  This must be the first call in order to work properly!
from deoldify import device
from deoldify.device_id import DeviceId
#choices:  CPU, GPU0...GPU7
device.set(device=DeviceId.GPU0)

<DeviceId.GPU0: 0>

In [2]:
# Dockerfile installs this but for some reason it's not loaded
!pip install tensorboardx==1.6.0
!pip install efficientnet-pytorch



In [3]:
resnet152_backend=True

In [4]:
import tensorflow as tf
import datetime
import os
import fastai
from fastai import *
from fastai.vision import *
from fastai.callbacks.tensorboard import *
from fastai.vision.gan import *
if resnet152_backend == True:
    from deoldify.generatorsResNet152 import *
else:
    from deoldify.generators import *
# Comment out above and Uncomment below to Load EfficientNet instead of ResNet101 for generators (Work in Progress)
# from deoldify.generatorsEFFNET import *
from deoldify.critics import *
from deoldify.dataset import *
from deoldify.loss import *
from deoldify.save import *
from PIL import Image, ImageDraw, ImageFont
from PIL import ImageFile

## Setup

### Activate Large Model Support for PyTorch

In [5]:
import shutil

In [6]:
# Set limit of GPU memory used before swapping tensors to host memory. 
max_gpu_mem = 29

def gb_to_bytes(gb):
    return gb*1024*1024*1024

# Enable PyTorch LMS
torch.cuda.set_enabled_lms(False)

# Set LMS limit
# torch.cuda.set_limit_lms(gb_to_bytes(max_gpu_mem))

In [7]:
# Check LMS is enabled
torch.cuda.get_enabled_lms()

False

In [8]:
# Check LMS Limit has been set
torch.cuda.get_limit_lms()

0

In [9]:
# Enable Half Precision (fp16) - https://docs.fast.ai/callbacks.fp16.html
hp_enable = False

# Load existing Generator model
load_existing_gen_model = False

# Load existing Critic model
load_existing_critic_model = False

In [10]:
# Path to Training Data
path = Path('/home/nolms/imageset/train')
path_hr = path

# Path to Validation Data
path_val = '/home/nolms/imageset/val'

# Path to Black and White images
path_bandw = Path('/home/nolms/generated')
path_lr = path_bandw/'bandw'

# Name of Model
proj_id = 'ColorizeStableNoLMS'

# Name of Generator
if load_existing_gen_model == False:
        gen_name = proj_id + '_gen'
else:
        # Path to existing Pre-Trained Model
        model_path = '/mnt/datasets/deoldify-pretrained-models/pretrain_gen-weights/'
        gen_name = model_path + proj_id + '_PretrainOnly_gen'

pre_gen_name = gen_name + '_0'

# Name of Critic
crit_name = proj_id + '_crit'

# Name of Generated Images folder, located within Training Data folder
name_gen = proj_id + '_image_gen'
path_gen = path/name_gen

# Path to tensorboard data
TENSORBOARD_PATH = Path('/home/nolms/tensorboard/' + proj_id)

nf_factor = 2
pct_start = 1e-8

# Specify Pre-Trained model
gen_old_checkpoint_name = pre_gen_name

In [11]:
def get_data(bs:int, sz:int, keep_pct:float):
    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, 
                             random_seed=None, keep_pct=keep_pct)

def get_crit_data(classes, bs, sz):
    src = ImageList.from_folder(path, include=classes, recurse=True).split_by_rand_pct(0.1, seed=42)
    ll = src.label_from_folder(classes=classes)
    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)
           .databunch(bs=bs).normalize(imagenet_stats))
    return data

def create_training_images(fn,i):
    dest = path_lr/fn.relative_to(path_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn).convert('LA').convert('RGB')
    img.save(dest)  
    
def save_preds(dl):
    i=0
    names = dl.dataset.items
    
    for b in dl:
        preds = learn_gen.pred_batch(batch=b, reconstruct=True)
        for o in preds:
            o.save(path_gen/names[i].name)
            i += 1
    
def save_gen_images():
    if path_gen.exists(): shutil.rmtree(path_gen)
    path_gen.mkdir(exist_ok=True)
    data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)
    save_preds(data_gen.fix_dl)
    PIL.Image.open(path_gen.ls()[0])

## Create black and white training images

Only runs if the directory isn't already created.

In [12]:
if not path_lr.exists():
    il = ImageList.from_folder(path_hr)
    parallel(create_training_images, il.items)

In [13]:
# List total number of B&W images
total_bw = len(list(path_lr.rglob('*.*')))
print('Total B&W Images:', total_bw)

Total B&W Images: 750


## Pre-train generator

#### NOTE
Most of the training takes place here in pretraining for NoGAN.  The goal here is to take the generator as far as possible with conventional training, as that is much easier to control and obtain glitch-free results compared to GAN training.

### 64px

In [14]:
bs=160 # This can be increased if using PyTorch LMS. Training can be slower when using x86_64. PPC64LE with NVLink (e.g. Power8 / Power9) does not suffer this training penalty
sz=64
keep_pct=1.0

In [15]:
data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)

In [16]:
if load_existing_gen_model == True and hp_enable == True:
        learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False).to_fp16()
elif load_existing_gen_model == True:
        learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)
elif hp_enable == True:
        learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).to_fp16()
else:
        learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)

In [17]:
learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPreTrain-64px'))

In [18]:
learn_gen.fit_one_cycle(30, pct_start=0.8, max_lr=slice(1e-3))

epoch,train_loss,valid_loss,time
0,7.200703,6.881262,02:57
1,7.250248,6.801543,02:53
2,7.165781,6.753059,02:49
3,7.073417,6.691885,02:49
4,6.95978,6.58005,02:53
5,6.838859,6.383619,02:56
6,6.719772,6.060325,02:50
7,6.597372,5.741998,02:59
8,6.477031,5.53464,02:51
9,6.342181,5.356925,02:44


  if ssh != up_out.shape[-2:]:
  .format(op_name, opset_version, op_name))
Exception in thread Thread-4:
Traceback (most recent call last):
  File "/home/nolms/anaconda3/envs/nolms/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/home/nolms/anaconda3/envs/nolms/lib/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "/home/nolms/anaconda3/envs/nolms/lib/python3.6/site-packages/fastai/callbacks/tensorboard.py", line 234, in _queue_processor
    request.write()
  File "/home/nolms/anaconda3/envs/nolms/lib/python3.6/site-packages/fastai/callbacks/tensorboard.py", line 424, in write
    self.tbwriter.add_graph(model=self.model, input_to_model=self.input_to_model)
  File "/home/nolms/anaconda3/envs/nolms/lib/python3.6/site-packages/tensorboardX/writer.py", line 566, in add_graph
    self.file_writer.add_graph(graph(model, input_to_model, verbose))
  File "/home/nolms/anaconda3/envs/nolms/lib/python3.6/site-packages

In [19]:
learn_gen.save(pre_gen_name)

In [20]:
learn_gen.unfreeze()

In [21]:
learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPreTrain-Unfreeze-64px'))

In [22]:
learn_gen.fit_one_cycle(20, pct_start=pct_start,  max_lr=slice(3e-7, 3e-4))

epoch,train_loss,valid_loss,time
0,3.139999,3.151261,02:58
1,3.142687,3.166601,02:58
2,3.126447,3.138364,02:53
3,3.116589,3.096464,02:56
4,3.098757,3.069941,02:54
5,3.085768,3.031842,02:50
6,3.069254,3.010613,02:51
7,3.056426,3.002199,02:50
8,3.047286,2.996199,02:54
9,3.037825,2.983824,02:50


In [23]:
learn_gen.save(pre_gen_name)

### 128px

In [24]:
bs=80 # This can be increased if using PyTorch LMS. Training can be slower when using x86_64. PPC64LE with NVLink (e.g. Power8 / Power9) does not suffer this training penalty
sz=128
keep_pct=1.0

In [25]:
learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)

In [26]:
learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPreTrain-128px'))

In [27]:
learn_gen.unfreeze()

In [28]:
learn_gen.fit_one_cycle(15, pct_start=pct_start, max_lr=slice(1e-7,1e-4))

epoch,train_loss,valid_loss,time
0,2.412763,2.229118,02:20
1,2.397464,2.242486,02:16
2,2.377202,2.234109,02:05
3,2.363073,2.206091,02:02
4,2.352275,2.188318,02:02
5,2.341934,2.177634,02:09
6,2.333969,2.170386,02:07
7,2.331078,2.173041,02:08
8,2.324987,2.16724,02:06
9,2.316201,2.168727,02:06


In [29]:
learn_gen.save(pre_gen_name)

### 192px

In [30]:
bs=40 # This can be increased if using PyTorch LMS. Training can be slower when using x86_64. PPC64LE with NVLink (e.g. Power8 / Power9) does not suffer this training penalty
sz=192
keep_pct=1.0

In [31]:
learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)

In [32]:
learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPreTrain-192px'))

In [33]:
learn_gen.unfreeze()

In [34]:
learn_gen.fit_one_cycle(10, pct_start=pct_start, max_lr=slice(5e-8,5e-5))

epoch,train_loss,valid_loss,time
0,2.089684,2.036349,01:55
1,2.17192,2.069926,01:44
2,2.11632,2.051018,01:42
3,2.082998,2.039481,01:44
4,2.059728,2.042892,01:46
5,2.04802,2.036618,01:48
6,2.037714,2.033741,01:47
7,2.027432,2.033447,01:46
8,2.022375,2.037764,01:45
9,2.017369,2.04218,01:53


In [35]:
learn_gen.save(pre_gen_name)

### 256px

In [36]:
bs=20 # This can be increased if using PyTorch LMS. Training can be slower when using x86_64. PPC64LE with NVLink (e.g. Power8 / Power9) does not suffer this training penalty
sz=256
keep_pct=1.0

In [37]:
learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)

In [38]:
learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPreTrain-256px'))

In [39]:
learn_gen.unfreeze()

In [40]:
learn_gen.fit_one_cycle(5, pct_start=pct_start, max_lr=slice(5e-8,5e-5))

epoch,train_loss,valid_loss,time
0,1.866832,1.744093,01:58
1,1.907457,1.784688,01:46
2,1.88187,1.770499,01:42
3,1.869818,1.755128,01:47
4,1.87105,1.760842,01:42


In [41]:
learn_gen.save(pre_gen_name)

### 512px

In [42]:
bs=10 # This can be increased if using PyTorch LMS. Training can be slower when using x86_64. PPC64LE with NVLink (e.g. Power8 / Power9) does not suffer this training penalty
sz=512
keep_pct=1.0

In [43]:
learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)

In [44]:
learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPreTrain-512px'))

In [45]:
learn_gen.unfreeze()

In [46]:
learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))

epoch,train_loss,valid_loss,time


RuntimeError: CUDA out of memory. Tried to allocate 2.53 GiB (GPU 0; 31.75 GiB total capacity; 28.56 GiB already allocated; 1.07 GiB free; 821.62 MiB cached; 0 bytes inactive)

In [None]:
learn_gen.save(pre_gen_name)

### Optional if using EfficientNet backend

### 600px

In [None]:
# s=10 # This can be increased if using PyTorch LMS. Training can be slower when using x86_64. PPC64LE with NVLink (e.g. Power8 / Power9) does not suffer this training penalty
# sz=600
# keep_pct=1.0

In [None]:
# learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)

In [None]:
# learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPreTrain-600px'))

In [None]:
# learn_gen.unfreeze()

In [None]:
# learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))

In [None]:
# learn_gen.save(pre_gen_name)

## Repeatable GAN Cycle

#### NOTE
Best results so far have been based on repeating the cycle below a few times (about 5-8?), until diminishing returns are hit (no improvement in image quality).  Each time you repeat the cycle, you want to increment that old_checkpoint_num by 1 so that new check points don't overwrite the old.  

In [None]:
old_checkpoint_num = 0
checkpoint_num = old_checkpoint_num + 1
gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)
gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)
crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)
crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)

### Save Generated Images

In [None]:
bs=10
sz=512

In [None]:
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)

In [None]:
save_gen_images()

### Pre-train the Critic on Dataset

#### Only need full pretraining of critic when starting from scratch.  Otherwise, just finetune!

In [None]:
if old_checkpoint_num == 0:
    bs=3
    sz=256
    learn_gen=None
    gc.collect()
    data_crit = get_crit_data([name_gen, 'images'], bs=bs, sz=sz)
    data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)
    learn_critic = colorize_crit_learner(data=data_crit, nf=256)
    learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPreTrain-256px'))
    learn_critic.fit_one_cycle(7, 1e-3)
    learn_critic.save(crit_old_checkpoint_name)

### Critic Training

In [None]:
bs=3
sz=512

In [None]:
data_crit = get_crit_data([name_gen, 'images'], bs=bs, sz=sz)

In [None]:
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)

In [None]:
learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)

In [None]:
learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPreTrain-512px'))

In [None]:
learn_critic.fit_one_cycle(4, 1e-4)

In [None]:
learn_critic.save(crit_new_checkpoint_name)

### GAN

In [None]:
learn_crit=None
learn_gen=None
gc.collect()

In [None]:
lr=2e-5
sz=512
bs=3

In [None]:
data_crit = get_crit_data([name_gen, 'images'], bs=bs, sz=sz)

In [None]:
learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)

In [None]:
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)

In [None]:
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,
                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))
learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=50))
learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=50))

#### Instructions:  
Find the checkpoint just before where glitches start to be introduced.  This is all very new so you may need to play around with just how far you go here with keep_pct.

In [None]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.1)
learn_gen.freeze_to(-1)
learn.fit(10,lr)