# Colab-traiNNer

victorca25's BasicSR fork: [victorca25/traiNNer](https://github.com/victorca25/traiNNer)

My fork: [styler00dollar/Colab-traiNNer](https://github.com/styler00dollar/Colab-traiNNer)

In [None]:
#@title Check GPU

gpu = !nvidia-smi --query-gpu=gpu_name --format=csv
print("GPU: " + gpu[1])

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Install

In [None]:
#@title Install dependencies
%cd /content/

# create empty folders
!mkdir /content/hr
!mkdir /content/lr
!mkdir /content/val_hr
!mkdir /content/val_lr
 
!mkdir /content/masks
!mkdir /content/validation
!mkdir /content/data
!mkdir /content/logs/

!git clone https://github.com/styler00dollar/Colab-traiNNer
!pip install git+https://github.com/styler00dollar/pytorch-lightning.git@fc86f4ca817d5ba1702a210a898ac2729c870112
!pip install wget tfrecord x-transformers adamp efficientnet_pytorch tensorboardX vit-pytorch swin-transformer-pytorch madgrad timm pillow-avif-plugin kornia omegaconf

In [None]:
#@title (optional) download precompiled mmcv (for GLEAN)
%cd /content/
!pip uninstall mmcv -y
!pip uninstall mmcv-full -y
!gdown --id 1--PoTPGKwAqGJsmaLYSEiqJy1yWMTx0G
!pip install mmcv_full-1.3.5-cp37-cp37m-linux_x86_64.whl

In [None]:
#@title (optional) compiling and installing mmcv (for GLEAN)
!pip install torch torchvision torchaudio -U
!pip uninstall mmcv -y
!pip uninstall mmcv-full -y
!pip install mmcv-full

In [None]:
#@title (optional) ninja (for GFPGAN / GPEN / co-mod-gan)
%cd /content
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force 

In [None]:
#@title (optional) correlation package (for ABME)
%cd /content/
!sudo rm -rf ABME
!git clone https://github.com/JunHeum/ABME
%cd /content/ABME/correlation_package
#!python setup.py install
!python setup.py build install

In [None]:
#@title (optional) install cupy (for EDSC)
!curl https://colab.chainer.org/install | sh -

In [None]:
#@title (optinal) upgrade pytoch
!pip3 install --pre torch torchvision torchaudio torchtext -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html -U --force-reinstall

In [None]:
#@title TPU
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8.1-cp37-cp37m-linux_x86_64.whl

You need to upload the data and then extract it within colab. You can use Google Drive for that.

In [None]:
%cd /content/
!cp "/content/drive/MyDrive/dataset.tar" "/content/data.tar"
!7z x /content/data.tar

# Train

In [None]:
#@title config.yaml
%%writefile /content/Colab-traiNNer/code/config.yaml 
name: template
scale: 1
gpus: 0 # amount of gpus, 0 = cpu
distributed_backend: ddp # dp, ddp (for multi-gpu training)
tpu_cores: 8 # 8 if you use a Google Colab TPU
use_tpu: False
use_amp: False
use_swa: False
progress_bar_refresh_rate: 20
default_root_dir: '/content'

# Dataset options:
datasets:
  train:
    # DS_inpaint: hr is from dataroot_HR, loads masks
    # DS_inpaint_tiled: hr is from dataroot_HR, but images are grids, loads masks
    # DS_inpaint_tiled_batch: hr is from dataroot_HR, but images are grids and processed as a batch (batch_size_DL), loads masks
    # DS_lrhr: loads lr from dataroot_LR and hr from dataroot_HR
    # DS_lrhr_batch_oft: loads grayscale-hr (3x3 400px) from dataroot_LR and generates lr by downscaling otf randomly
    # DS_video: video dataloader which has 3 frames as input (look into data/data_video.py for more details)
    # DS_inpaint_TF: takes one tfrecord file as dataset input, but the validation is still just green masked images like in DS_inpaint

    mode: DS_inpaint # DS_video | DS_inpaint_TF | DS_inpaint | DS_inpaint_tiled | DS_inpaint_tiled_batch | DS_lrhr | DS_lrhr_batch_oft
    amount_files: 7 # tfrecord files do not store amount of images and are infinite, specify the images inside of it

    grayscale: False # If true, reads as 1-Channel image. If false, reads as 3-channel image. Currently only implemented for DS_lrhr_batch_oft
    tfrecord_path: "/content/tfrecord/tfrecord-r09.tfrecords"
    dataroot_HR: '/content/hr' # Original, with a single directory. Inpainting will use this directory as source image.
    dataroot_LR: '/content/hr' # Original, with a single directory
    loading_backend: 'PIL' # PIL, OpenCV (currently only in DS_inpaint_tiled_batch)

    n_workers: 1 # 0 to disable CPU multithreading, or an integrer representing CPU threads to use for dataloading
    batch_size: 1
    
    # the inpainting dataloaders "DS_inpaint_tiled | DS_inpaint_tiled_batch" randomly crop images out of a grid x-amount of times and return a 
    # batch created from one image (dataloader assume grid images and you must have batch_size: 1, the batch size during training will be determined 
    # by the amount of random crops of one image)
    batch_size_DL: 1 
    HR_size: 512 # The resolution the network will get. Random crop gets applied if that resolution does not match.
    image_channels: 3 # number of channels to load images in

    masks: '/content/masks/'
    mask_invert_ratio: 0.3 # 0.3 = 30% of masks will be inverted
    max_epochs: 2000
    save_step_frequency: 10 # also validation frequency

    # batch
    # If a tiled dataloader is used, specify image charactaristics. Random crop will not be applied. Maybe in the future.
    image_size: 256 # Size of one tile
    amount_tiles: 8 # Amount of tiles inside the merged grid image

    # if edge data is required
    canny_min: 100
    canny_max: 150

  val:
    dataroot_HR: '/content/val_hr/'
    dataroot_LR: '/content/val_hr/' # Inpainting will use this directory as input

path:
    pretrain_model_G: 
    pretrain_model_D: 
    checkpoint_path:
    checkpoint_save_path: '/content/'
    validation_output_path: '/content/val/'
    log_path: '/content/logs'

# Generator options:
network_G:
    # CEM (for esrgan, not 1x)
    CEM: False # uses hardcoded torch.cuda.FloatTensor
    sigmoid_range_limit: False

    finetune: False # Important for further rfr/dsnet training. Apply that after training for a while. https://github.com/jingyuanli001/RFR-Inpainting/issues/33

    # ESRGAN:
    #netG: RRDB_net # RRDB_net (original ESRGAN arch) | MRRDB_net (modified/"new" arch)
    #norm_type: null
    #mode: CNA
    #nf: 64 # of discrim filters in the first conv layer (default: 64, good: 32)
    #nb: 23 # (default: 23, good: 8)
    #in_nc: 3 # of input image channels: 3 for RGB and 1 for grayscale
    #out_nc: 3 # of output image channels: 3 for RGB and 1 for grayscale
    #gc: 32
    #group: 1
    #convtype: Conv2D # Conv2D | PartialConv2D
    #net_act: leakyrelu # swish | leakyrelu
    #gaussian: false # true | false # esrgan plus, does not work on TPU because of cuda()
    #plus: false # true | false
    #finalact: None #tanh # Test. Activation function to make outputs fit in [-1, 1] range. Default = None. Coordinate with znorm.
    #upsample_mode: 'upconv'
    #nr: 3

    # ASRGAN:
    #which_model_G: asr_resnet # asr_resnet | asr_cnn
    #nf: 64

    # PPON:
    #netG: ppon # | ppon
    ##norm_type: null
    #mode: CNA
    #nf: 64
    #nb: 24
    #in_nc: 3
    #out_nc: 3
    ##gc: 32
    #group: 1
    ##convtype: Conv2D #Conv2D | PartialConv2D

    # SRGAN:
    #netG: sr_resnet # RRDB_net | sr_resnet
    #norm_type: null
    #mode: CNA
    #nf: 64
    #nb: 16
    #in_nc: 3
    #out_nc: 3

    # SR:
    #netG: RRDB_net # RRDB_net | sr_resnet
    #norm_type: null
    #mode: CNA
    #nf: 64
    #nb: 23
    #in_nc: 3
    #out_nc: 3
    #gc: 32
    #group: 1

    # PAN:
    # netG: pan_net
    # in_nc: 3
    # out_nc: 3
    # nf: 40
    # unf: 24
    # nb: 16
    # self_attention: true
    # double_scpa: false

    # edge-informed-sisr
    #which_model_G: sisr
    #use_spectral_norm: True

    # USRNet
    #netG: USRNet
    #in_nc=4
    #out_nc=3
    #nc=[64, 128, 256, 512]
    #nb=2
    #act_mode='R'
    #downsample_mode='strideconv'
    #upsample_mode='convtranspose'

    # GLEAN (2021)
    # Warning: Does require "pip install mmcv-full"
    #netG: GLEAN
    #in_size: 512
    #out_size: 512
    #img_channels: 4
    #img_channels_out: 3
    #rrdb_channels: 16 # 64
    #num_rrdbs: 8 # 23
    #style_channels: 512 # 512
    #num_mlps: 4 # 8
    #channel_multiplier: 2
    #blur_kernel: [1, 3, 3, 1]
    #lr_mlp: 0.01
    #default_style_mode: 'mix'
    #eval_style_mode: 'single'
    #mix_prob: 0.9
    #pretrained: False # only works with official settings
    #bgr2rgb: False

    # srflow (upscaling factors: 4, 8, 16)
    # Warning: Can be very unstable with batch_size 1, use higher batch_size
    #netG: srflow
    #in_nc: 3
    #out_nc: 3
    #nf: 64
    #nb: 23
    #train_RRDB: false
    #train_RRDB_delay: 0.5
    #flow:
    #  K: 16
    #  L: 3
    #  noInitialInj: true
    #  coupling: CondAffineSeparatedAndCond
    #  additionalFlowNoAffine: 2
    #  split:
    #    enable: true
    #  fea_up0: true
    #  stackRRDB:
    #    blocks: [ 1, 8, 15, 22 ]
    #    concat: true
    #nll_weight: 1
    #freeze_iter: 100000

    # DFDNet
    # Warning: Expects "DictionaryCenter512" in the current folder, you can get the data here: https://drive.google.com/drive/folders/1bayYIUMCSGmoFPyd4Uu2Uwn347RW-vl5
    # Also wants a folder called "landmarks", you can generate that data yourself. Example: https://github.com/styler00dollar/Colab-DFDNet/blob/local/Colab-DFDNet-lightning-train.ipynb
    # Hardcoded resolution: 512px
    #netG: DFDNet
    #dictionary_path: "/content/DictionaryCenter512"
    #landmarkpath: "/content/landmarks"
    #val_landmarkpath: "/content/landmarks"

    # GFPGAN (2021) [EXPERIMENTAL]
    # does require ninja
    # because it compiles files, the startup time is quite long
    #netG: GFPGAN
    #input_channels: 4
    #output_channels: 3
    #out_size: 512
    #num_style_feat: 512
    #channel_multiplier: 1
    #resample_kernel: [1, 3, 3, 1]
    #decoder_load_path: # None
    #fix_decoder: True
    #num_mlp: 8
    #lr_mlp: 0.01
    #input_is_latent: False
    #different_w: False
    #narrow: 1
    #sft_half: False

    # GPEN
    # does require ninja
    # because it compiles files, the startup time is quite long
    # output_channels is hardcoded to 3
    #netG: GPEN
    #input_channels: 4
    #size: 512
    #style_dim: 512
    #n_mlp: 8
    #channel_multiplier: 2
    #blur_kernel: [1, 3, 3, 1]
    #lr_mlp: 0.01
    #pooling: True # Experimental, to have any input size

    # comodgan (2021)
    # needs ninja
    # because it compiles files, the startup time is quite long
    #netG: comodgan
    #dlatent_size: 512
    #num_channels: 3 # amount of channels without mask
    #resolution: 512
    #fmap_base: 16384 # 16 << 10
    #fmap_decay: 1.0
    #fmap_min: 1
    #fmap_max: 512
    #randomize_noise: True
    #architecture: 'skip'
    #nonlinearity: 'lrelu'
    #resample_kernel: [1,3,3,1]
    #fused_modconv: True
    #pix2pix: False
    #dropout_rate: 0.5
    #cond_mod: True
    #style_mod: True
    #noise_injection: True

    # ----Inpainting Generators----
    # DFNet (batch_size: 2+, needs 2^x image input and validation) (2019)
    netG: DFNet
    c_img: 3
    c_mask: 1
    c_alpha: 3
    mode: nearest
    norm: batch
    act_en: relu
    act_de: leaky_relu
    en_ksize: [7, 5, 5, 3, 3, 3, 3, 3]
    de_ksize: [3, 3, 3, 3, 3, 3, 3, 3]
    blend_layers: [0, 1, 2, 3, 4, 5]
    conv_type: normal # partial | normal | deform
    

    # EdgeConnect (2019)
    #netG: EdgeConnect
    #use_spectral_norm: True
    #residual_blocks_edge: 8
    #residual_blocks_inpaint: 8
    #conv_type_edge: 'normal' # normal | partial | deform (has no spectral_norm)
    #conv_type_inpaint: 'normal' # normal | partial | deform

    # CSA (2019)
    #netG: CSA
    #c_img: 3
    #norm: 'instance'
    #act_en: 'leaky_relu'
    #act_de: 'relu'

    # RN (2020)
    #netG: RN
    #input_channels: 3
    #residual_blocks: 8
    #threshold: 0.8

    # deepfillv1 (2018)
    #netG:  deepfillv1

    # deepfillv2 (2019)
    #netG: deepfillv2
    #in_channels:  4
    #out_channels:  3
    #latent_channels:  64
    #pad_type:  'zero'
    #activation:  'lrelu'
    #norm: 'in'
    #conv_type: partial # partial | normal

    # Adaptive (2020)
    #netG: Adaptive
    #in_channels: 3
    #residual_blocks: 1
    #init_weights: True

    # Global (2020)
    #netG: Global
    #input_dim: 5
    #ngf: 32
    #use_cuda: True
    #device_ids: [0]

    # Pluralistic (2019)
    #netG: Pluralistic
    #ngf_E: 32
    #z_nc_E: 128
    #img_f_E: 128
    #layers_E: 5
    #norm_E: 'none'
    #activation_E: 'LeakyReLU'
    #ngf_G: 32
    #z_nc_G: 128
    #img_f_G: 128
    #L_G: 0
    #output_scale_G: 1
    #norm_G: 'instance'
    #activation_G: 'LeakyReLU'

    # crfill (2020)
    #netG: crfill
    #cnum: 48

    # DeepDFNet (experimental)
    #netG: DeepDFNet
    #in_channels:  4
    #out_channels:  3
    #latent_channels:  64
    #pad_type:  'zero'
    #activation:  'lrelu'
    #norm: 'in'

    # partial (2018)
    #netG: partial

    # DMFN (2020)
    #netG: DMFN
    #in_nc: 4
    #out_nc: 3
    #nf: 64
    #n_res: 8
    #norm: 'in'
    #activation: 'relu'

    # pennet (2019)
    #netG: pennet

    # LBAM (2019)
    #netG: LBAM
    #inputChannels: 4
    #outputChannels: 3

    # RFR (use_swa: false, no TPU) (2020)
    #netG: RFR
    #conv_type: partial # partial | deform

    # FRRN (2019)
    #netG: FRRN

    # PRVS (2019)
    #netG: PRVS

    # CRA (HR_size: 512) (2020)
    #netG: CRA
    #activation: 'elu'
    #norm: 'none'

    # atrous (2020)
    #netG: atrous

    # MEDFE (batch_size: 1) (2020)
    #netG: MEDFE

    # AdaFill (2021)
    #netG: AdaFill

    # lightweight_gan (2021)
    #netG: lightweight_gan
    #image_size: 512
    #latent_dim: 256
    #fmap_max: 512
    #fmap_inverse_coef: 12
    #transparent: False
    #greyscale: False
    #freq_chan_attn: False

    # CTSDG (2021)
    #netG: CTSDG

    # lama (2022) (no AMP)
    #netG: lama

    # MST (2021)
    #netG: MST

    # ----Interpolation Generators----
    # cain (2020)
    #netG: CAIN
    #depth: 3
    #conv: MBConv # doconv | conv2d | gated | TBC | dynamic | MBConv
    #RG: 2 # ResidualGroup amount
    # for dynamic
    #nof_kernels: 4
    #reduce: 4

    # rife 3.8
    #netG: rife

    # RRIN (2020)
    #netG: RRIN

    # ABME (2021)
    #netG: ABME

    # EDSC (2021) (2^x image size)
    # pip install cupy
    #netG: EDSC

# Discriminator options:
network_D:
    discriminator_criterion: MSE # MSE

    d_loss_fool_weight: 1 # inside the generator loop, trying to fool the disciminator
    d_loss_weight: 1 # inside own discriminator update
    
    #netD: # in case there is no discriminator, leave it empty

    # VGG
    #netD: VGG
    #size: 256
    #in_nc: 3 #3
    #base_nf: 64
    #norm_type: 'batch'
    #act_type: 'leakyrelu'
    #mode: 'CNA'
    #convtype: 'Conv2D'
    #arch: 'ESRGAN'

    # VGG fea
    #netD: VGG_fea
    #size: 256
    #in_nc: 3
    #base_nf: 64
    #norm_type: 'batch'
    #act_type: 'leakyrelu'
    #mode: 'CNA'
    #convtype: 'Conv2D'
    #arch: 'ESRGAN'
    #spectral_norm: False
    #self_attention: False
    #max_pool: False
    #poolsize: 4


    #netD: VGG_128_SN

    # VGGFeatureExtractor
    #netD: VGGFeatureExtractor
    #feature_layer: 34
    #use_bn: False
    #use_input_norm: True
    #device: 'cpu'
    #z_norm: False

    # PatchGAN
    #netD: NLayerDiscriminator
    #input_nc: 3
    #ndf: 64
    #n_layers: 3
    #norm_layer: nn.BatchNorm2d
    #use_sigmoid: False
    #getIntermFeat: False
    #patch: True
    #use_spectral_norm: False

    # Multiscale
    #netD: MultiscaleDiscriminator
    #input_nc: 3
    #ndf: 64
    #n_layers: 3
    #norm_layer: nn.BatchNorm2d
    #se_sigmoid: False
    #num_D: 3
    #getIntermFeat: False

    #netD: ResNet101FeatureExtractor
    #use_input_norm: True
    #device: 'cpu'
    #z_norm: False

    # MINC
    #netD: MINCNet

    # Pixel
    #netD: PixelDiscriminator
    #input_nc: 3
    #ndf: 64
    #norm_layer: nn.BatchNorm2d

    # EfficientNet (3-channel input)
    #netD: EfficientNet
    #EfficientNet_pretrain: 'efficientnet-b0'
    #num_classes: 1 # should be 1

    # ResNeSt (not working)
    #netD: ResNeSt
    #ResNeSt_pretrain: 'resnest50' # ["resnest50", "resnest101", "resnest200", "resnest269"]
    #pretrained: False # cant be true currently
    #num_classes: 1

    # Transformer (not working)
    #netD: TranformerDiscriminator
    #img_size: 256
    #patch_size: 1
    #in_chans: 3
    #num_classes: 1
    #embed_dim: 64
    #depth: 7
    #num_heads: 4
    #mlp_ratio: 4.
    #qkv_bias: False
    #qk_scale: None
    #drop_rate: 0.
    #attn_drop_rate: 0.
    #drop_path_rate: 0.
    #hybrid_backbone: None
    #norm_layer: 

    # context_encoder (num_classes can't be set, broadcasting warning will be shown, training works, but I am not sure if it will work correctly)
    #netD: context_encoder

    # Transformer (doesn't do init)
    #netD: ViT
    #image_size: 256
    #patch_size: 32
    #num_classes: 1
    #dim: 1024
    #depth: 6
    #heads: 16
    #mlp_dim: 2048
    #dropout: 0.1
    #emb_dropout: 0.1

    # Transformer (doesn't do init)
    #netD: DeepViT
    #image_size: 256
    #patch_size: 32
    #num_classes: 1
    #dim: 1024
    #depth: 6
    #heads: 16
    #mlp_dim: 2048
    #dropout: 0.1
    #emb_dropout: 0.1

    # RepVGG
    #netD: RepVGG
    #RepVGG_arch: RepVGG-A0 # RepVGG-A0, RepVGG-A1, RepVGG-A2, RepVGG-B0, RepVGG-B1, RepVGG-B1g2, RepVGG-B1g4, , RepVGG-B2, RepVGG-B2g2, RepVGG-B2g4, RepVGG-B3, RepVGG-B3g2, RepVGG-B3g4
    #num_classes: 1

    # squeezenet
    #netD: squeezenet
    #version: "1_1" # 1_0, 1_1
    #num_classes: 1

    # SwinTransformer (doesn't do init)
    #netD: SwinTransformer
    #hidden_dim: 96
    #layers: [2, 2, 6, 2]
    #heads: [3, 6, 12, 24]
    #channels: 3
    #num_classes: 1
    #head_dim: 32
    #window_size: 8
    #downscaling_factors: [4, 2, 2, 2]
    #relative_pos_embedding: True

    # mobilenetV3 (doesn't do init)
    #netD: mobilenetV3
    #mode: small # small, large
    #n_class: 1
    #input_size: 256

    # resnet
    #netD: resnet
    #resnet_arch: resnet50 # resnet50, resnet101, resnet152
    #num_classes: 1
    #pretrain: True
  
    # NFNet
    #netD: NFNet
    #num_classes: 1
    #variant: 'F0'         # F0 - F7
    #stochdepth_rate: 0.25 # 0-1, the probability that a layer is dropped during one step
    #alpha: 0.2            # Scaling factor at the end of each block
    #se_ratio: 0.5         # Squeeze-Excite expansion ratio
    #activation: 'gelu'    # or 'relu'

    # lvvit (2021)
    # Warning: Needs 'pip install timm==0.4.5'
    #netD: lvvit
    #img_size: 224
    #patch_size: 16
    #in_chans: 3
    #num_classes: 1
    #embed_dim: 768
    #depth: 12
    #num_heads: 12
    #mlp_ratio: 4.
    #qkv_bias: False
    #qk_scale: # None
    #drop_rate: 0.
    #attn_drop_rate: 0.
    #drop_path_rate: 0.
    #drop_path_decay: 'linear'
    #hybrid_backbone: # None
    ##norm_layer: nn.LayerNorm # Deafault: nn.LayerNorm / can't be configured
    #p_emb: '4_2'
    #head_dim: # None
    #skip_lam: 1.0
    #order: # None
    #mix_token: False
    #return_dense: False

    # timm
    # pip install timm
    # you can loop up models here: https://rwightman.github.io/pytorch-image-models/
    netD: timm
    timm_model: "tf_efficientnetv2_b0"

    #netD: resnet3d
    #model_depth: 50 # [10, 18, 34, 50, 101, 152, 200]

    # from lama (2021)
    #netD: FFCNLayerDiscriminator
    #FFCN_feature_weight: 1

    #netD: effV2
    #conv: fft # fft | conv2d
    #size: s # s | m | l | xl

    # x-transformers
    # pip install x-transformers
    #netD: x_transformers
    #image_size: 512
    #patch_size: 32
    #dim: 512
    #depth: 6
    #heads: 8

    #netD: mobilevit
    #size: xxs # xxs | xs | s

    # because of too many parameters, a seperate config file named "hrt_config.yaml" is available
    #netD: hrt

train: 
    scheduler: AdamP # Adam, AdamP, Adam, SGDP, MADGRAD, cosangulargrad [maybe broken], tanangulargrad [maybe broken]
    lr_g: 0.0001
    lr_d: 0.0001
    
    # AdamP, AGDP, MADGRAD, cosangulargrad, tanangulargrad
    weight_decay: 0.01

    # SGDP, MAGDRAD
    momentum: 0.9

    # AdamP, cosangulargrad, tanangulargrad
    betas0: 0.9
    betas1: 0.999
    
    # SGDP
    nesterov: True

    # MADGRAD, cosangulargrad, tanangulargrad
    eps: 1e-6

    ############################

    # Losses:
    L1Loss_weight: 0

    # HFENLoss
    HFEN_weight: 0
    loss_f: L1CosineSim # L1Loss | L1CosineSim
    kernel: 'log'
    kernel_size: 15
    sigma: 2.5
    norm: False

    # Elastic
    Elatic_weight: 0
    a: 0.2
    reduction_elastic: 'mean'

    # Relative L1
    Relative_l1_weight: 0
    eps: .01
    reduction_realtive: 'mean'

    # L1CosineSim (3-channel input)
    L1CosineSim_weight: 0
    loss_lambda: 5
    reduction_L1CosineSim: 'mean'

    # ClipL1
    ClipL1_weight: 0
    clip_min: 0.0
    clip_max: 10.0

    # FFTLoss
    FFTLoss_weight: 0
    loss_f_fft: L1Loss
    reduction_fft: 'mean'

    OFLoss_weight: 0

    # GPLoss
    GPLoss_weight: 0
    trace: False
    spl_denorm: False

    # CPLoss
    CPLoss_weight: 0
    rgb: True
    yuv: True
    yuvgrad: True
    trace: False
    spl_denorm: False
    yuv_denorm: False

    StyleLoss_weight: 0

    # TVLoss
    TVLoss_weight: 0
    tv_type: 'tv'
    p: 1

    # Contextual_Loss (3-channel input)
    Contexual_weight: 0
    crop_quarter: False
    max_1d_size: 100
    distance_type: 'cosine' # ["11", "l2", "consine"]
    b: 1.0
    band_width: 0.5
    # for vgg
    use_vgg: False
    net_contextual: 'vgg19'
    layers_weights: {'conv_1_1': 1.0, 'conv_3_2': 1.0}
    # for timm
    use_timm: True
    timm_model: "tf_efficientnetv2_b0"
    # for both
    calc_type: 'regular' # ["bilateral" | "symetric" | None]

    # Style (3-channel input)
    StyleLoss_weight: 0

    # PerceptualLoss
    perceptual_weight: 0
    net: PNetLin # PNetLin, DSSIM (?)
    pnet_type: 'vgg' # alex, squeeze, vgg
    pnet_rand: False
    pnet_tune: False
    use_dropout: True
    spatial: False
    version: '0.1' # only version
    lpips: True

    # high receptive field (HRF) perceptual loss
    # you can download it manually with this command, but the code will download it automatically if you don't have it
    # wget -P /content/ http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth
    hrf_perceptual_weight: 1

    ColorLoss_weight: 0 # converts rgb to yuv and calculates l1, expected input is rgb
    FrobeniusNormLoss_weight: 0
    GradientLoss_weight: 0
    MultiscalePixelLoss_weight: 0
    SPLoss_weight: 0
    FFLoss_weight: 1

    # only if the network outputs 2 images, will use l1
    stage1_weight: 0 

    Lap_weight: 0

    # pytorch loss functions
    MSE_weight: 0
    BCE_weight: 0
    Huber_weight: 1
    SmoothL1_weight: 0

    # loss for CTSDG
    CTSDG_edge_weight: 0 #0.01
    CTSDG_projected_weight: 0 #0.1

    # Differentiable Augmentation for Data-Efficient GAN Training
    diffaug: False
    policy: 'color,translation,cutout'

    # Metrics
    metrics: [] # PSNR | SSIM | AE | MSE

In [None]:
%cd '/content/Colab-traiNNer/code'
!python train.py

# Misc (Data)

In [None]:
#@title Resize folder
from tqdm import tqdm
import os
import glob
import cv2
import threading
import shutil
import hashlib
import PIL
from PIL import Image
import numpy as np
import hashlib

rootdir = "/content/val_lr/"
destination = "/content/val_lr/"
broken_folder = "/content/"

resize_method = 'PIL' #@param ["OpenCV", "PIL"] {allow-input: false}

files = glob.glob(rootdir + '/**/*.png', recursive=True)
files_jpg = glob.glob(rootdir + '/**/*.jpg', recursive=True)
files_jpeg = glob.glob(rootdir + '/**/*.jpeg', recursive=True)
files_webp = glob.glob(rootdir + '/**/*.webp', recursive=True)
files.extend(files_jpg)
files.extend(files_jpeg)
files.extend(files_webp)
err_files=[]

image_size = 64

for file in tqdm(files):
    image = cv2.imread(file)
    if image is not None:
        #####################################
        # resize with opencv
        if resize_method == "OpenCV":
          resized = cv2.resize(image, (image_size,image_size), interpolation=cv2.INTER_AREA)

        # resize with PIL
        elif resize_method == "PIL":
          image = Image.fromarray(image)
          image = image.resize((image_size,image_size))
          resized = np.asarray(image)
        #####################################

        hash_md5 = hashlib.md5()
        with open(file, "rb") as f:
          for chunk in iter(lambda: f.read(4096), b""):
            hash_md5.update(chunk)

        cv2.imwrite(file, resized)

In [None]:
#@title creating tiled images (image grids) (with skip)
import cv2
import numpy
import glob
import shutil
import tqdm
import os
import PIL
from PIL import Image
import numpy as np

resize_method = 'PIL' #@param ["OpenCV", "PIL"] {allow-input: false}
grayscale = False #@param {type:"boolean"}

rootdir = '/content/' #@param {type:"string"}
destination_dir = "/content/" #@param {type:"string"}
broken_dir = '/content/opencv_fail/' #@param {type:"string"}
 
files = glob.glob(rootdir + '/**/*.png', recursive=True)
files_jpg = glob.glob(rootdir + '/**/*.jpg', recursive=True)
files_jpeg = glob.glob(rootdir + '/**/*.jpeg', recursive=True)
files_webp = glob.glob(rootdir + '/**/*.webp', recursive=True)
files.extend(files_jpg)
files.extend(files_jpeg)
files.extend(files_webp)
err_files=[]

amount_tiles = 8 #@param
image_size = 256 #@param

filepos = 0
img_cnt = 0
filename_cnt = 0

if grayscale == True:
  tmp_img = numpy.zeros((amount_tiles*image_size,amount_tiles*image_size))
elif grayscale == False:
  tmp_img = numpy.zeros((amount_tiles*image_size,amount_tiles*image_size, 3))

with tqdm.tqdm(files) as pbar:
  while True:
      if grayscale == True:
        image = cv2.imread(files[filepos], cv2.IMREAD_GRAYSCALE)
      elif grayscale == False:
        image = cv2.imread(files[filepos])

      filepos += 1

      if image is not None:
        
        i = img_cnt % amount_tiles
        j = img_cnt // amount_tiles

        #####################################
        # resize with opencv
        if resize_method == "OpenCV":
          image = cv2.resize(image, (image_size,image_size), interpolation=cv2.INTER_AREA)

        # resize with PIL
        elif resize_method == "PIL":
          if grayscale == True:
            image = Image.fromarray(image)
            image = image.resize((image_size,image_size))
            image = np.asarray(image)
          if grayscale == False:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = Image.fromarray(image)
            image = image.resize((image_size,image_size), resample=PIL.Image.LANCZOS)
            image = np.asarray(image)
            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        #####################################

        tmp_img[i*image_size:(i+1)*image_size, j*image_size:(j+1)*image_size] = image
        img_cnt += 1
      else:
        print(files[filepos])
        print(f'{broken_dir}/{os.path.basename(files[filepos])}')
        shutil.move(files[filepos], f'{broken_dir}/{os.path.basename(files[filepos])}')

      if img_cnt == (amount_tiles*amount_tiles):
        #cv2.imwrite(destination_dir+str(filename_cnt)+".jpg", tmp_img, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
        cv2.imwrite(destination_dir+str(filename_cnt)+".webp", tmp_img)
        filename_cnt += 1
        img_cnt = 0
      pbar.update(1)

In [None]:
#@title convert to onnx
#@markdown Make sure the input dimensions are correct. Maybe a runtime restart is needed if it complains about ``TypeError: forward() missing 1 required positional argument``. Make sure you only run the required cells.
from torch.autograd import Variable
model = CustomTrainClass()
checkpoint_path = '/content/Checkpoint_0_0.ckpt' #@param
output_path = '/content/output.onnx' #@param
model = model.load_from_checkpoint(checkpoint_path) # start training from checkpoint, warning: apperantly global_step will be reset to zero and overwriting validation images, you could manually make an offset
dummy_input = Variable(torch.randn(1, 1, 64, 64))

model.to_onnx(output_path, input_sample=dummy_input)

In [None]:
#@title copy pasting data to create artificatial dataset for debugging
import shutil
from random import random
from tqdm import tqdm
for i in tqdm(range(5000)):
  shutil.copy("/content/4k/0.jpg", "/content/4k/"+str(random())+"jpg")

In [None]:
#@title pip list with space
!pip list | tail -n +3 | awk '{print $1}' | xargs pip show | grep -E 'Location:|Name:' | cut -d ' ' -f 2 | paste -d ' ' - - | awk '{print $2 "/" tolower($1)}' | xargs du -sh 2> /dev/null | sort -hr

In [None]:
#@title tiling script
import cv2
import numpy
import glob
import shutil
import tqdm
import os
from multiprocessing.pool import ThreadPool as ThreadPool

rootdir = '/content/' #@param {type:"string"}
destination_dir = "/content/" #@param {type:"string"}
broken_dir = '/content/' #@param {type:"string"}
threads = 2 #@param
tile_size = 256 #@param

files = glob.glob(rootdir + '/**/*.png', recursive=True)
files_jpg = glob.glob(rootdir + '/**/*.jpg', recursive=True)
files_jpeg = glob.glob(rootdir + '/**/*.jpeg', recursive=True)
files_webp = glob.glob(rootdir + '/**/*.webp', recursive=True)
files.extend(files_jpg)
files.extend(files_jpeg)
files.extend(files_webp)
err_files=[]

pool = ThreadPool(threads)

def tiling(f):
  image = cv2.imread(f)
  if image is not None:
      counter = 0

      x = image.shape[0]
      y = image.shape[1]

      x_amount = x // tile_size
      y_amount = y // tile_size

      for i in range(x_amount):
        for j in range(y_amount):
          crop = image[i*tile_size:(i+1)*tile_size, (j*tile_size):(j+1)*tile_size]
          cv2.imwrite(os.path.join(destination_dir, os.path.splitext(os.path.basename(f))[0] + str(counter) + ".png"), crop)
          counter += 1

    else:
        print(f'Broken file: {os.path.basename(f)}')
        shutil.move(f, f'{broken_dir}/{os.path.basename(f)}')

        
pool.map(tiling, files)

In [None]:
#@title create landmarks (for DFDNet)
!pip install face-alignment
!pip install matplotlib --upgrade

import face_alignment
from skimage import io
import numpy as np
import glob
from tqdm import tqdm
import os
import shutil

fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)

unchecked_input_path = '/content/ffhq' #@param {type:"string"}
checked_output_path = '/content/ffhq' #@param {type:"string"}
failed_output_path = '/content/ffhq' #@param {type:"string"}
landmark_output_path = '/content/landmarks' #@param {type:"string"}

if not os.path.exists(unchecked_input_path):
    os.makedirs(unchecked_input_path)
if not os.path.exists(checked_output_path):
    os.makedirs(checked_output_path)
if not os.path.exists(failed_output_path):
    os.makedirs(failed_output_path)
if not os.path.exists(landmark_output_path):
    os.makedirs(landmark_output_path)

files = glob.glob(unchecked_input_path + '/**/*.png', recursive=True)
files_jpg = glob.glob(unchecked_input_path + '/**/*.jpg', recursive=True)
files.extend(files_jpg)
err_files=[]

for f in tqdm(files):
  input = io.imread(f)
  preds = fa.get_landmarks(input)
  if preds is not None:
    np.savetxt(os.path.join(landmark_output_path, os.path.basename(f)+".txt"), preds[0], delimiter=' ', fmt='%1.3f')   # X is an array
    shutil.move(f, os.path.join(checked_output_path,os.path.basename(f)))
  else:
    shutil.move(f, os.path.join(failed_output_path,os.path.basename(f)))

In [None]:
#@title download DictionaryCenter512 for DFDNet
!mkdir /content/DictionaryCenter512
%cd /content/DictionaryCenter512
!gdown --id 1sEB9j3s7Wj9aqPai1NF-MR7B-c0zfTin
!gdown --id 1H4kByBiVmZuS9TbrWUR5uSNY770Goid6
!gdown --id 10ctK3d9znZ9nGN3d1Z77xW3GGshbeKBb
!gdown --id 1gcwmrIZjPFVu-cHjdQD6P4luohkPsil-
!gdown --id 1rJ8cORPxbJsIVAiNrjBag0ihaY_Mvurn
!gdown --id 1LkfJv2a3ud-mefAc1eZMJuINuNdSYgYO
!gdown --id 1LH-nxD__icSJvTiAbXAXDch03oDtbpkZ
!gdown --id 1JRTStLFsQ8dwaQjQ8qG5fNyrOvo6Tcvd
!gdown --id 1Z4AkU1pOYTYpdbfljCgNMmPilhdEd0Kl
!gdown --id 1Z4e1ltB3ACbYKzkoMBuVtzZ7a310G4xc
!gdown --id 1fqWmi6-8ZQzUtZTp9UH4hyom7n4nl8aZ
!gdown --id 1wfHtsExLvSgfH_EWtCPjTF5xsw3YyvjC
!gdown --id 1Jr3Luf6tmcdKANcSLzvt0sjXr0QUIQ2g
!gdown --id 1sPd4_IMYgqGLol0gqhHjBedKKxFAxswR
!gdown --id 1eVFjXJRnBH4mx7ZbAmZRwVXZNUbgCQec
!gdown --id 1w0GfO_KY775ZVF3KMk74ya6QL_bNU4cJ

#!mkdir /content/DFDNet/weights/
#%cd /content/DFDNet/weights/
#!gdown --id 1SfKKZJduOGhDD27Xl01yDx0-YSEkL2Aa

In [None]:
#@title getting ffhq test data
%cd /content/
!gdown --id 1VE5tnOKcfL6MoV839IVCCw5FhJxIgml5
!7z x data.zip

# Downlaod pretrain

In [None]:
#@title getting DFDNet pretrain
%cd /content
!gdown --id 1UCo7YEbLLa1_87b0AoWmzhTGyrw-26nb

In [None]:
#@title downloading places2 dfnet
%cd /content/
!gdown --id 1SGJ_Z9kpchdnZ3Qwwf4HnN-Cq-AeK7vH # dfnet places2

In [None]:
#@title download rfr paris model and fix state_dict
# rfr paris
%cd /content/
!gdown --id 1jnUb-EvBw9DcwyWUQyWDdN9o42BPH7uT

#https://discuss.pytorch.org/t/dataparallel-changes-parameter-names-issue-with-load-state-dict/60211
import torch
from collections import OrderedDict
state_dict = torch.load("/content/checkpoint_paris.pth", map_location='cpu')
new_state_dict = OrderedDict()

for k, v in state_dict['generator'].items():
  if k == 'Pconv1.weight':
      name = 'conv1.weight'
  elif k == 'Pconv2.weight':
      name = 'conv2.weight'
  elif k == 'Pconv21.weight':
      name = 'conv21.weight'
  elif k == 'Pconv22.weight':
      name = 'conv22.weight'
  else:
      name = k

  new_state_dict[name] = v

torch.save(new_state_dict, '/content/converted.pth')

In [None]:
#@title download and create fixed lama pretrain
!pip install omegaconf

%cd /content
!pip3 install wldhx.yadisk-direct
!curl -L $(yadisk-direct https://disk.yandex.ru/d/ouP6l8VJ0HpMZg) -o big-lama.zip
!unzip big-lama.zip

import torch
from collections import OrderedDict
state_dict = torch.load("/content/big-lama/models/best.ckpt", map_location='cpu')
new_state_dict = OrderedDict()

for k, v in state_dict['state_dict'].items():
  name = k.replace("generator.", "")
  new_state_dict[name] = v

torch.save(new_state_dict, '/content/converted.pth')

# Misc

A summary of all interesting inpainting generators that are not trainable with my code.

--------------------------------------------------

``Broken generators:``

Generators that are not included here since I can't seem to make them work properly:

AOT-GAN (2021): [researchmm/AOT-GAN-for-Inpainting](https://github.com/researchmm/AOT-GAN-for-Inpainting)

    Couldn't get generator working.

PenNet [no AMP] (2019): [researchmm/PEN-Net-for-Inpainting](https://github.com/researchmm/PEN-Net-for-Inpainting/)

    Always outputs white for some reason.

CRA [no AMP] (2019): [wangyx240/High-Resolution-Image-Inpainting-GAN](https://github.com/wangyx240/High-Resolution-Image-Inpainting-GAN)

    Likes to create the color pink.

Global [no AMP] (2020): [SayedNadim/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting](https://github.com/SayedNadim/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting)

    Always outputs white for some reason.

crfill (2020): [zengxianyu/crfill](https://github.com/zengxianyu/crfill)

    No clear instructions/code result in broken results. Unreleased training code makes a correct implementation harder.

--------------------------------------------------

``Non-Pytorch generators:``

PSR-Net (2020): [sfwyly/PSR-Net](https://github.com/sfwyly/PSR-Net)

    Uses Tensorflow 2

co-mod-gan (2021): [zsyzzsoft/co-mod-gan](https://github.com/zsyzzsoft/co-mod-gan)

    Has a web demo and (a broken link to a) docker. Relies on Tensorflow 1.15 / StyleGAN2 code. A Colab by me for this can be found inside https://github.com/styler00dollar/Colab-co-mod-gan.

Diverse-Structure-Inpainting (2021): [USTC-JialunPeng/Diverse-Structure-Inpainting](https://github.com/USTC-JialunPeng/Diverse-Structure-Inpainting)

    Tensorflow 1

R-MNet (2021): [Jireh-Jam/R-MNet-Inpainting-keras](https://github.com/Jireh-Jam/R-MNet-Inpainting-keras)

    Not sure if there is much new and interesting stuff.

Hypergraphs (2021): [GouravWadhwa/Hypergraphs-Image-Inpainting](https://github.com/GouravWadhwa/Hypergraphs-Image-Inpainting)

    Uses custom conv layer (that is implemented with tensorflow). It sounds interesting, but I got errors when I tried to port it to pytorch.

PEPSI (2019): [Forty-lock/PEPSI-Fast_image_inpainting_with_parallel_decoding_network](https://github.com/Forty-lock/PEPSI-Fast_image_inpainting_with_parallel_decoding_network)

    The net dcpV2 uses.

Region (2019): [vickyFox/Region-wise-Inpainting](https://github.com/vickyFox/Region-wise-Inpainting)

--------------------------------------------------

``Pytorch generators that I never tested:``

SPL (2021): [WendongZh/SPL](https://github.com/WendongZh/SPL)

WTAM (2020): [ChenWang8750/WTAM](https://github.com/ChenWang8750/WTAM)

MPI (2020): [ChenWang8750/MPI-model](https://github.com/ChenWang8750/MPI-model)

Edge-LBAM (2021): [wds1998/Edge-LBAM](https://github.com/wds1998/Edge-LBAM)

VCNET (2020): [birdortyedi/vcnet-blind-image-inpainting](https://github.com/birdortyedi/vcnet-blind-image-inpainting)

    Blind image inpainting without masks.

DFMA (2020): [mprzewie/dmfa_inpainting](https://github.com/mprzewie/dmfa_inpainting)

GIN (2020): [rlct1/gin-sg](https://github.com/rlct1/gin-sg) and [rlct1/gin](https://github.com/rlct1/gin)

StructureFlow (2019): [RenYurui/StructureFlow](https://github.com/RenYurui/StructureFlow)

    Needs special files.

GMCNN (2018): [shepnerd/inpainting_gmcnn](https://github.com/shepnerd/inpainting_gmcnn)

    The net dcpV1 used iirc.

ShiftNet (2018): [Zhaoyi-Yan/Shift-Net_pytorch](https://github.com/Zhaoyi-Yan/Shift-Net_pytorch)

--------------------------------------------------
``Soon:``

ICT (2021): [raywzy/ICT](https://github.com/raywzy/ICT) (code released, waiting for pre-trained models)

MuFA-Net (2021): [ChenWang8750/MuFA-Net](https://github.com/ChenWang8750/MuFA-Net)

GCM-Net (2021): [ZhengHuanCS/GCM-Net](https://github.com/ZhengHuanCS/GCM-Net)

--------------------------------------------------

``No training code:``

SC-FEGAN (2019): [run-youngjoo/SC-FEGAN](https://github.com/run-youngjoo/SC-FEGAN)
