# Colab-DFNet (pytorch lightning)

Porting my pytorch repo [styler00dollar/Colab-DFNet](https://github.com/styler00dollar/Colab-DFNet) to pytorch-lightning. Uses [Yukariin/DFNet](https://github.com/Yukariin/DFNet) as a base. [This tutorial](https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09), [this issue](https://stackoverflow.com/questions/65387967/misconfigurationerror-no-tpu-devices-were-found-even-when-tpu-is-connected-in)  and [this Colab](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03-basic-gan.ipynb#scrollTo=3vKszYf6y1Vv) were very helpful. This Colab does support single-GPU, multi-GPU and TPU training.

Can use various loss functions and has the context_encoder discriminator as default.

In [None]:
!nvidia-smi

In [None]:
#@title GPU
!pip install pytorch-lightning -U

In [None]:
#@title TPU  (restart runtime afterwards)
#!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
#!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
#!pip install pytorch-lightning
!pip install lightning-flash

import collections
from datetime import datetime, timedelta
import os
import requests
import threading

_VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server')
VERSION = "xrt==1.15.0"  #@param ["xrt==1.15.0", "torch_xla==nightly"]
CONFIG = {
    'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'),
    'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format(
        (datetime.today() - timedelta(1)).strftime('%Y%m%d'))),
}[VERSION]
DIST_BUCKET = 'gs://tpu-pytorch/wheels'
TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)

# Update TPU XRT version
def update_server_xrt():
  print('Updating server-side XRT to {} ...'.format(CONFIG.server))
  url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format(
      TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0],
      XRT_VERSION=CONFIG.server,
  )
  print('Done updating server-side XRT: {}'.format(requests.post(url)))

update = threading.Thread(target=update_server_xrt)
update.start()

# Install Colab TPU compat PyTorch/TPU wheels and dependencies
!pip uninstall -y torch torchvision
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" .
!pip install "$TORCH_WHEEL"
!pip install "$TORCH_XLA_WHEEL"
!pip install "$TORCHVISION_WHEEL"
!sudo apt-get install libomp5
update.join()

!pip install pytorch-lightning


!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py > /dev/null
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev > /dev/null
!pip install pytorch-lightning > /dev/null

In [None]:
#@title Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
print('Google Drive connected.')

In [None]:
# copy data somehow
!mkdir '/content/data'
!mkdir '/content/data/images'
!cp "/content/drive/MyDrive/data.7z" "/content/data/images/data.7z"
%cd /content/data/images
!7z x "data.7z"
!rm -rf /content/data/images/data.7z

In [None]:
#@title getting pytorch-loss-functions
%cd /content
!git clone https://github.com/styler00dollar/pytorch-loss-functions pytorchloss
%cd /content/pytorchloss

In [None]:
%cd /content/pytorchloss

In [None]:
#@title utils.py
from pathlib import Path

import cv2
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np


def resize_like(x, target, mode='bilinear'):
    return F.interpolate(x, target.shape[-2:], mode=mode, align_corners=False)


def list2nparray(lst, dtype=None):
    """fast conversion from nested list to ndarray by pre-allocating space"""
    if isinstance(lst, np.ndarray):
        return lst
    assert isinstance(lst, (list, tuple)), 'bad type: {}'.format(type(lst))
    assert lst, 'attempt to convert empty list to np array'
    if isinstance(lst[0], np.ndarray):
        dim1 = lst[0].shape
        assert all(i.shape == dim1 for i in lst)
        if dtype is None:
            dtype = lst[0].dtype
            assert all(i.dtype == dtype for i in lst), \
                'bad dtype: {} {}'.format(dtype, set(i.dtype for i in lst))
    elif isinstance(lst[0], (int, float, complex, np.number)):
        return np.array(lst, dtype=dtype)
    else:
        dim1 = list2nparray(lst[0])
        if dtype is None:
            dtype = dim1.dtype
        dim1 = dim1.shape
    shape = [len(lst)] + list(dim1)
    rst = np.empty(shape, dtype=dtype)
    for idx, i in enumerate(lst):
        rst[idx] = i
    return rst


def get_img_list(path):
    return sorted(list(Path(path).glob('*.png'))) + \
        sorted(list(Path(path).glob('*.jpg'))) + \
        sorted(list(Path(path).glob('*.jpeg')))


def gen_miss(img, mask, output):

    imgs = get_img_list(img)
    masks = get_img_list(mask)
    print('Total images:', len(imgs), len(masks))

    out = Path(output)
    out.mkdir(parents=True, exist_ok=True)

    for i, (img, mask) in tqdm(enumerate(zip(imgs, masks))):
        path = out.joinpath('miss_%04d.png' % (i+1))
        img = cv2.imread(str(img), cv2.IMREAD_COLOR)
        mask = cv2.imread(str(mask), cv2.IMREAD_GRAYSCALE)
        mask = cv2.resize(mask, img.shape[:2][::-1])
        mask = mask[..., np.newaxis]
        miss = img * (mask > 127) + 255 * (mask <= 127)
        cv2.imwrite(str(path), miss)

def merge_imgs(dirs, output, row=1, gap=2, res=512):

    image_list = [get_img_list(path) for path in dirs]
    img_count = [len(image) for image in image_list]
    print('Total images:', img_count)
    assert min(img_count) > 0, 'Please check the path of empty folder.'

    output_dir = Path(output)
    output_dir.mkdir(parents=True, exist_ok=True)

    n_img = len(dirs)
    row = row
    column = (n_img - 1) // row + 1
    print('Row:', row)
    print('Column:', column)

    for i, unit in tqdm(enumerate(zip(*image_list))):
        name = output_dir.joinpath('merge_%04d.png' % i)
        merge = np.ones([
            res*row + (row+1)*gap, res*column + (column+1)*gap, 3], np.uint8) * 255
        for j, img in enumerate(unit):
            r = j // column
            c = j - r * column
            img = cv2.imread(str(img), cv2.IMREAD_COLOR)
            if img.shape[:2] != (res, res):
                img = cv2.resize(img, (res, res))
            start_h, start_w = (r + 1) * gap + r * res, (c + 1) * gap + c * res
            merge[start_h: start_h + res, start_w: start_w + res] = img
        cv2.imwrite(str(name), merge)


In [None]:
#@title metrics.py (removing lpips import)
%%writefile /content/pytorchloss/metrics.py
#https://github.com/huster-wgm/Pytorch-metrics/blob/master/metrics.py

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
  @Email:  guangmingwu2010@gmail.com \
           guozhilingty@gmail.com
  @Copyright: go-hiroaki & Chokurei
  @License: MIT
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
#import lpips

eps = 1e-6

def _binarize(y_data, threshold):
    """
    args:
        y_data : [float] 4-d tensor in [batch_size, channels, img_rows, img_cols]
        threshold : [float] [0.0, 1.0]
    return 4-d binarized y_data
    """
    y_data[y_data < threshold] = 0.0
    y_data[y_data >= threshold] = 1.0
    return y_data

def _argmax(y_data, dim):
    """
    args:
        y_data : 4-d tensor in [batch_size, chs, img_rows, img_cols]
        dim : int
    return 3-d [int] y_data
    """
    return torch.argmax(y_data, dim).int()


def _get_tp(y_pred, y_true):
    """
    args:
        y_true : [int] 3-d in [batch_size, img_rows, img_cols]
        y_pred : [int] 3-d in [batch_size, img_rows, img_cols]
    return [float] true_positive
    """
    return torch.sum(y_true * y_pred).float()


def _get_fp(y_pred, y_true):
    """
    args:
        y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
        y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
    return [float] false_positive
    """
    return torch.sum((1 - y_true) * y_pred).float()


def _get_tn(y_pred, y_true):
    """
    args:
        y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
        y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
    return [float] true_negative
    """
    return torch.sum((1 - y_true) * (1 - y_pred)).float()


def _get_fn(y_pred, y_true):
    """
    args:
        y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
        y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
    return [float] false_negative
    """
    return torch.sum(y_true * (1 - y_pred)).float()


def _get_weights(y_true, nb_ch):
    """
    args:
        y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
        nb_ch : int
    return [float] weights
    """
    batch_size, img_rows, img_cols = y_true.shape
    pixels = batch_size * img_rows * img_cols
    weights = [torch.sum(y_true==ch).item() / pixels for ch in range(nb_ch)]
    return weights


class CFMatrix(object):
    def __init__(self, des=None):
        self.des = des

    def __repr__(self):
        return "ConfusionMatrix"

    def __call__(self, y_pred, y_true, threshold=0.5):

        """
        args:
            y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return confusion matrix
        """
        batch_size, chs, img_rows, img_cols = y_true.shape
        device = y_true.device
        if chs == 1:
            y_pred = _binarize(y_pred, threshold)
            y_true = _binarize(y_true, threshold)
            nb_tp = _get_tp(y_pred, y_true)
            nb_fp = _get_fp(y_pred, y_true)
            nb_tn = _get_tn(y_pred, y_true)
            nb_fn = _get_fn(y_pred, y_true)
            mperforms = [nb_tp, nb_fp, nb_tn, nb_fn]
            performs = None
        else:
            y_pred = _argmax(y_pred, 1)
            y_true = _argmax(y_true, 1)
            performs = torch.zeros(chs, 4).to(device)
            weights = _get_weights(y_true, chs)
            for ch in range(chs):
                y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_true_ch[y_true == ch] = 1
                y_pred_ch[y_pred == ch] = 1
                nb_tp = _get_tp(y_pred_ch, y_true_ch)
                nb_fp = _get_fp(y_pred_ch, y_true_ch)
                nb_tn = _get_tn(y_pred_ch, y_true_ch)
                nb_fn = _get_fn(y_pred_ch, y_true_ch)
                performs[int(ch), :] = [nb_tp, nb_fp, nb_tn, nb_fn]
            mperforms = sum([i*j for (i, j) in zip(performs, weights)])
        return mperforms, performs


class OAAcc(object):
    def __init__(self, des="Overall Accuracy"):
        self.des = des

    def __repr__(self):
        return "OAcc"

    def __call__(self, y_pred, y_true, threshold=0.5):
        """
        args:
            y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return (tp+tn)/total
        """
        batch_size, chs, img_rows, img_cols = y_true.shape
        device = y_true.device
        if chs == 1:
            y_pred = _binarize(y_pred, threshold)
            y_true = _binarize(y_true, threshold)
        else:
            y_pred = _argmax(y_pred, 1)
            y_true = _argmax(y_true, 1)

        nb_tp_tn = torch.sum(y_true == y_pred).float()
        mperforms = nb_tp_tn / (batch_size * img_rows * img_cols)
        performs = None
        return mperforms, performs


class Precision(object):
    def __init__(self, des="Precision"):
        self.des = des

    def __repr__(self):
        return "Prec"

    def __call__(self, y_pred, y_true, threshold=0.5):
        """
        args:
            y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return tp/(tp+fp)
        """
        batch_size, chs, img_rows, img_cols = y_true.shape
        device = y_true.device
        if chs == 1:
            y_pred = _binarize(y_pred, threshold)
            y_true = _binarize(y_true, threshold)
            nb_tp = _get_tp(y_pred, y_true)
            nb_fp = _get_fp(y_pred, y_true)
            mperforms = nb_tp / (nb_tp + nb_fp + esp)
            performs = None
        else:
            y_pred = _argmax(y_pred, 1)
            y_true = _argmax(y_true, 1)
            performs = torch.zeros(chs, 1).to(device)
            weights = _get_weights(y_true, chs)
            for ch in range(chs):
                y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_true_ch[y_true == ch] = 1
                y_pred_ch[y_pred == ch] = 1
                nb_tp = _get_tp(y_pred_ch, y_true_ch)
                nb_fp = _get_fp(y_pred_ch, y_true_ch)
                performs[int(ch)] = nb_tp / (nb_tp + nb_fp + esp)
            mperforms = sum([i*j for (i, j) in zip(performs, weights)])
        return mperforms, performs


class Recall(object):
    def __init__(self, des="Recall"):
        self.des = des

    def __repr__(self):
        return "Reca"

    def __call__(self, y_pred, y_true, threshold=0.5):
        """
        args:
            y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return tp/(tp+fn)
        """
        batch_size, chs, img_rows, img_cols = y_true.shape
        device = y_true.device
        if chs == 1:
            y_pred = _binarize(y_pred, threshold)
            y_true = _binarize(y_true, threshold)
            nb_tp = _get_tp(y_pred, y_true)
            nb_fn = _get_fn(y_pred, y_true)
            mperforms = nb_tp / (nb_tp + nb_fn + esp)
            performs = None
        else:
            y_pred = _argmax(y_pred, 1)
            y_true = _argmax(y_true, 1)
            performs = torch.zeros(chs, 1).to(device)
            weights = _get_weights(y_true, chs)
            for ch in range(chs):
                y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_true_ch[y_true == ch] = 1
                y_pred_ch[y_pred == ch] = 1
                nb_tp = _get_tp(y_pred_ch, y_true_ch)
                nb_fn = _get_fn(y_pred_ch, y_true_ch)
                performs[int(ch)] = nb_tp / (nb_tp + nb_fn + esp)
            mperforms = sum([i*j for (i, j) in zip(performs, weights)])
        return mperforms, performs


class F1Score(object):
    def __init__(self, des="F1Score"):
        self.des = des

    def __repr__(self):
        return "F1Sc"

    def __call__(self, y_pred, y_true, threshold=0.5):

        """
        args:
            y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return 2*precision*recall/(precision+recall)
        """
        batch_size, chs, img_rows, img_cols = y_true.shape
        device = y_true.device
        if chs == 1:
            y_pred = _binarize(y_pred, threshold)
            y_true = _binarize(y_true, threshold)
            nb_tp = _get_tp(y_pred, y_true)
            nb_fp = _get_fp(y_pred, y_true)
            nb_fn = _get_fn(y_pred, y_true)
            _precision = nb_tp / (nb_tp + nb_fp + esp)
            _recall = nb_tp / (nb_tp + nb_fn + esp)
            mperforms = 2 * _precision * _recall / (_precision + _recall + esp)
            performs = None
        else:
            y_pred = _argmax(y_pred, 1)
            y_true = _argmax(y_true, 1)
            performs = torch.zeros(chs, 1).to(device)
            weights = _get_weights(y_true, chs)
            for ch in range(chs):
                y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_true_ch[y_true == ch] = 1
                y_pred_ch[y_pred == ch] = 1
                nb_tp = _get_tp(y_pred_ch, y_true_ch)
                nb_fp = _get_fp(y_pred_ch, y_true_ch)
                nb_fn = _get_fn(y_pred_ch, y_true_ch)
                _precision = nb_tp / (nb_tp + nb_fp + esp)
                _recall = nb_tp / (nb_tp + nb_fn + esp)
                performs[int(ch)] = 2 * _precision * \
                    _recall / (_precision + _recall + esp)
            mperforms = sum([i*j for (i, j) in zip(performs, weights)])
        return mperforms, performs


class Kappa(object):
    def __init__(self, des="Kappa"):
        self.des = des

    def __repr__(self):
        return "Kapp"

    def __call__(self, y_pred, y_true, threshold=0.5):

        """
        args:
            y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return (Po-Pe)/(1-Pe)
        """
        batch_size, chs, img_rows, img_cols = y_true.shape
        device = y_true.device
        if chs == 1:
            y_pred = _binarize(y_pred, threshold)
            y_true = _binarize(y_true, threshold)
            nb_tp = _get_tp(y_pred, y_true)
            nb_fp = _get_fp(y_pred, y_true)
            nb_tn = _get_tn(y_pred, y_true)
            nb_fn = _get_fn(y_pred, y_true)
            nb_total = nb_tp + nb_fp + nb_tn + nb_fn
            Po = (nb_tp + nb_tn) / nb_total
            Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn) +
                  (nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2)
            mperforms = (Po - Pe) / (1 - Pe + esp)
            performs = None
        else:
            y_pred = _argmax(y_pred, 1)
            y_true = _argmax(y_true, 1)
            performs = torch.zeros(chs, 1).to(device)
            weights = _get_weights(y_true, chs)
            for ch in range(chs):
                y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_true_ch[y_true == ch] = 1
                y_pred_ch[y_pred == ch] = 1
                nb_tp = _get_tp(y_pred_ch, y_true_ch)
                nb_fp = _get_fp(y_pred_ch, y_true_ch)
                nb_tn = _get_tn(y_pred_ch, y_true_ch)
                nb_fn = _get_fn(y_pred_ch, y_true_ch)
                nb_total = nb_tp + nb_fp + nb_tn + nb_fn
                Po = (nb_tp + nb_tn) / nb_total
                Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn)
                      + (nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2)
                performs[int(ch)] = (Po - Pe) / (1 - Pe + esp)
            mperforms = sum([i*j for (i, j) in zip(performs, weights)])
        return mperforms, performs


class Jaccard(object):
    def __init__(self, des="Jaccard"):
        self.des = des

    def __repr__(self):
        return "Jacc"

    def __call__(self, y_pred, y_true, threshold=0.5):
        """
        args:
            y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return intersection / (sum-intersection)
        """
        batch_size, chs, img_rows, img_cols = y_true.shape
        device = y_true.device
        if chs == 1:
            y_pred = _binarize(y_pred, threshold)
            y_true = _binarize(y_true, threshold)
            _intersec = torch.sum(y_true * y_pred).float()
            _sum = torch.sum(y_true + y_pred).float()
            mperforms = _intersec / (_sum - _intersec + esp)
            performs = None
        else:
            y_pred = _argmax(y_pred, 1)
            y_true = _argmax(y_true, 1)
            performs = torch.zeros(chs, 1).to(device)
            weights = _get_weights(y_true, chs)
            for ch in range(chs):
                y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_true_ch[y_true == ch] = 1
                y_pred_ch[y_pred == ch] = 1
                _intersec = torch.sum(y_true_ch * y_pred_ch).float()
                _sum = torch.sum(y_true_ch + y_pred_ch).float()
                performs[int(ch)] = _intersec / (_sum - _intersec + esp)
            mperforms = sum([i*j for (i, j) in zip(performs, weights)])
        return mperforms, performs


class MSE(object):
    def __init__(self, des="Mean Square Error"):
        self.des = des

    def __repr__(self):
        return "MSE"

    def __call__(self, y_pred, y_true, dim=1, threshold=None):
        """
        args:
            y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return mean_squared_error, smaller the better
        """
        if threshold:
            y_pred = _binarize(y_pred, threshold)
        return torch.mean((y_pred - y_true) ** 2)


class PSNR(object):
    def __init__(self, des="Peak Signal to Noise Ratio"):
        self.des = des

    def __repr__(self):
        return "PSNR"

    def __call__(self, y_pred, y_true, dim=1, threshold=None):
        """
        args:
            y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return PSNR, larger the better
        """
        if threshold:
            y_pred = _binarize(y_pred, threshold)
        mse = torch.mean((y_pred - y_true) ** 2)
        return 10 * torch.log10(1 / mse)


class SSIM(object):
    '''
    modified from https://github.com/jorge-pessoa/pytorch-msssim
    '''
    def __init__(self, des="structural similarity index"):
        self.des = des

    def __repr__(self):
        return "SSIM"

    def gaussian(self, w_size, sigma):
        gauss = torch.Tensor([math.exp(-(x - w_size//2)**2/float(2*sigma**2)) for x in range(w_size)])
        return gauss/gauss.sum()

    def create_window(self, w_size, channel=1):
        _1D_window = self.gaussian(w_size, 1.5).unsqueeze(1)
        _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
        window = _2D_window.expand(channel, 1, w_size, w_size).contiguous()
        return window

    def __call__(self, y_pred, y_true, w_size=11, size_average=True, full=False):
        """
        args:
            y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            w_size : int, default 11
            size_average : boolean, default True
            full : boolean, default False
        return ssim, larger the better
        """
        # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
        if torch.max(y_pred) > 128:
            max_val = 255
        else:
            max_val = 1

        if torch.min(y_pred) < -0.5:
            min_val = -1
        else:
            min_val = 0
        L = max_val - min_val

        padd = 0
        (_, channel, height, width) = y_pred.size()
        window = self.create_window(w_size, channel=channel).to(y_pred.device)

        mu1 = F.conv2d(y_pred, window, padding=padd, groups=channel)
        mu2 = F.conv2d(y_true, window, padding=padd, groups=channel)

        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1 * mu2

        sigma1_sq = F.conv2d(y_pred * y_pred, window, padding=padd, groups=channel) - mu1_sq
        sigma2_sq = F.conv2d(y_true * y_true, window, padding=padd, groups=channel) - mu2_sq
        sigma12 = F.conv2d(y_pred * y_true, window, padding=padd, groups=channel) - mu1_mu2

        C1 = (0.01 * L) ** 2
        C2 = (0.03 * L) ** 2

        v1 = 2.0 * sigma12 + C2
        v2 = sigma1_sq + sigma2_sq + C2
        cs = torch.mean(v1 / v2)  # contrast sensitivity

        ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)

        if size_average:
            ret = ssim_map.mean()
        else:
            ret = ssim_map.mean(1).mean(1).mean(1)

        if full:
            return ret, cs
        return ret


class LPIPS(object):
    '''
    borrowed from https://github.com/richzhang/PerceptualSimilarity
    '''
    def __init__(self, cuda, des="Learned Perceptual Image Patch Similarity", version="0.1"):
        self.des = des
        self.version = version
        self.model = lpips.PerceptualLoss(model='net-lin',net='alex',use_gpu=cuda)

    def __repr__(self):
        return "LPIPS"

    def __call__(self, y_pred, y_true, normalized=True):
        """
        args:
            y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            normalized : change [0,1] => [-1,1] (default by LPIPS)
        return LPIPS, smaller the better
        """
        if normalized:
            y_pred = y_pred * 2.0 - 1.0
            y_true = y_true * 2.0 - 1.0
        return self.model.forward(y_pred, y_true)


class AE(object):
    """
    Modified from matlab : colorangle.m, MATLAB V2019b
    angle = acos(RGB1' * RGB2 / (norm(RGB1) * norm(RGB2)));
    angle = 180 / pi * angle;
    """
    def __init__(self, des='average Angular Error'):
        self.des = des

    def __repr__(self):
        return "AE"

    def __call__(self, y_pred, y_true):
        """
        args:
            y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
        return average AE, smaller the better
        """
        dotP = torch.sum(y_pred * y_true, dim=1)
        Norm_pred = torch.sqrt(torch.sum(y_pred * y_pred, dim=1))
        Norm_true = torch.sqrt(torch.sum(y_true * y_true, dim=1))
        ae = 180 / math.pi * torch.acos(dotP / (Norm_pred * Norm_true + eps))
        return ae.mean(1).mean(1)


if __name__ == "__main__":
    for ch in [3, 1]:
        batch_size, img_row, img_col = 1, 224, 224
        y_true = torch.rand(batch_size, ch, img_row, img_col)
        noise = torch.zeros(y_true.size()).data.normal_(0, std=0.1)
        y_pred = y_true + noise
        for cuda in [False, True]:
            if cuda:
                y_pred = y_pred.cuda()
                y_true = y_true.cuda()

            print('#'*20, 'Cuda : {} ; size : {}'.format(cuda, y_true.size()))
            ########### similarity metrics
            metric = MSE()
            acc = metric(y_pred, y_true).item()
            print("{} ==> {}".format(repr(metric), acc))

            metric = PSNR()
            acc = metric(y_pred, y_true).item()
            print("{} ==> {}".format(repr(metric), acc))

            metric = SSIM()
            acc = metric(y_pred, y_true).item()
            print("{} ==> {}".format(repr(metric), acc))

            metric = LPIPS(cuda)
            acc = metric(y_pred, y_true).item()
            print("{} ==> {}".format(repr(metric), acc))

            metric = AE()
            acc = metric(y_pred, y_true).item()
            print("{} ==> {}".format(repr(metric), acc))

            ########### accuracy metrics
            metric = OAAcc()
            maccu, accu = metric(y_pred, y_true)
            print('mAccu:', maccu, 'Accu', accu)

            metric = Precision()
            mprec, prec = metric(y_pred, y_true)
            print('mPrec:', mprec, 'Prec', prec)

            metric = Recall()
            mreca, reca = metric(y_pred, y_true)
            print('mReca:', mreca, 'Reca', reca)

            metric = F1Score()
            mf1sc, f1sc = metric(y_pred, y_true)
            print('mF1sc:', mf1sc, 'F1sc', f1sc)

            metric = Kappa()
            mkapp, kapp = metric(y_pred, y_true)
            print('mKapp:', mkapp, 'Kapp', kapp)

            metric = Jaccard()
            mjacc, jacc = metric(y_pred, y_true)
            print('mJacc:', mjacc, 'Jacc', jacc)


In [None]:
#@title partialconv.py
###############################################################################
# BSD 3-Clause License
#
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Author & Contact: Guilin Liu (guilinl@nvidia.com)
#
# Source: https://github.com/NVIDIA/partialconv/blob/master/models/partialconv2d.py
###############################################################################

import torch
import torch.nn.functional as F
from torch import nn, cuda
from torch.autograd import Variable

class PartialConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):

        # whether the mask is multi-channel or not
        if 'multi_channel' in kwargs:
            self.multi_channel = kwargs['multi_channel']
            kwargs.pop('multi_channel')
        else:
            self.multi_channel = False  

        if 'return_mask' in kwargs:
            self.return_mask = kwargs['return_mask']
            kwargs.pop('return_mask')
        else:
            self.return_mask = False

        super(PartialConv2d, self).__init__(*args, **kwargs)

        if self.multi_channel:
            self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])
        else:
            self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])
            
        self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]

        self.last_size = (None, None, None, None)
        self.update_mask = None
        self.mask_ratio = None

    def forward(self, input, mask_in=None):
        assert len(input.shape) == 4
        if mask_in is not None or self.last_size != tuple(input.shape):
            self.last_size = tuple(input.shape)

            with torch.no_grad():
                if self.weight_maskUpdater.type() != input.type():
                    self.weight_maskUpdater = self.weight_maskUpdater.to(input)

                if mask_in is None:
                    # if mask is not provided, create a mask
                    if self.multi_channel:
                        mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3]).to(input)
                    else:
                        mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input)
                else:
                    mask = mask_in
                        
                self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

                self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
                #make sure the value of self.mask_ratio for the entries in the interior (no need for padding) have value 1. If not, you replace with the line below.
                # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)

        # if self.update_mask.type() != input.type() or self.mask_ratio.type() != input.type():
        #     self.update_mask.to(input)
        #     self.mask_ratio.to(input)

        raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)

        if self.bias is not None:
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            output = torch.mul(output, self.update_mask)
        else:
            output = torch.mul(raw_out, self.mask_ratio)


        if self.return_mask:
            return output, self.update_mask
        else:
            return output

In [None]:
#@title deformconv2d.py
import torch.nn as nn
import torchvision.ops as O


class DeformConv2d(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
        super(DeformConv2d, self).__init__()

        self.conv_offset = nn.Conv2d(in_nc, 2 * (kernel_size**2), kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
        self.conv_offset.weight.data.zero_()
        self.conv_offset.bias.data.zero_()

        self.dcn_conv = O.DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)

    def forward(self, x):
        offset = self.conv_offset(x)
        return self.dcn_conv(x, offset=offset)

Inside the model it is possible to configure loss functions and weights. Warning: Don't use AMP with StyleLoss.

In [None]:
#@title model.py
# https://github.com/hughplay/DFNet
# https://github.com/Yukariin/DFNet
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

#from utils import resize_like

#from .convolutions import partialconv2d
#from models.modules.architectures.convolutions.deformconv2d import DeformConv2d

def resize_like(x, target, mode='bilinear'):
    return F.interpolate(x, target.shape[-2:], mode=mode, align_corners=False)


def get_norm(name, out_channels):
    if name == 'batch':
        norm = nn.BatchNorm2d(out_channels)
    elif name == 'instance':
        norm = nn.InstanceNorm2d(out_channels)
    else:
        norm = None
    return norm


def get_activation(name):
    if name == 'relu':
        activation = nn.ReLU()
    elif name == 'elu':
        activation == nn.ELU()
    elif name == 'leaky_relu':
        activation = nn.LeakyReLU(negative_slope=0.2)
    elif name == 'tanh':
        activation = nn.Tanh()
    elif name == 'sigmoid':
        activation = nn.Sigmoid()
    else:
        activation = None
    return activation


class Conv2dSame(pl.LightningModule):

    def __init__(self, in_channels, out_channels, conv_type, kernel_size, stride):
        super().__init__()

        padding = self.conv_same_pad(kernel_size, stride)

        if conv_type == 'normal':
          # original
          if type(padding) is not tuple:
              self.conv = nn.Conv2d(
                  in_channels, out_channels, kernel_size, stride, padding)
          else:
              self.conv = nn.Sequential(
                  nn.ConstantPad2d(padding*2, 0),
                  nn.Conv2d(in_channels, out_channels, kernel_size, stride, 0)
              )

        elif conv_type == 'partial':
          if type(padding) is not tuple:
              self.conv = PartialConv2d(
                  in_channels, out_channels, kernel_size, stride, padding)
          else:
              self.conv = nn.Sequential(
                  nn.ConstantPad2d(padding*2, 0),
                  PartialConv2d(in_channels, out_channels, kernel_size, stride, 0)
              )


        elif conv_type == 'deform':
          if type(padding) is not tuple:
              self.conv = PartialConv2d(
                  in_channels, out_channels, kernel_size, stride, padding)
          else:
              self.conv = nn.Sequential(
                  nn.ConstantPad2d(padding*2, 0),
                  DeformConv2d(in_channels, out_channels, kernel_size, stride, 0)
              )


    def conv_same_pad(self, ksize, stride):
        if (ksize - stride) % 2 == 0:
            return (ksize - stride) // 2
        else:
            left = (ksize - stride) // 2
            right = left + 1
            return left, right

    def forward(self, x):
        return self.conv(x)


class ConvTranspose2dSame(pl.LightningModule):

    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super().__init__()

        padding, output_padding = self.deconv_same_pad(kernel_size, stride)
        self.trans_conv = nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size, stride,
            padding, output_padding)

    def deconv_same_pad(self, ksize, stride):
        pad = (ksize - stride + 1) // 2
        outpad = 2 * pad + stride - ksize
        return pad, outpad

    def forward(self, x):
        return self.trans_conv(x)


class UpBlock(pl.LightningModule):

    def __init__(self, mode='nearest', scale=2, channel=None, kernel_size=4):
        super().__init__()

        self.mode = mode
        if mode == 'deconv':
            self.up = ConvTranspose2dSame(
                channel, channel, kernel_size, stride=scale)
        else:
            def upsample(x):
                return F.interpolate(x, scale_factor=scale, mode=mode)
            self.up = upsample

    def forward(self, x):
        return self.up(x)


class EncodeBlock(pl.LightningModule):

    def __init__(
            self, in_channels, out_channels, conv_type, kernel_size, stride,
            normalization=None, activation=None):
        super().__init__()

        self.c_in = in_channels
        self.c_out = out_channels

        layers = []
        layers.append(
            Conv2dSame(self.c_in, self.c_out, conv_type, kernel_size, stride))
        if normalization:
            layers.append(get_norm(normalization, self.c_out))
        if activation:
            layers.append(get_activation(activation))
        self.encode = nn.Sequential(*layers)

    def forward(self, x):
        return self.encode(x)


class DecodeBlock(pl.LightningModule):

    def __init__(
            self, c_from_up, c_from_down, conv_type, c_out, mode='nearest',
            kernel_size=4, scale=2, normalization='batch', activation='relu'):
        super().__init__()

        self.c_from_up = c_from_up
        self.c_from_down = c_from_down
        self.c_in = c_from_up + c_from_down
        self.c_out = c_out

        self.up = UpBlock(mode, scale, c_from_up, kernel_size=scale)

        layers = []
        layers.append(
            Conv2dSame(self.c_in, self.c_out, conv_type, kernel_size, stride=1))
        if normalization:
            layers.append(get_norm(normalization, self.c_out))
        if activation:
            layers.append(get_activation(activation))
        self.decode = nn.Sequential(*layers)

    def forward(self, x, concat=None):
        out = self.up(x)
        if self.c_from_down > 0:
            out = torch.cat([out, concat], dim=1)
        out = self.decode(out)
        return out


class BlendBlock(pl.LightningModule):

    def __init__(
            self, c_in, c_out, conv_type, ksize_mid=3, norm='batch', act='leaky_relu'):
        super().__init__()
        c_mid = max(c_in // 2, 32)
        self.blend = nn.Sequential(
            Conv2dSame(c_in, c_mid, conv_type, 1, 1),
            get_norm(norm, c_mid),
            get_activation(act),
            Conv2dSame(c_mid, c_out, conv_type, ksize_mid, 1),
            get_norm(norm, c_out),
            get_activation(act),
            Conv2dSame(c_out, c_out, conv_type, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.blend(x)


class FusionBlock(pl.LightningModule):
    def __init__(self, c_feat, conv_type, c_alpha=1):
        super().__init__()
        c_img = 3
        self.map2img = nn.Sequential(
            Conv2dSame(c_feat, c_img, conv_type, 1, 1),
            nn.Sigmoid())
        self.blend = BlendBlock(c_img*2, c_alpha, conv_type)

    def forward(self, img_miss, feat_de):
        img_miss = resize_like(img_miss, feat_de)
        raw = self.map2img(feat_de)
        alpha = self.blend(torch.cat([img_miss, raw], dim=1))
        result = alpha * raw + (1 - alpha) * img_miss
        return result, alpha, raw

from torchvision.utils import save_image

class DFNet(pl.LightningModule):
    def __init__(
            self, c_img=3, c_mask=1, c_alpha=3,
            mode='nearest', norm='batch', act_en='relu', act_de='leaky_relu',
            en_ksize=[7, 5, 5, 3, 3, 3, 3, 3], de_ksize=[3]*8,
            blend_layers=[0, 1, 2, 3, 4, 5], conv_type = 'normal'):
        super().__init__()

        c_init = c_img + c_mask

        self.n_en = len(en_ksize)
        self.n_de = len(de_ksize)
        assert self.n_en == self.n_de, (
            'The number layer of Encoder and Decoder must be equal.')
        assert self.n_en >= 1, (
            'The number layer of Encoder and Decoder must be greater than 1.')

        assert 0 in blend_layers, 'Layer 0 must be blended.'

        self.en = []
        c_in = c_init
        self.en.append(
            EncodeBlock(c_in, 64, conv_type, en_ksize[0], 2, None, None))
        for k_en in en_ksize[1:]:
            c_in = self.en[-1].c_out
            c_out = min(c_in*2, 512)
            self.en.append(EncodeBlock(
                c_in, c_out, conv_type, k_en, stride=2,
                normalization=norm, activation=act_en))

        # register parameters
        for i, en in enumerate(self.en):
            self.__setattr__('en_{}'.format(i), en)

        self.de = []
        self.fuse = []
        for i, k_de in enumerate(de_ksize):

            c_from_up = self.en[-1].c_out if i == 0 else self.de[-1].c_out
            c_out = c_from_down = self.en[-i-1].c_in
            layer_idx = self.n_de - i - 1

            self.de.append(DecodeBlock(
                c_from_up, c_from_down, conv_type, c_out, mode, k_de, scale=2,
                normalization=norm, activation=act_de))
            if layer_idx in blend_layers:
                self.fuse.append(FusionBlock(c_out, conv_type, c_alpha))
            else:
                self.fuse.append(None)

        # register parameters
        for i, de in enumerate(self.de[::-1]):
            self.__setattr__('de_{}'.format(i), de)
        for i, fuse in enumerate(self.fuse[::-1]):
            if fuse:
                self.__setattr__('fuse_{}'.format(i), fuse)

    def forward(self, img_miss, mask):

        out = torch.cat([img_miss, mask], dim=1)
        out_en = [out]

        for encode in self.en:
            out = encode(out)
            out_en.append(out)

        results = []
        for i, (decode, fuse) in enumerate(zip(self.de, self.fuse)):
            out = decode(out, out_en[-i-2])
            if fuse:
                result, alpha, raw = fuse(img_miss, out)
                results.append(result)
        return results[::-1][0]


In [None]:
#@title init.py
import torch.nn.init as init

def weights_init(net, init_type = 'kaiming', init_gain = 0.02):
    #Initialize network weights.
    #Parameters:
    #    net (network)       -- network to be initialized
    #    init_type (str)     -- the name of an initialization method: normal | xavier | kaiming | orthogonal
    #    init_var (float)    -- scaling factor for normal, xavier and orthogonal.

    def init_func(m):
        classname = m.__class__.__name__

        if hasattr(m, 'weight') and classname.find('Conv') != -1:
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain = init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a = 0, mode = 'fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain = init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, 0.02)
            init.constant_(m.bias.data, 0.0)
        elif classname.find('Linear') != -1:
            init.normal_(m.weight, 0, 0.01)
            init.constant_(m.bias, 0)

    # Apply the initialization function <init_func>
    print('Initialization method [{:s}]'.format(init_type))
    net.apply(init_func)

In [None]:
#@title data.py
import os

import cv2
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset


class DS(Dataset):
    def __init__(self, root, transform=None):
        self.samples = []
        for root, _, fnames in sorted(os.walk(root)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                self.samples.append(path)
        if len(self.samples) == 0:
            raise RuntimeError("Found 0 files in subfolders of: " + root)

        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        sample_path = self.samples[index]
        sample = Image.open(sample_path).convert('RGB')

        if self.transform:
            sample = self.transform(sample)

        mask = DS.random_mask()
        mask = torch.from_numpy(mask)

        # apply mask
        masked = sample * mask

        return masked, mask, sample
    
    @staticmethod
    def random_mask(height=256, width=256,
                    min_stroke=1, max_stroke=4,
                    min_vertex=1, max_vertex=12,
                    min_brush_width_divisor=16, max_brush_width_divisor=10):
        mask = np.ones((height, width))

        min_brush_width = height // min_brush_width_divisor
        max_brush_width = height // max_brush_width_divisor
        max_angle = 2*np.pi
        num_stroke = np.random.randint(min_stroke, max_stroke+1)
        average_length = np.sqrt(height*height + width*width) / 8

        for _ in range(num_stroke):
            num_vertex = np.random.randint(min_vertex, max_vertex+1)
            start_x = np.random.randint(width)
            start_y = np.random.randint(height)

            for _ in range(num_vertex):
                angle = np.random.uniform(max_angle)
                length = np.clip(np.random.normal(average_length, average_length//2), 0, 2*average_length)
                brush_width = np.random.randint(min_brush_width, max_brush_width+1)
                end_x = (start_x + length * np.sin(angle)).astype(np.int32)
                end_y = (start_y + length * np.cos(angle)).astype(np.int32)

                cv2.line(mask, (start_y, start_x), (end_y, end_x), 0., brush_width)

                start_x, start_y = end_x, end_y
        if np.random.random() < 0.5:
            mask = np.fliplr(mask)
        if np.random.random() < 0.5:
            mask = np.flipud(mask)
        return mask.reshape((1,)+mask.shape).astype(np.float32) 


class DS_green_from_mask(Dataset):
    def __init__(self, root, transform=None):
        self.samples = []
        for root, _, fnames in sorted(os.walk(root)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                self.samples.append(path)
        if len(self.samples) == 0:
            raise RuntimeError("Found 0 files in subfolders of: " + root)

        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        sample_path = self.samples[index]
        #sample = Image.open(sample_path).convert('RGB')
        sample = cv2.imread(sample_path)
        sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB)

        green_mask = 1-np.all(sample == [0,255,0], axis=-1).astype(int)
        #sample = sample*green_mask
        green_mask = torch.from_numpy(green_mask)

        if self.transform:
            sample = self.transform(sample)

        return sample, green_mask.unsqueeze(0)


In [None]:
#@title dataloader.py
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl

class DFNetDataModule(pl.LightningDataModule):
    def __init__(self, training_path: str = './', validation_path: str = './', test_path: str = './', batch_size: int = 5, num_workers: int = 2):
        super().__init__()
        self.training_dir = training_path
        self.validation_dir = validation_path
        self.test_dir = test_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.size = 256
    def setup(self, stage=None):
        img_tf = transforms.Compose([
            transforms.Resize(size=self.size),
            transforms.CenterCrop(size=self.size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])
        
        self.DFNetdataset_train = DS(self.training_dir, img_tf)
        self.DFNetdataset_validation = DS(self.validation_dir, img_tf)
        self.DFNetdataset_test = DS(self.test_dir)

    def train_dataloader(self):
        return DataLoader(self.DFNetdataset_train, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.DFNetdataset_validation, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.DFNetdataset_test, batch_size=self.batch_size, num_workers=self.num_workers)

In [None]:
#@title discriminator.py
"""
models.py (21-12-20)
https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/context_encoder/models.py
"""
class context_encoder(pl.LightningModule):
    def __init__(self, channels=3):
        super(context_encoder, self).__init__()

        def discriminator_block(in_filters, out_filters, stride, normalize):
            """Returns layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = channels
        for out_filters, stride, normalize in [(64, 2, False), (128, 2, True), (256, 2, True), (512, 1, True)]:
            layers.extend(discriminator_block(in_filters, out_filters, stride, normalize))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, 3, 1, 1))

        self.model = nn.Sequential(*layers)
        
    def forward(self, img):
        return self.model(img)

In [None]:
#@title CustomTrainClass.py
from vic.loss import CharbonnierLoss, GANLoss, GradientPenaltyLoss, HFENLoss, TVLoss, GradientLoss, ElasticLoss, RelativeL1, L1CosineSim, ClipL1, MaskedL1Loss, MultiscalePixelLoss, FFTloss, OFLoss, L1_regularization, ColorLoss, AverageLoss, GPLoss, CPLoss, SPL_ComputeWithTrace, SPLoss, Contextual_Loss, StyleLoss
from vic.perceptual_loss import PerceptualLoss
from metrics import *
from torchvision.utils import save_image
from torch.autograd import Variable

class CustomTrainClass(pl.LightningModule):
  def __init__(self):
    super().__init__()

    self.netG = DFNet(c_img=3, c_mask=1, c_alpha=3,
            mode='nearest', norm='batch', act_en='relu', act_de='leaky_relu',
            en_ksize=[7, 5, 5, 3, 3, 3, 3, 3], de_ksize=[3, 3, 3, 3, 3, 3, 3, 3],
            blend_layers=[0, 1, 2, 3, 4, 5], conv_type='partial')
    weights_init(self.netG, 'kaiming')
    
    self.netD = context_encoder()
    weights_init(self.netD, 'kaiming')

    # loss functions
    self.l1 = nn.L1Loss()
    l_hfen_type = L1CosineSim()
    self.HFENLoss = HFENLoss(loss_f=l_hfen_type, kernel='log', kernel_size=15, sigma = 2.5, norm = False)
    self.ElasticLoss = ElasticLoss(a=0.2, reduction='mean')
    self.RelativeL1 = RelativeL1(eps=.01, reduction='mean')
    self.L1CosineSim = L1CosineSim(loss_lambda=5, reduction='mean')
    self.ClipL1 = ClipL1(clip_min=0.0, clip_max=10.0)
    self.FFTloss = FFTloss(loss_f = torch.nn.L1Loss, reduction='mean')
    self.OFLoss = OFLoss()
    self.GPLoss = GPLoss(trace=False, spl_denorm=False)
    self.CPLoss = CPLoss(rgb=True, yuv=True, yuvgrad=True, trace=False, spl_denorm=False, yuv_denorm=False)
    self.StyleLoss = StyleLoss()
    self.TVLoss = TVLoss(tv_type='tv', p = 1)
    self.PerceptualLoss = PerceptualLoss(model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0], model_path=None)
    layers_weights = {'conv_1_1': 1.0, 'conv_3_2': 1.0}
    self.Contextual_Loss = Contextual_Loss(layers_weights, crop_quarter=False, max_1d_size=100,
        distance_type = 'cosine', b=1.0, band_width=0.5,
        use_vgg = True, net = 'vgg19', calc_type = 'regular')

    self.MSELoss = torch.nn.MSELoss()

    # metrics
    self.psnr_metric = PSNR()
    self.ssim_metric = SSIM()
    self.ae_metric = AE()
    self.mse_metric = MSE()


  def forward(self, image, masks):
      return self.netG(image, masks)

  #def adversarial_loss(self, y_hat, y):
  #    return F.binary_cross_entropy(y_hat, y)


  def training_step(self, train_batch, batch_idx):

      # train generator
      #if optimizer_idx == 0:

      # train_batch[0][0] = batch_size
      # train_batch[0] = masked
      # train_batch[1] = mask
      # train_batch[2] = original
      #out = self.forward(train_batch[0],train_batch[1])

      # generate fake
      out = self(train_batch[0],train_batch[1])
      # masking, taking original content from HR
      out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])
      #save_image(out, "output_after_generator.png")

      total_loss = 0
      """
      HFENLoss_forward = self.HFENLoss(out, train_batch[0])
      total_loss += HFENLoss_forward
      ElasticLoss_forward = self.ElasticLoss(out, train_batch[0])
      total_loss += ElasticLoss_forward
      RelativeL1_forward = self.RelativeL1(out, train_batch[0])
      total_loss += RelativeL1_forward
      """
      L1CosineSim_forward = 5*self.L1CosineSim(out, train_batch[2])
      total_loss += L1CosineSim_forward
      self.log('loss/L1CosineSim', L1CosineSim_forward)

      """
      ClipL1_forward = self.ClipL1(out, train_batch[0])
      total_loss += ClipL1_forward
      FFTloss_forward = self.FFTloss(out, train_batch[0])
      total_loss += FFTloss_forward
      OFLoss_forward = self.OFLoss(out)
      total_loss += OFLoss_forward
      GPLoss_forward = self.GPLoss(out, train_batch[0])
      total_loss += GPLoss_forward
      
      CPLoss_forward = 0.1*self.CPLoss(out, train_batch[0])
      total_loss += CPLoss_forward
      

      Contextual_Loss_forward = self.Contextual_Loss(out, train_batch[0])
      total_loss += Contextual_Loss_forward
      self.log('loss/contextual', Contextual_Loss_forward)
      """

      #style_forward = 240*self.StyleLoss(out, train_batch[2])
      #total_loss += style_forward
      #self.log('loss/style', style_forward)

      tv_forward = 0.0000005*self.TVLoss(out)
      total_loss += tv_forward
      self.log('loss/tv', tv_forward)

      perceptual_forward = 2*self.PerceptualLoss(out, train_batch[2])
      total_loss += perceptual_forward
      self.log('loss/perceptual', perceptual_forward)

      self.log('loss/g_loss', total_loss)

      #return total_loss

      # train discriminator
      #if optimizer_idx == 1:

      Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
      valid = Variable(Tensor(out.shape).fill_(1.0), requires_grad=False)
      fake = Variable(Tensor(out.shape).fill_(0.0), requires_grad=False)
      dis_real_loss = self.MSELoss(train_batch[2], valid)
      dis_fake_loss = self.MSELoss(out, fake)

      d_loss = (dis_real_loss + dis_fake_loss) / 2
      self.log('loss/d_loss', d_loss)

      return total_loss+d_loss

  def configure_optimizers(self):
      optimizer = torch.optim.Adam(self.netG.parameters(), lr=2e-3)
      return optimizer

  def validation_step(self, train_batch, train_idx):
    out = self(train_batch[0],train_batch[1])
    self.log('metrics/PSNR', self.psnr_metric(train_batch[2], out))
    self.log('metrics/SSIM', self.ssim_metric(train_batch[2], out))
    self.log('metrics/MSE', self.mse_metric(train_batch[2], out))
    self.log('metrics/LPIPS', self.PerceptualLoss(out, train_batch[2]))


  def test_step(self, train_batch, train_idx):
    # train_batch[0] = masked
    # train_batch[1] = mask
    out = self(train_batch[0].unsqueeze(0)*train_batch[1].unsqueeze(0),train_batch[1].unsqueeze(0))
    out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])
    save_image(out, str(train_idx)+'.png')

  """
  def on_epoch_end(self):
      sample_imgs = self.netG(image, masks)

      # log sampled images
      grid = torchvision.utils.make_grid(sample_imgs)
      self.logger.experiment.add_image('generated_images', grid, self.current_epoch)
  """

In [None]:
#@title checkpoint.py
#https://github.com/PyTorchLightning/pytorch-lightning/issues/2534
import os
import pytorch_lightning as pl

class CheckpointEveryNSteps(pl.Callback):
    """
    Save a checkpoint every N steps, instead of Lightning's default that checkpoints
    based on validation loss.
    """

    def __init__(
        self,
        save_step_frequency,
        prefix="Checkpoint",
        use_modelcheckpoint_filename=False,
        save_path = '/content/'
    ):
        """
        Args:
            save_step_frequency: how often to save in steps
            prefix: add a prefix to the name, only used if
                use_modelcheckpoint_filename=False
            use_modelcheckpoint_filename: just use the ModelCheckpoint callback's
                default filename, don't use ours.
        """
        self.save_step_frequency = save_step_frequency
        self.prefix = prefix
        self.use_modelcheckpoint_filename = use_modelcheckpoint_filename
        self.save_path = save_path

    def on_batch_end(self, trainer: pl.Trainer, _):
        """ Check if we should save a checkpoint after every train batch """
        epoch = trainer.current_epoch
        global_step = trainer.global_step
        if global_step % self.save_step_frequency == 0:
            if self.use_modelcheckpoint_filename:
                filename = trainer.checkpoint_callback.filename
            else:
                filename = f"{self.prefix}_{epoch}_{global_step}.ckpt"
            #ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename)
            ckpt_path = os.path.join(self.save_path, filename)
            trainer.save_checkpoint(ckpt_path)

#Trainer(callbacks=[CheckpointEveryNSteps()])

In [None]:
# Training
dm = DFNetDataModule(training_path = '/content/data/images/', validation_path = '/content/data/images/', batch_size=5)
model = CustomTrainClass()
#weights_init(model, 'kaiming')
# GPU
#trainer = pl.Trainer(gpus=1, max_epochs=10, progress_bar_refresh_rate=20, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=1000, save_path='/content/')])
# GPU with AMP (amp_level='O1' = mixed precision)
trainer = pl.Trainer(gpus=1, precision=16, amp_level='O1', max_epochs=10, progress_bar_refresh_rate=20, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=1000, save_path='/content/')])
# TPU
#trainer = pl.Trainer(tpu_cores=8, max_epochs=10, progress_bar_refresh_rate=20, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=1000, save_path='/content/')])
trainer.fit(model, dm)

# Testing 

In [None]:
# testing the model
img_tf = transforms.Compose([
    transforms.ToTensor()
])

dm = DS_green_from_mask('/content/t/test', img_tf)
model = CustomTrainClass()

# GPU
#trainer = pl.Trainer(gpus=1, max_epochs=10, progress_bar_refresh_rate=20, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=1000, save_path='/content/')])
# GPU with AMP (amp_level='O1' = mixed precision)
trainer = pl.Trainer(gpus=1, precision=16, amp_level='O1', max_epochs=10, progress_bar_refresh_rate=20, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=1000, save_path='/content/')])
# TPU
#trainer = pl.Trainer(tpu_cores=8, max_epochs=10, progress_bar_refresh_rate=20, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=1000, save_path='/content/')])
trainer.test(model, dm, ckpt_path='/content/Checkpoint_1_1000.ckpt')