# Baseline Demarkify Training Notebook

Work here is for getting the basics down for the new Fastai v2 library and getting a working training loop with the "NoGan" method (or something else if that's easier). 

In [1]:
from fastai2.basics import *
from fastai2.callback.all import *
from fastai2.vision.all import *
from fastai2.vision.gan import *
from PIL import ImageDraw, ImageFont, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = False # ignores metadata of large, compressed images
import pdb
from tqdm import tqdm_notebook
import gc

from preprocess import *
from train import *
from config import *

In [2]:
gpu_idx = 0
torch.cuda.set_device(gpu_idx)
print(f"Using {torch.cuda.get_device_name(gpu_idx)} at index {torch.cuda.current_device()}")

Using GeForce GTX 1080 Ti at index 0


## Preprocessing

Resize all "clean" images to have a max width/height of a reasonable size (1080px). Save to file. This takes some time to run!

In [None]:
# raw_images = get_image_files(RAW)
# parallel(partial(create_clean_image, CLEAN), raw_images)

Create marked images

In [None]:
# fonts = L(Path(DATA/"fonts").rglob("*.ttf"))
# clean_images = get_image_files(CLEAN)
# markr = Markr(CLEAN, MARKED, fonts)

# parallel(markr, clean_images, n_workers=12)

# if len(markr.failed_images) != 0:
#     print("Retrying failed images")
#     failed_images = markr.failed_images
#     parallel(markr, failed_images)

## Modeling

### Pretrain Generator

In [None]:
bs = 88
sz = 64
keep_pct = 1.0
arch = resnet34

In [None]:
dls_gen = get_dls(bs,sz,keep_pct)

In [None]:
dls_gen.show_batch(max_n=4)

In [None]:
wd = 1e-3
y_range = (-3.,3.)
loss_gen = MSELossFlat()

In [None]:
learn_gen = create_gen_learner(dls_gen, arch, loss_gen, y_range)

In [None]:
learn_gen.fit_one_cycle(1, pct_start=0.8, wd=wd)

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.fit_one_cycle(1, pct_start=0.8, lr_max=slice(3e-7, 3e-4), wd=wd)

In [None]:
learn_gen.show_results(max_n=4)

In [None]:
learn_gen.save(f"gen-pre-{sz}px")

In [None]:
bs = 22
sz = 128
keep_pct = 1.0

In [None]:
learn_gen.dls = get_dls(bs, sz, keep_pct)

In [None]:
learn_gen.fit_one_cycle(1, pct_start=0.8, lr_max=slice(1e-7,1e-4))

In [None]:
learn_gen.show_results(max_n=4)

In [None]:
learn_gen.save(f"gen-pre-{sz}px")

In [None]:
bs = 11
sz = 192
keep_pct=1.0

In [None]:
learn_gen.dls = get_dls(bs, sz, keep_pct)

In [None]:
learn_gen.fit_one_cycle(1, pct_start=0.8, lr_max=slice(5e-8, 5e-5))

In [None]:
learn_gen.show_results(max_n=4)

In [None]:
learn_gen.save(f"gen-pre-{sz}px")

In [None]:
learn_gen.load(f"gen-pre-{sz}px");

In [None]:
bs = 12
sz = 300
keep_pct=1.0

In [None]:
learn_gen.dls = get_dls(bs, sz, keep_pct)

In [None]:
learn_gen.fit_one_cycle(1, lr_max=slice(5e-8, 5e-7))

In [None]:
# load_image('/mnt/nvme/data/demarkr/marked/27705.png')

In [None]:
learn_gen.show_results(max_n=4)

In [None]:
learn_gen.save(f"gen-pre-{sz}px")

In [None]:
bs = 6
sz = 480
keep_pct=1.0

In [None]:
learn_gen.dls = get_dls(bs, sz, keep_pct)

In [None]:
learn_gen.fit_one_cycle(1, lr_max=slice(5e-8, 5e-7))

In [None]:
learn_gen.show_results(max_n=4)

In [None]:
learn_gen.save(f"gen-pre-{sz}px")

In [None]:
bs = 4
sz = 600
keep_pct=1.0

In [None]:
learn_gen.dls = get_dls(bs, sz, keep_pct)

In [None]:
learn_gen.fit_one_cycle(1, lr_max=slice(5e-8, 5e-7))

In [None]:
learn_gen.show_results(max_n=4)

In [None]:
learn_gen.save(f"gen-pre-{sz}px")

### GAN Cycle

In [3]:
cycle_iter = 3
prev_cycle_iter = cycle_iter - 1

#### Save Generated Images

In [4]:
arch = resnet34
wd = 1e-3
loss_gen = MSELossFlat()

bs = 16
sz = 300
keep_pct = 0.15

# create learner
dls_gen = get_dls(bs,sz,keep_pct)
learn_gen = create_gen_learner(dls_gen, arch, loss_gen)
if cycle_iter == 1:
    learn_gen.load(f"gen-pre-{sz}px")
else:
    learn_gen.load(f"gen-{prev_cycle_iter}-{sz}")

# get dataloader
dl = dls_gen.train.new(shuffle=False, drop_last=False, 
                       after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)])

In [None]:
save_preds(learn_gen, dl)

HBox(children=(FloatProgress(value=0.0, description='batches', max=436.0, style=ProgressStyle(description_widt…

#### Train Critic

In [None]:
learn_gen, dl = None, None
torch.cuda.empty_cache()
gc.collect()

In [None]:
wd = 1e-3
loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())

Pretrain if `cycle_iter` is 1

In [None]:
if cycle_iter == 1:
    # 128px
    bs = 128
    sz = 128
    
    dls_crit = get_crit_dls(bs=bs, size=sz)
    learn_critic = create_critic_learner(dls_crit, accuracy_thresh_expand, loss_critic)
    learn_critic.fit_one_cycle(6, 1e-3, wd=wd)
    learn_critic.save(f"critic-pre-{sz}px")

In [None]:
bs = 10
sz = 300

dls_crit = get_crit_dls(bs=bs, size=sz)
learn_critic = create_critic_learner(dls_crit, accuracy_thresh_expand, loss_critic)
learn_critic.dls = get_crit_dls(bs=bs, size=sz)

# load last model
if cycle_iter >= 2:
    learn_critic.load(f"critic-pre-{sz}px")
else:
    learn_critic.load(f"critic-{cycle_iter}-{sz}")

# fit model
learn_critic.fit_one_cycle(6, 1e-5, wd=wd)

# save model
if cycle_iter == 1:
    learn_critic.save(f"critic-pre-{sz}px")
else:
    learn_critic.save(f"critic-{prev_cycle_iter}-{sz}")

#### GAN

In [None]:
learn_critic, dls_crit = None, None
torch.cuda.empty_cache()
gc.collect()

In [None]:
bs = 6
sz = 300

In [None]:
dls_crit = get_crit_dls(bs=bs, size=sz)

In [None]:
learn_crit = create_critic_learner(dls_crit, metrics=None, loss_critic=loss_critic)
if cycle_iter == 1:
    learn_crit.load(f"critic-pre-{sz}px")
else:
    learn_crit.load(f"critic-{prev_cycle_iter}-{sz}")

In [None]:
dls_gen = get_dls(bs=bs, size=sz, keep_pct=1.0)

In [None]:
keep_pct = 1.0
arch = resnet34
loss_gen = MSELossFlat()

learn_gen = create_gen_learner(dls_gen, arch, loss_gen)
if cycle_iter == 1:
    learn_gen.load(f"gen-pre-{sz}px")
else:
    learn_gen.load(f"gen-{prev_cycle_iter}-{sz}")

In [None]:
class GANDiscriminativeLR(Callback):
    def __init__(self, mult_lr=5.):
        self.mult_lr = mult_lr
        
    def begin_batch(self):
        if not self.learn.gan_trainer.gen_mode and self.training:
            self.learn.opt.set_hyper('lr', learn.opt.hypers[0]['lr']*self.mult_lr)
    
    def after_batch(self):
        if not self.learn.gan_trainer.gen_mode:
            self.learn.opt.set_hyper('lr', learn.opt.hypers[0]['lr']/self.mult_lr)

In [None]:
class GANSaveCallback(Callback):
    def __init__(self, learn_gen: Learner, iteration: int):
        self.learn_gen = learn_gen
        self.iteration = iteration
        
    def after_fit(self):
        self._save_gen_learner()
        
    def _save_gen_learner(self):
        sz = self. learn.dls.loaders[0].after_item[0].final_size[0]
        self.learn_gen.save(f"gen-{self.iteration}-{sz}")

In [None]:
switcher = AdaptiveGANSwitcher(critic_thresh=0.65)
learn = GANLearner.from_learners(
    learn_gen, learn_crit, weights_gen=(1.,50.), show_img=False, switcher=switcher, 
    opt_func=partial(Adam, mom=0.), 
    cbs=[GANDiscriminativeLR(mult_lr=5.), GANSaveCallback(learn_gen, cycle_iter)])

In [None]:
lr = 1e-5

In [None]:
learn_gen.freeze_to(-1)

In [None]:
learn.fit(1, lr, wd=wd)

In [None]:
learn.show_results(max_n=4)

## Callback Testing

In [None]:
TstLearner

## Inference

In [None]:
img_fp = Path(DATA/"real_world"/"watermarked_dude_300px.png")
img = Image.open(img_fp)
display(img, img.shape)

In [None]:
gen_models = [model.name.split(".")[0] for model in Path(DATA/"models").ls() if "gen" in str(model) and "300" in str(model)]
gen_models.sort()
gen_models

In [None]:
keep_pct = 1.0
arch = resnet34
loss_gen = MSELossFlat()

learn_gen = create_gen_learner(dls_gen, arch, loss_gen)
preds = dict()

for model in gen_models:
    print("predicting for model", model)
    learn_gen.load(model)
    learn_gen.predict(img_fp)
    
    pred = learn_gen.predict(DATA/"real_world"/"watermarked_dude.png")
    dec = learn_gen.dls.after_batch.decode((TensorImage(pred[1].to('cpu')[None]),))[0][0]
    arr = dec.numpy().transpose(1,2,0)
    preds[model] = Image.fromarray(np.uint8(arr), mode='RGB')

In [None]:
for item in preds.items():
    display(item[0],item[1])

gan pred

learn_gen pred

In [None]:
pred = learn_gen.predict(DATA/"real_world"/"watermarked_dude.png")
dec = learn_gen.dls.after_batch.decode((TensorImage(pred[1].to('cpu')[None]),))[0][0]
arr = dec.numpy().transpose(1,2,0)
Image.fromarray(np.uint8(arr), mode='RGB')