In [1]:
import os
import sys
import pytorch_lightning as pl

base_dir = os.path.dirname(os.getcwd())
sys.path.append(base_dir)

from src.experiments import cGANCloudTOPtoRGB, Logger
from src.models import Unet, PatchGAN
from src.datasets import CloudTOPtoRGBDataset
from src.utils import load_yaml

# Session Setup

### Specify the session global variables 

In [2]:
# Path to configuration file
cfg_path = '../config/dummy.yaml'

# Path to experiment outputs directory
output_dir = 'sandbox/'

# Id of GPU on which computation will take place - if no GPU, will go on CPU by default
gpu_id = 0

# Global random seed to set for reproducibility
seed = 73

### Load experiment configuration file

Everything we need to setup the experiment is contained in this configuration file that we'll be using throughout this notebook

In [3]:
cfg = load_yaml(cfg_path)
# cfg

### Define logger to get some outputs during training

In [4]:
logger = Logger(save_dir=os.path.dirname(output_dir),
                name=os.path.basename(output_dir))
logger

<src.experiments.utils.loggers.Logger at 0x12de57438>

### Define an utility which will save model weights checkpoint for us

In [5]:
model_checkpoint = pl.callbacks.ModelCheckpoint(**cfg['model_checkpoint'])
model_checkpoint

<pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint at 0x12de57208>

### Finally, don't forget to seed the run

In [6]:
pl.seed_everything(seed)

73

# Experiment Setup

__Here we're going to define instance for the dataset, Unet and discriminator for the GANs training, and it's as simple as :__

In [4]:
generator = Unet.build(cfg['model']['generator'])
discriminator = PatchGAN.build(cfg['model']['discriminator'])
dataset = CloudTOPtoRGBDataset.build(cfg['dataset'])

In [8]:
# Run those if you want to look at what the models looks like
# print(generator)
# print(discriminator)

__Now let's encapsulate all these things into one single `experiment` instance that specifies how training should be performed__

In [9]:
experiment = cGANCloudTOPtoRGB(generator=generator,
                               discriminator=discriminator,
                               dataset=dataset,
                               split=list(cfg['dataset']['split'].values()),
                               optimizer_kwargs=cfg['optimizer'],
                               lr_scheduler_kwargs=cfg['lr_scheduler'],
                               dataloader_kwargs=cfg['dataset']['dataloader'],
                               supervision_weight_l1=cfg['experiment']['supervision_weight_l1'],
                               supervision_weight_ssim=cfg['experiment']['supervision_weight_ssim'],
                               seed=cfg['experiment']['seed'])

# Run Training

__This is the final step, we're going to feed our experiment to a `Trainer` that will run it__

In [10]:
trainer = pl.Trainer(logger=logger,
                     checkpoint_callback=model_checkpoint,
                     precision=cfg['experiment']['precision'],
                     max_epochs=cfg['experiment']['max_epochs'],
                     gpus=gpu_id)
trainer

GPU available: False, used: False
No environment variable for node rank defined. Set as 0.


<pytorch_lightning.trainer.trainer.Trainer at 0x1295f9588>

__Execute training__

In [12]:
trainer.fit(experiment)


   | Name                                    | Type            | Params
------------------------------------------------------------------------
0  | model                                   | Unet            | 16 M  
1  | model.encoder                           | Encoder         | 11 M  
2  | model.encoder.encoding_layers           | Sequential      | 11 M  
3  | model.encoder.encoding_layers.0         | Conv2d          | 3 K   
4  | model.encoder.encoding_layers.0.conv    | Conv2d          | 3 K   
5  | model.encoder.encoding_layers.1         | Conv2d          | 131 K 
6  | model.encoder.encoding_layers.1.conv    | Conv2d          | 131 K 
7  | model.encoder.encoding_layers.1.bn      | BatchNorm2d     | 256   
8  | model.encoder.encoding_layers.1.relu    | PReLU           | 1     
9  | model.encoder.encoding_layers.2         | Conv2d          | 525 K 
10 | model.encoder.encoding_layers.2.conv    | Conv2d          | 524 K 
11 | model.encoder.encoding_layers.2.bn      | BatchNorm2d    

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Detected KeyboardInterrupt, attempting graceful shutdown...





1

While training you model, you can visualize output logs with `tensorboard`. 

To do so, go to your previously defined `output_dir` and run `tensorboard --logdir=output_dir --port=6008`

Then go to your browser and type `localhost:6008`, it might take some time to load