<a href="https://colab.research.google.com/github/wandb/examples/blob/restorers/colabs/keras/restorers/Train_MirNetv2_Restorers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q --upgrade pip setuptools
!pip install git+https://github.com/soumik12345/restorers.git

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m31.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m30.5 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.9.0 requires jedi>=0.10, which is not installed.
cvxpy 1.2.3 requires setuptools<=64.0.2, but you have setuptools 67.6.0 which is incompatible.[0m[31m
[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/soumik12345/restorers.git
  Cloning https://github.com/soumik12345/restorers.git to /tmp/pip-req-build-xgghxtc_
  Running command git clone --filter=blob:none --quiet https://github.com/soumik12345/restorers.git /tmp/pip-req-build-xgghxtc_
  Resolved https://github.com/soumik

In [6]:
import wandb
import tensorflow as tf
from restorers.dataloader import LOLDataLoader

In [3]:
wandb.init(project="low-light-enhancement")

# define dataloader for the LoL dataset
data_loader = LOLDataLoader(
    # size of image crops on which we will train
    image_size=128,
    # bit depth of the images
    bit_depth=8,
    # fraction of images for validation
    val_split=0.2,
    # visualize the dataset on WandB or not
    visualize_on_wandb=True,
    # the wandb artifact address of the dataset,
    # this can be found from the `Usage` tab of
    # the aforemenioned weave panel
    dataset_artifact_address="ml-colabs/dataset/LoL:v0",
)

# call `get_datasets` on the `data_loader` to get
# the TensorFlow datasets corresponding to the 
# training and validation splits
datasets = data_loader.get_datasets(batch_size=2)
train_dataset, val_dataset = datasets

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


[34m[1mwandb[0m: Downloading large artifact LoL:v0, 331.95MB. 1003 files... 
[34m[1mwandb[0m:   1003 of 1003 files downloaded.  
Done. 0:0:41.9


Generating visualizations for Train images:   0%|          | 0/388 [00:00<?, ?it/s]

Generating visualizations for Validation images:   0%|          | 0/97 [00:00<?, ?it/s]

Generating visualizations for Test images:   0%|          | 0/15 [00:00<?, ?it/s]

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


In [4]:
# import MirNetv2 from restorers
from restorers.model import MirNetv2


# define the MirNetv2 model; this gives us a `tf.keras.Model`
model = MirNetv2(
    # number of channels in the feature map
    channels=80,
    # number of multi-scale residual blocks
    channel_factor=1.5,
    # factor by which number of the number of output channels vary
    num_mrb_blocks=2,
    # number of groups in which the input is split along the
    # channel axis in the convolution layers.
    add_residual_connection=True,
)

In [8]:
from restorers.losses import CharbonnierLoss
# import Peak Signal-to-Noise Ratio and Structural Similarity metrics,
# implemented as part of restorers
from restorers.metrics import PSNRMetric, SSIMMetric


loss = CharbonnierLoss(
    # a small constant to avoid division by zero
    epsilon=1e-3,
    # type of reduction applied to the loss, it needs to be
    # explicitly specified in case of distributed training
    reduction=tf.keras.losses.Reduction.SUM,
)


optimizer = tf.keras.optimizers.experimental.AdamW(learning_rate=2e-4,)

psnr_metric = PSNRMetric(max_val=1.0) # peak signal-to-noise ratio metric
ssim_metric = SSIMMetric(max_val=1.0) # structural similarity metric

model.compile(
    optimizer=optimizer, loss=loss, metrics=[psnr_metric, ssim_metric]
)

In [9]:
# import the wandb callbacks for keras
from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint


callbacks = [
    # define the metrics logger callback;
    # we set the `log_freq="batch"` explicitly
    # to the metrics are logged both batch-wise and epoch-wise
    WandbMetricsLogger(log_freq="batch"),
    # define the model checkpoint callback
    WandbModelCheckpoint(
        filepath="checkpoint",
        monitor="val_loss",
        save_best_only=False,
        save_weights_only=False,
        initial_value_threshold=None,
    )
]

# call model.fit()
model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=50,
    callbacks=callbacks,
)

Epoch 1/50

[34m[1mwandb[0m: Adding directory to artifact (./checkpoint)... Done. 0.1s


Epoch 2/50

[34m[1mwandb[0m: Adding directory to artifact (./checkpoint)... Done. 0.1s


Epoch 3/50
 24/194 [==>...........................] - ETA: 54s - loss: 0.1467 - psnr_metric: 15.4093 - ssim_metric: 0.6662

KeyboardInterrupt: ignored

In [None]:
wandb.finish()