# Colab-traiNNer

victorca25's BasicSR fork: [victorca25/traiNNer](https://github.com/victorca25/traiNNer)

Original colab by [nmkd](https://github.com/n00mkrad) with modifications by [styler00dollar](https://github.com/styler00dollar)

In [None]:
#@title Check GPU

gpu = !nvidia-smi --query-gpu=gpu_name --format=csv
print("GPU: " + gpu[1])

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Install

In [None]:
#@title Install

!rm -r "/content/traiNNer"
!mkdir "/content/traiNNer"

from datetime import datetime, timedelta
from IPython import display as ipythondisplay
from IPython.display import Image as ipythonimage
import os
import fileinput
import sys

sedloc = ""
%cd /content/
#Install apt-fast, for faster installing
!/bin/bash -c "$(curl -sL https://git.io/vokNn)"
#Get some basic dependencies
!apt-fast install -y -q -q p7zip-full p7zip-rar

# Clone traiNNer
!rm -r /content/traiNNer
%cd /content/
#!git clone "https://github.com/victorca25/traiNNer.git"
!git clone https://github.com/styler00dollar/Colab-traiNNer traiNNer
%cd /content/traiNNer
!pip install pytorch_lightning==1.2.5 timm adamp opencv-python tensorboardX pyyaml
!pip install git+https://github.com/shijianjian/EfficientNet-PyTorch-3D

# create empty folders
!mkdir /content/hr
!mkdir /content/lr
!mkdir /content/lr_val
!mkdir /content/hr_val

print('Done.')

# Dataset

Pre-configured paths:

LR: ```/content/lr```

HR: ```/content/hr```

LR_VAL: ```/content/lr_val```

HR_VAL: ```/content/hr_val```

# Download data

You need to upload the data and then extract it within colab. You can use Google Drive for that.

In [None]:
%cd /content/
!cp "/content/drive/MyDrive/animeinterp_mod_datset.tar" "/content/data.tar"
!7z x /content/data.tar

# Train

In [None]:
#@title config example
%%writefile /content/traiNNer/code/config.yaml
name: template
scale: 1
gpus: 1 # amount of gpus, 0 = cpu
distributed_backend: ddp # dp, ddp (for multi-gpu training)
tpu_cores: 8 # 8 if you use a Google Colab TPU
use_tpu: False
use_amp: False
use_swa: False
progress_bar_refresh_rate: 20
default_root_dir: '/content'

# Dataset options:
datasets:
  train:
    # DS_inpaint: hr is from dataroot_HR, loads masks
    # DS_inpaint_tiled: hr is from dataroot_HR, but images are grids (16x16 256px currently), loads masks
    # DS_inpaint_tiled_batch: hr is from dataroot_HR, but images are grids (16x16 256px currently) and processed as a batch (batch_size_DL), loads masks
    # DS_lrhr: loads lr from dataroot_LR and hr from dataroot_HR
    # DS_lrhr_batch_oft: loads grayscale-hr (3x3 400px) from dataroot_LR and generates lr by downscaling otf randomly

    mode: DS_video # DS_video | DS_inpaint | DS_inpaint_tiled | DS_inpaint_tiled_batch | DS_lrhr | DS_lrhr_batch_oft
    grayscale: False # If true, reads as 1-Channel image. If false, reads as 3-channel image. Currently only implemented for DS_lrhr_batch_oft
    dataroot_HR: '/content/data' # Original, with a single directory. Inpainting will use this directory as source image.
    dataroot_LR: '/content/data' # Original, with a single directory
    loading_backend: 'PIL' # PIL, OpenCV (currently only in DS_inpaint_tiled_batch)

    n_workers: 2 # 0 to disable CPU multithreading, or an integrer representing CPU threads to use for dataloading
    batch_size: 1
    batch_size_DL: 1
    HR_size: 512 # The resolution the network will get. Random crop gets applied if that resolution does not match.
    image_channels: 3 # number of channels to load images in

    masks: '/content/masks/'
    max_epochs: 2000
    save_step_frequency: 5000 # also validation frequency

    # batch
    # If a tiled dataloader is used, specify image charactaristics. Random crop will not be applied. Maybe in the future.
    image_size: 512 # Size of one tile
    amount_tiles: 2 # Amount of tiles inside the merged grid image

    # if edge data is required
    canny_min: 100
    canny_max: 150

  val:
    dataroot_HR: '/content/data'
    dataroot_LR: '/content/data' # Inpainting will use this directory as input

path:
    pretrain_model_G:
    pretrain_model_D: 
    checkpoint_path:
    checkpoint_save_path: '/content/'
    validation_output_path: '/content/val/'
    log_path: '/content/logs'

# Generator options:
network_G:
    # CEM (for esrgan, not 1x)
    CEM: False # uses hardcoded torch.cuda.FloatTensor
    sigmoid_range_limit: False

    finetune: False # Important for further rfr/dsnet training. Apply that after training for a while. https://github.com/jingyuanli001/RFR-Inpainting/issues/33

    # ESRGAN:
    #netG: RRDB_net # RRDB_net (original ESRGAN arch) | MRRDB_net (modified/"new" arch)
    #norm_type: null
    #mode: CNA
    #nf: 64 # of discrim filters in the first conv layer (default: 64, good: 32)
    #nb: 23 # (default: 23, good: 8)
    #in_nc: 3 # of input image channels: 3 for RGB and 1 for grayscale
    #out_nc: 3 # of output image channels: 3 for RGB and 1 for grayscale
    #gc: 32
    #group: 1
    #convtype: Conv2D # Conv2D | PartialConv2D
    #net_act: leakyrelu # swish | leakyrelu
    #gaussian: false # true | false # esrgan plus, does not work on TPU because of cuda()
    #plus: false # true | false
    #finalact: None #tanh # Test. Activation function to make outputs fit in [-1, 1] range. Default = None. Coordinate with znorm.
    #upsample_mode: 'upconv'
    #nr: 3

    # ASRGAN:
    #which_model_G: asr_resnet # asr_resnet | asr_cnn
    #nf: 64

    # PPON:
    #netG: ppon # | ppon
    ##norm_type: null
    #mode: CNA
    #nf: 64
    #nb: 24
    #in_nc: 3
    #out_nc: 3
    ##gc: 32
    #group: 1
    ##convtype: Conv2D #Conv2D | PartialConv2D

    # SRGAN:
    #netG: sr_resnet # RRDB_net | sr_resnet
    #norm_type: null
    #mode: CNA
    #nf: 64
    #nb: 16
    #in_nc: 3
    #out_nc: 3

    # SR:
    #netG: RRDB_net # RRDB_net | sr_resnet
    #norm_type: null
    #mode: CNA
    #nf: 64
    #nb: 23
    #in_nc: 3
    #out_nc: 3
    #gc: 32
    #group: 1

    # PAN:
    # netG: pan_net
    # in_nc: 3
    # out_nc: 3
    # nf: 40
    # unf: 24
    # nb: 16
    # self_attention: true
    # double_scpa: false

    # edge-informed-sisr
    #which_model_G: sisr
    #use_spectral_norm: True

    # USRNet
    #netG: USRNet
    #in_nc=4
    #out_nc=3
    #nc=[64, 128, 256, 512]
    #nb=2
    #act_mode='R'
    #downsample_mode='strideconv'
    #upsample_mode='convtranspose'

    # GLEAN (2021)
    # Warning: Does require "pip install mmcv-full"
    #netG: GLEAN
    #in_size: 512
    #out_size: 512
    #img_channels: 4
    #img_channels_out: 3
    #rrdb_channels: 16 # 64
    #num_rrdbs: 8 # 23
    #style_channels: 512 # 512
    #num_mlps: 4 # 8
    #channel_multiplier: 2
    #blur_kernel: [1, 3, 3, 1]
    #lr_mlp: 0.01
    #default_style_mode: 'mix'
    #eval_style_mode: 'single'
    #mix_prob: 0.9
    #pretrained: False # only works with official settings
    #bgr2rgb: False

    # srflow (upscaling factors: 4, 8, 16)
    # Warning: Can be very unstable with batch_size 1, use higher batch_size
    #netG: srflow
    #in_nc: 3
    #out_nc: 3
    #nf: 64
    #nb: 23
    #train_RRDB: false
    #train_RRDB_delay: 0.5
    #flow:
    #  K: 16
    #  L: 3
    #  noInitialInj: true
    #  coupling: CondAffineSeparatedAndCond
    #  additionalFlowNoAffine: 2
    #  split:
    #    enable: true
    #  fea_up0: true
    #  stackRRDB:
    #    blocks: [ 1, 8, 15, 22 ]
    #    concat: true
    #nll_weight: 1
    #freeze_iter: 100000

    # DFDNet
    # Warning: Expects "DictionaryCenter512" in the current folder, you can get the data here: https://drive.google.com/drive/folders/1bayYIUMCSGmoFPyd4Uu2Uwn347RW-vl5
    # Also wants a folder called "landmarks", you can generate that data yourself. Example: https://github.com/styler00dollar/Colab-DFDNet/blob/local/Colab-DFDNet-lightning-train.ipynb
    # Hardcoded resolution: 512px
    #netG: DFDNet
    #dictionary_path: "/content/DictionaryCenter512"
    #landmarkpath: "/content/landmarks"
    #val_landmarkpath: "/content/landmarks"

    # GFPGAN (2021) [EXPERIMENTAL]
    # does require ninja
    # because it compiles files, the startup time is quite long
    #netG: GFPGAN
    #input_channels: 4
    #output_channels: 3
    #out_size: 512
    #num_style_feat: 512
    #channel_multiplier: 1
    #resample_kernel: [1, 3, 3, 1]
    #decoder_load_path: # None
    #fix_decoder: True
    #num_mlp: 8
    #lr_mlp: 0.01
    #input_is_latent: False
    #different_w: False
    #narrow: 1
    #sft_half: False

    # GPEN
    # does require ninja
    # because it compiles files, the startup time is quite long
    # output_channels is hardcoded to 3
    #netG: GPEN
    #input_channels: 4
    #size: 512
    #style_dim: 512
    #n_mlp: 8
    #channel_multiplier: 2
    #blur_kernel: [1, 3, 3, 1]
    #lr_mlp: 0.01
    #pooling: True # Experimental, to have any input size

    # comodgan (2021)
    # needs ninja
    # because it compiles files, the startup time is quite long
    #netG: comodgan
    #dlatent_size: 512
    #num_channels: 3 # amount of channels without mask
    #resolution: 512
    #fmap_base: 16384 # 16 << 10
    #fmap_decay: 1.0
    #fmap_min: 1
    #fmap_max: 512
    #randomize_noise: True
    #architecture: 'skip'
    #nonlinearity: 'lrelu'
    #resample_kernel: [1,3,3,1]
    #fused_modconv: True
    #pix2pix: False
    #dropout_rate: 0.5
    #cond_mod: True
    #style_mod: True
    #noise_injection: True

    # ----Inpainting Generators----
    # DFNet (batch_size: 2+, needs 2^x image input and validation) (2019)
    #netG: DFNet
    #c_img: 3
    #c_mask: 1
    #c_alpha: 3
    #mode: nearest
    #norm: batch
    #act_en: relu
    #act_de: leaky_relu
    #en_ksize: [7, 5, 5, 3, 3, 3, 3, 3]
    #de_ksize: [3, 3, 3, 3, 3, 3, 3, 3]
    #blend_layers: [0, 1, 2, 3, 4, 5]
    #conv_type: normal # partial | normal | deform
    

    # EdgeConnect (2019)
    #netG: EdgeConnect
    #use_spectral_norm: True
    #residual_blocks_edge: 8
    #residual_blocks_inpaint: 8
    #conv_type_edge: 'normal' # normal | partial | deform (has no spectral_norm)
    #conv_type_inpaint: 'normal' # normal | partial | deform

    # CSA (2019)
    #netG: CSA
    #c_img: 3
    #norm: 'instance'
    #act_en: 'leaky_relu'
    #act_de: 'relu'

    # RN (2020)
    #netG: RN
    #input_channels: 3
    #residual_blocks: 8
    #threshold: 0.8

    # deepfillv1 (2018)
    #netG:  deepfillv1

    # deepfillv2 (2019)
    #netG: deepfillv2
    #in_channels:  4
    #out_channels:  3
    #latent_channels:  64
    #pad_type:  'zero'
    #activation:  'lrelu'
    #norm: 'in'
    #conv_type: partial # partial | normal

    # Adaptive (2020)
    #netG: Adaptive
    #in_channels: 3
    #residual_blocks: 1
    #init_weights: True

    # Global (2020)
    #netG: Global
    #input_dim: 5
    #ngf: 32
    #use_cuda: True
    #device_ids: [0]

    # Pluralistic (2019)
    #netG: Pluralistic
    #ngf_E: 32
    #z_nc_E: 128
    #img_f_E: 128
    #layers_E: 5
    #norm_E: 'none'
    #activation_E: 'LeakyReLU'
    #ngf_G: 32
    #z_nc_G: 128
    #img_f_G: 128
    #L_G: 0
    #output_scale_G: 1
    #norm_G: 'instance'
    #activation_G: 'LeakyReLU'

    # crfill (2020)
    #netG: crfill
    #cnum: 48

    # DeepDFNet (experimental)
    #netG: DeepDFNet
    #in_channels:  4
    #out_channels:  3
    #latent_channels:  64
    #pad_type:  'zero'
    #activation:  'lrelu'
    #norm: 'in'

    # partial (2018)
    #netG: partial

    # DMFN (2020)
    #netG: DMFN
    #in_nc: 4
    #out_nc: 3
    #nf: 64
    #n_res: 8
    #norm: 'in'
    #activation: 'relu'

    # pennet (2019)
    #netG: pennet

    # LBAM (2019)
    #netG: LBAM
    #inputChannels: 4
    #outputChannels: 3

    # RFR (use_swa: false, no TPU) (2020)
    #netG: RFR
    #conv_type: partial # partial | deform

    # FRRN (2019)
    #netG: FRRN

    # PRVS (2019)
    #netG: PRVS

    # CRA (HR_size: 512) (2020)
    #netG: CRA
    #activation: 'elu'
    #norm: 'none'

    # atrous (2020)
    #netG: atrous

    # MEDFE (batch_size: 1) (2020)
    #netG: MEDFE

    # AdaFill (2021)
    #netG: AdaFill

    # lightweight_gan (2021)
    #netG: lightweight_gan
    #image_size: 512
    #latent_dim: 256
    #fmap_max: 512
    #fmap_inverse_coef: 12
    #transparent: False
    #greyscale: False
    #freq_chan_attn: False


    # ----Interpolation Generators----
    netG: CAIN
    depth: 3

# Discriminator options:
network_D:
    d_loss_fool_weight: 1 # inside the generator loop, trying to fool the disciminator
    d_loss_weight: 1 # inside own discriminator update
    
    #netD: # in case there is no discriminator, leave it empty

    # VGG
    #netD: VGG
    #size: 256
    #in_nc: 3 #3
    #base_nf: 64
    #norm_type: 'batch'
    #act_type: 'leakyrelu'
    #mode: 'CNA'
    #convtype: 'Conv2D'
    #arch: 'ESRGAN'

    # VGG fea
    #netD: VGG_fea
    #size: 256
    #in_nc: 3
    #base_nf: 64
    #norm_type: 'batch'
    #act_type: 'leakyrelu'
    #mode: 'CNA'
    #convtype: 'Conv2D'
    #arch: 'ESRGAN'
    #spectral_norm: False
    #self_attention: False
    #max_pool: False
    #poolsize: 4


    #netD: VGG_128_SN

    # VGGFeatureExtractor
    #netD: VGGFeatureExtractor
    #feature_layer: 34
    #use_bn: False
    #use_input_norm: True
    #device: 'cpu'
    #z_norm: False

    # PatchGAN
    #netD: NLayerDiscriminator
    #input_nc: 3
    #ndf: 64
    #n_layers: 3
    #norm_layer: nn.BatchNorm2d
    #use_sigmoid: False
    #getIntermFeat: False
    #patch: True
    #use_spectral_norm: False

    # Multiscale
    #netD: MultiscaleDiscriminator
    #input_nc: 3
    #ndf: 64
    #n_layers: 3
    #norm_layer: nn.BatchNorm2d
    #se_sigmoid: False
    #num_D: 3
    #getIntermFeat: False

    #netD: ResNet101FeatureExtractor
    #use_input_norm: True
    #device: 'cpu'
    #z_norm: False

    # MINC
    #netD: MINCNet

    # Pixel
    #netD: PixelDiscriminator
    #input_nc: 3
    #ndf: 64
    #norm_layer: nn.BatchNorm2d

    # EfficientNet (3-channel input)
    #netD: EfficientNet
    #EfficientNet_pretrain: 'efficientnet-b0'
    #num_classes: 1 # should be 1

    # ResNeSt (not working)
    #netD: ResNeSt
    #ResNeSt_pretrain: 'resnest50' # ["resnest50", "resnest101", "resnest200", "resnest269"]
    #pretrained: False # cant be true currently
    #num_classes: 1

    # Transformer (not working)
    #netD: TranformerDiscriminator
    #img_size: 256
    #patch_size: 1
    #in_chans: 3
    #num_classes: 1
    #embed_dim: 64
    #depth: 7
    #num_heads: 4
    #mlp_ratio: 4.
    #qkv_bias: False
    #qk_scale: None
    #drop_rate: 0.
    #attn_drop_rate: 0.
    #drop_path_rate: 0.
    #hybrid_backbone: None
    #norm_layer: 

    # context_encoder (num_classes can't be set, broadcasting warning will be shown, training works, but I am not sure if it will work correctly)
    #netD: context_encoder

    # Transformer (doesn't do init)
    #netD: ViT
    #image_size: 256
    #patch_size: 32
    #num_classes: 1
    #dim: 1024
    #depth: 6
    #heads: 16
    #mlp_dim: 2048
    #dropout: 0.1
    #emb_dropout: 0.1

    # Transformer (doesn't do init)
    #netD: DeepViT
    #image_size: 256
    #patch_size: 32
    #num_classes: 1
    #dim: 1024
    #depth: 6
    #heads: 16
    #mlp_dim: 2048
    #dropout: 0.1
    #emb_dropout: 0.1

    # RepVGG
    #netD: RepVGG
    #RepVGG_arch: RepVGG-A0 # RepVGG-A0, RepVGG-A1, RepVGG-A2, RepVGG-B0, RepVGG-B1, RepVGG-B1g2, RepVGG-B1g4, , RepVGG-B2, RepVGG-B2g2, RepVGG-B2g4, RepVGG-B3, RepVGG-B3g2, RepVGG-B3g4
    #num_classes: 1

    # squeezenet
    #netD: squeezenet
    #version: "1_1" # 1_0, 1_1
    #num_classes: 1

    # SwinTransformer (doesn't do init)
    #netD: SwinTransformer
    #hidden_dim: 96
    #layers: [2, 2, 6, 2]
    #heads: [3, 6, 12, 24]
    #channels: 3
    #num_classes: 1
    #head_dim: 32
    #window_size: 8
    #downscaling_factors: [4, 2, 2, 2]
    #relative_pos_embedding: True

    # mobilenetV3 (doesn't do init)
    #netD: mobilenetV3
    #mode: small # small, large
    #n_class: 1
    #input_size: 256

    # resnet
    #netD: resnet
    #resnet_arch: resnet50 # resnet50, resnet101, resnet152
    #num_classes: 1
    #pretrain: True
  
    # NFNet
    #netD: NFNet
    #num_classes: 1
    #variant: 'F0'         # F0 - F7
    #stochdepth_rate: 0.25 # 0-1, the probability that a layer is dropped during one step
    #alpha: 0.2            # Scaling factor at the end of each block
    #se_ratio: 0.5         # Squeeze-Excite expansion ratio
    #activation: 'gelu'    # or 'relu'

    # lvvit (2021)
    # Warning: Needs 'pip install timm==0.4.5'
    #netD: lvvit
    #img_size: 224
    #patch_size: 16
    #in_chans: 3
    #num_classes: 1
    #embed_dim: 768
    #depth: 12
    #num_heads: 12
    #mlp_ratio: 4.
    #qkv_bias: False
    #qk_scale: # None
    #drop_rate: 0.
    #attn_drop_rate: 0.
    #drop_path_rate: 0.
    #drop_path_decay: 'linear'
    #hybrid_backbone: # None
    ##norm_layer: nn.LayerNorm # Deafault: nn.LayerNorm / can't be configured
    #p_emb: '4_2'
    #head_dim: # None
    #skip_lam: 1.0
    #order: # None
    #mix_token: False
    #return_dense: False

    # timm
    # pip install timm
    # you can loop up models here: https://rwightman.github.io/pytorch-image-models/
    #netD: timm
    #timm_model: "tf_efficientnetv2_b0"

    netD: resnet3d
    model_depth: 50 # [10, 18, 34, 50, 101, 152, 200]

train: 
    scheduler: AdamP # Adam, AdamP, Adam, SGDP, MADGRAD, cosangulargrad [maybe broken], tanangulargrad [maybe broken]
    lr: 0.000001
    
    # AdamP, AGDP, MADGRAD, cosangulargrad, tanangulargrad
    weight_decay: 0.01

    # SGDP, MAGDRAD
    momentum: 0.9

    # AdamP, cosangulargrad, tanangulargrad
    betas0: 0.9
    betas1: 0.999
    
    # SGDP
    nesterov: True

    # MADGRAD, cosangulargrad, tanangulargrad
    eps: 1e-6

    # Losses:
    L1Loss_weight: 0

    # HFENLoss
    HFEN_weight: 0
    loss_f: L1CosineSim # L1Loss | L1CosineSim
    kernel: 'log'
    kernel_size: 15
    sigma: 2.5
    norm: False

    # Elastic
    Elatic_weight: 0
    a: 0.2
    reduction_elastic: 'mean'

    # Relative L1
    Relative_l1_weight: 0
    eps: .01
    reduction_realtive: 'mean'

    # L1CosineSim (3-channel input)
    L1CosineSim_weight: 0
    loss_lambda: 5
    reduction_L1CosineSim: 'mean'

    # ClipL1
    ClipL1_weight: 0
    clip_min: 0.0
    clip_max: 10.0

    # FFTLoss
    FFTLoss_weight: 1
    loss_f_fft: L1Loss
    reduction_fft: 'mean'

    OFLoss_weight: 0

    # GPLoss
    GPLoss_weight: 0
    trace: False
    spl_denorm: False

    # CPLoss
    CPLoss_weight: 0
    rgb: True
    yuv: True
    yuvgrad: True
    trace: False
    spl_denorm: False
    yuv_denorm: False

    StyleLoss_weight: 0

    # TVLoss
    TVLoss_weight: 0.00001
    tv_type: 'tv'
    p: 1

    # Contextual_Loss (3-channel input)
    Contexual_weight: 0
    crop_quarter: False
    max_1d_size: 100
    distance_type: 'cosine' # ["11", "l2", "consine"]
    b: 1.0
    band_width: 0.5
    # for vgg
    use_vgg: False
    net_contextual: 'vgg19'
    layers_weights: {'conv_1_1': 1.0, 'conv_3_2': 1.0}
    # for timm
    use_timm: True
    timm_model: "tf_efficientnetv2_b0"
    # for both
    calc_type: 'regular' # ["bilateral" | "symetric" | None]

    # Style (3-channel input)
    StyleLoss_weight: 0

    # PerceptualLoss
    perceptual_weight: 1
    net: PNetLin # PNetLin, DSSIM (?)
    pnet_type: 'vgg' # alex, squeeze, vgg
    pnet_rand: False
    pnet_tune: False
    use_dropout: True
    spatial: False
    version: '0.1' # only version
    lpips: True

    # only if the network outputs 2 images, will use l1
    stage1_weight: 0 

    # Differentiable Augmentation for Data-Efficient GAN Training
    diffaug: False
    policy: 'color,translation,cutout'

    # Metrics
    metrics: [] # PSNR | SSIM | AE | MSE

In [None]:
%cd '/content/traiNNer/code'
!python train.py