In [1]:
import sys
import os

pwd = os.getcwd()
python_path = pwd[: pwd.rfind("/")]
sys.path.append(python_path)

In [2]:
import torch
import lightning
import jupyter_black

from models import WGAN
from data import SubstratesDataModule
from constants import VALIDATION_SUBSTRATES_PATH
from lightning.pytorch.loggers import TensorBoardLogger

torch.set_float32_matmul_precision(precision="high")
jupyter_black.load()

In [4]:
# Data Module hyper parameters
BATCH_SIZE = 256
NUM_WORKERS = 0
SHUFFLE = True

# Trainer hyper parameteres
MAX_EPOCHS = 100

# WGAN hyper parameters
D_LR = 0.0002
G_LR = 0.001
G_OPTIM_FREQUENCY = 1
D_OPTIM_FREQUENCY = 3
G_LATENT_DIMS = 100

substrate_data_module = SubstratesDataModule(
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=SHUFFLE
)

wgan = WGAN(
    G_latent_dims=G_LATENT_DIMS,
    D_lr=D_LR,
    G_lr=G_LR,
    G_optim_frequency=G_OPTIM_FREQUENCY,
    D_optim_frequency=D_OPTIM_FREQUENCY,
    val_imgs_dir_path=VALIDATION_SUBSTRATES_PATH,
)

trainer = lightning.Trainer(
    max_epochs=MAX_EPOCHS,
    accelerator="gpu",
    devices=1,
    log_every_n_steps=10,
    logger=TensorBoardLogger(save_dir="../../logs/substrates_wgan"),
)
trainer.fit(wgan, substrate_data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type          | Params
---------------------------------------
0 | D    | Discriminator | 2.8 M 
1 | G    | Generator     | 3.6 M 
---------------------------------------
6.3 M     Trainable params
0         Non-trainable params
6.3 M     Total params
25.353    Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.
