In [1]:
%load_ext autoreload
%autoreload 2

In [11]:
from spectral_networks.nn.models import ReLUNet
from transform_datasets.patterns.synthetic import *
from transform_datasets.transforms import *
from transform_datasets.utils.wandb import load_or_create_dataset
from torch_tools.trainer import Trainer
from torch_tools.logger import WBLogger
from torch_tools.config import Config
from torch_tools.regularizer import Regularizer, MultiRegularizer
from torch_tools.functional import l1_norm

from torch_tools.data import TrainValLoader
from pytorch_metric_learning import losses, distances
from torch.optim import Adam

from torch_tools.plotter import Plotter, MultiPlotter

In [3]:
DATA_PROJECT = "dataset"
MODEL_PROJECT = "bispectrum"
ENTITY = "naturalcomputation"
DEVICE = "cuda:0"
SEED = 0

"""
DATASET
"""

dataset_config = Config(
    {
        "type": HarmonicsS1,
        "params": {"dim": 256, "n_classes": 10, "seed": 5},
    }
)

transforms_config = {
    "0": Config(
        {
            "type": CyclicTranslation1D,
            "params": {
                "fraction_transforms": 1.0,
                "sample_method": "linspace",
            },
        }
    ),
    "1": Config(
        {
            "type": UniformNoise,
            "params": {"n_samples": 1, "magnitude": 0.1},
        }
    ),
}


tdataset_config = {"dataset": dataset_config, "transforms": transforms_config}

dataset = load_or_create_dataset(tdataset_config, DATA_PROJECT, ENTITY)

"""
DATA_LOADER
"""

data_loader_config = Config(
    {
        "type": TrainValLoader,
        "params": {
            "batch_size": 32,
            "fraction_val": 0.2,
            "num_workers": 1,
            "seed": SEED,
        },
    }
)

data_loader = data_loader_config.build()
data_loader.load(dataset)

In [7]:
"""
MODEL
"""
model_config = Config(
    {
        "type": ReLUNet,
        "params": {
            "size_in": dataset.dim,
            "hdim": [256],
            "seed": SEED,
            "device": 'cuda:0'
        },
    }
)
model = model_config.build()

"""
OPTIMIZER
"""
optimizer_config = Config({"type": Adam, "params": {"lr": 0.001}})
# optimizer = optimizer_config.build()


'''
REGULARIZER
'''
regularizer_config1 = Config({'type': Regularizer, 'params': {'function': l1_norm, 
                                              'variables': ['out'],
                                              'coefficient': 0.1
                                             }
                         })

regularizer_config2 = Config({'type': Regularizer, 'params': {'function': l1_norm, 
                                              'variables': ['out'],
                                              'coefficient': 1
                                             }
                         })

multiregularizer_config = Config({'type': MultiRegularizer, 'params': {'regularizer_configs': [regularizer_config1, regularizer_config2]}})
regularizer = multiregularizer_config.build()

plotter_config1 = Config({'type': Plotter, 'params': {'function': gen_UVW_analysis_plots_1D, 
                                              'variables': ['model'],
                                              'f_params': {'use_wandb': True}
                                             }
                         })

plotter_config2 = Config({'type': Plotter, 'params': {'function': gen_avg_data_spectrum_plot_1D, 
                                              'variables': ['X'],
                                              'f_params': {'use_wandb': True}
                                             }
                         })

multiplotter_config = Config({'type': MultiPlotter, 'params': {'plotter_configs': [plotter_config1, plotter_config2]}})
multiplotter = multiplotter_config.build()

"""
LOSS
"""
loss_config = Config(
    {
        "type": losses.ContrastiveLoss,
        "params": {
            "pos_margin": 0,
            "neg_margin": 1,
            "distance": distances.LpDistance(),
        },
    }
)
loss = loss_config.build()

"""
MASTER CONFIG
"""

config = {
    "dataset": dataset_config,
    "model": model_config,
    "optimizer": optimizer_config,
    "loss": loss_config,
    "data_loader": data_loader_config,
}

"""
LOGGING
"""
logging_config = Config(
    {
        "type": WBLogger,
        "params": {
            "config": config,
            "project": MODEL_PROJECT,
            "entity": ENTITY,
            "log_interval": 10,
            "watch_interval": 10 * len(data_loader.train),
        },
    }
)

logger = logging_config.build()


[34m[1mwandb[0m: wandb version 0.12.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [8]:
"""
TRAINER
"""

training_config = Config(
    {
        "type": Trainer,
        "params": {
            "model": model,
            "loss": loss,
            "logger": logger,
            "regularizer": regularizer,
            "device": DEVICE,
            "optimizer_config": optimizer_config,
        },
    }
)

trainer = training_config.build()

In [9]:
trainer.train(data_loader, epochs = 100)

Epoch 0 ||  N Examples 0 || Train Total Loss 1220.19434 || Validation Total Loss 591.98187
Epoch 1 ||  N Examples 2560 || Train Total Loss 411.13034 || Validation Total Loss 313.99680
Epoch 2 ||  N Examples 5120 || Train Total Loss 311.36636 || Validation Total Loss 295.79642
Epoch 3 ||  N Examples 7680 || Train Total Loss 289.18088 || Validation Total Loss 301.02658
Epoch 4 ||  N Examples 10240 || Train Total Loss 289.71671 || Validation Total Loss 287.59137
Epoch 5 ||  N Examples 12800 || Train Total Loss 283.72217 || Validation Total Loss 281.79419
Epoch 6 ||  N Examples 15360 || Train Total Loss 276.84799 || Validation Total Loss 282.78635
Epoch 7 ||  N Examples 17920 || Train Total Loss 269.70743 || Validation Total Loss 284.57587
Epoch 8 ||  N Examples 20480 || Train Total Loss 268.85480 || Validation Total Loss 273.43777
Epoch 9 ||  N Examples 23040 || Train Total Loss 265.91498 || Validation Total Loss 290.03711
Epoch 10 ||  N Examples 25600 || Train Total Loss 257.00452 || Val

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_loss,1.19611
train_reg_loss,255.80835
train_total_loss,257.00452
epoch,10.0
n_examples,25600.0
_runtime,16.0
_timestamp,1628733480.0
_step,3.0
val_loss,1.18004
val_reg_loss,259.15884


0,1
train_loss,█▁
train_reg_loss,█▁
train_total_loss,█▁
epoch,▁▁██
n_examples,▁▁██
_runtime,▁▂██
_timestamp,▁▂██
_step,▁▃▆█
val_loss,█▁
val_reg_loss,█▁


In [10]:
trainer.resume(data_loader, epochs=100)

[34m[1mwandb[0m: wandb version 0.12.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Epoch 10 ||  N Examples 28160 || Train Total Loss 248.80884 || Validation Total Loss 258.52383
Epoch 11 ||  N Examples 30720 || Train Total Loss 243.10643 || Validation Total Loss 250.33142
Epoch 12 ||  N Examples 33280 || Train Total Loss 241.74835 || Validation Total Loss 246.22932
Epoch 13 ||  N Examples 35840 || Train Total Loss 235.39255 || Validation Total Loss 246.32678
Epoch 14 ||  N Examples 38400 || Train Total Loss 227.36395 || Validation Total Loss 233.06372
Epoch 15 ||  N Examples 40960 || Train Total Loss 220.58293 || Validation Total Loss 234.19165
Epoch 16 ||  N Examples 43520 || Train Total Loss 214.53622 || Validation Total Loss 229.76579
Epoch 17 ||  N Examples 46080 || Train Total Loss 211.59363 || Validation Total Loss 231.73355
Epoch 18 ||  N Examples 48640 || Train Total Loss 204.53821 || Validation Total Loss 219.04488
Epoch 19 ||  N Examples 51200 || Train Total Loss 202.05453 || Validation Total Loss 229.27260
Epoch 20 ||  N Examples 53760 || Train Total Loss 

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
val_reg_loss,70.75697
val_loss,1.29956
n_examples,284160.0
train_reg_loss,53.26998
_step,25.0
epoch,110.0
_timestamp,1628733593.0
_runtime,87.0
val_total_loss,72.05654
train_loss,1.2908


0,1
train_loss,▁▂▃▅▆▇▇█▇▇▆
train_reg_loss,█▆▅▄▃▂▁▁▁▁▁
train_total_loss,█▆▅▄▃▂▁▁▁▁▁
epoch,▁▁▂▂▂▂▃▃▄▄▅▅▅▅▆▆▇▇▇▇██
n_examples,▁▁▂▂▂▂▃▃▄▄▅▅▅▅▆▆▇▇▇▇██
_runtime,▁▁▂▂▂▂▃▃▄▄▅▅▅▅▆▆▇▇▇▇██
_timestamp,▁▁▂▂▂▂▃▃▄▄▅▅▅▅▆▆▇▇▇▇██
_step,▁▁▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▇▇▇██
val_loss,▁▂▂▃▆▆▅█▇▇█
val_reg_loss,█▇▅▄▃▂▁▁▁▁▁
