In [None]:
%load_ext autoreload
%autoreload 2

In [1]:
from spectral_networks.nn.models import ReLUNet
from transform_datasets.patterns.synthetic import *
from transform_datasets.transforms import *
from transform_datasets.utils.wandb import load_or_create_dataset
from torch_tools.trainer import Trainer
from torch_tools.logger import WBLogger
from torch_tools.config import Config
from torch_tools.data import TrainValLoader
from pytorch_metric_learning import losses, distances
from torch.optim import Adam

In [2]:

DATA_PROJECT = "dataset"
MODEL_PROJECT = "bispectrum"
ENTITY = "naturalcomputation"
DEVICE = "cuda:0"
SEED = 0

"""
DATASET
"""

dataset_config = Config(
    {
        "type": HarmonicsS1,
        "params": {"dim": 256, "n_classes": 10, "seed": 5},
    }
)

transforms_config = {
    "0": Config(
        {
            "type": CyclicTranslation1D,
            "params": {
                "fraction_transforms": 1.0,
                "sample_method": "linspace",
            },
        }
    ),
    "1": Config(
        {
            "type": UniformNoise,
            "params": {"n_samples": 1, "magnitude": 0.1},
        }
    ),
}


tdataset_config = {"dataset": dataset_config, "transforms": transforms_config}

dataset = load_or_create_dataset(tdataset_config, DATA_PROJECT, ENTITY)

"""
DATA_LOADER
"""

data_loader_config = Config(
    {
        "type": TrainValLoader,
        "params": {
            "batch_size": 32,
            "fraction_val": 0.2,
            "num_workers": 1,
            "seed": SEED,
        },
    }
)

data_loader = data_loader_config.build()
data_loader.load(dataset)



In [3]:

"""
MODEL
"""
model_config = Config(
    {
        "type": ReLUNet,
        "params": {
            "size_in": dataset.dim,
            "hdim": [256],
            "seed": SEED,
            "device": 'cuda:0'
        },
    }
)
model = model_config.build()

"""
OPTIMIZER
"""
optimizer_config = Config({"type": Adam, "params": {"lr": 0.001}})
# optimizer = optimizer_config.build()

"""
LOSS
"""
loss_config = Config(
    {
        "type": losses.ContrastiveLoss,
        "params": {
            "pos_margin": 0,
            "neg_margin": 1,
            "distance": distances.LpDistance(),
        },
    }
)
loss = loss_config.build()

"""
MASTER CONFIG
"""

config = {
    "dataset": dataset_config,
    "model": model_config,
    "optimizer": optimizer_config,
    "loss": loss_config,
    "data_loader": data_loader_config,
}

"""
LOGGING
"""
logging_config = Config(
    {
        "type": WBLogger,
        "params": {
            "config": config,
            "project": MODEL_PROJECT,
            "entity": ENTITY,
            "log_interval": 10,
            "watch_interval": 10 * len(data_loader.train),
        },
    }
)

logger = logging_config.build()


[34m[1mwandb[0m: Currently logged in as: [33mshewmake[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [6]:
"""
TRAINER
"""

training_config = Config(
    {
        "type": Trainer,
        "params": {
            "model": model,
            "loss": loss,
            "logger": logger,
            "device": DEVICE,
            "optimizer_config": optimizer_config,
        },
    }
)

trainer = training_config.build()

In [8]:
trainer.train(data_loader, epochs = 10)

Epoch 0  ||  N Examples 0 || Training Loss: 1.17979  |  Validation Loss: 1.16959
Epoch 1  ||  N Examples 2560 || Training Loss: 1.16935  |  Validation Loss: 1.15554
Epoch 2  ||  N Examples 5120 || Training Loss: 1.15042  |  Validation Loss: 1.15476
Epoch 3  ||  N Examples 7680 || Training Loss: 1.15296  |  Validation Loss: 1.15147
Epoch 4  ||  N Examples 10240 || Training Loss: 1.13832  |  Validation Loss: 1.15255
Epoch 5  ||  N Examples 12800 || Training Loss: 1.13202  |  Validation Loss: 1.13471
Epoch 6  ||  N Examples 15360 || Training Loss: 1.12699  |  Validation Loss: 1.12398
Epoch 7  ||  N Examples 17920 || Training Loss: 1.11811  |  Validation Loss: 1.13606
Epoch 8  ||  N Examples 20480 || Training Loss: 1.11347  |  Validation Loss: 1.11144
Epoch 9  ||  N Examples 23040 || Training Loss: 1.10666  |  Validation Loss: 1.10335
Epoch 10  ||  N Examples 25600 || Training Loss: 1.09435  |  Validation Loss: 1.10278


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_total_loss,1.09435
epoch,10.0
n_examples,25600.0
_runtime,13.0
_timestamp,1628725879.0
_step,3.0
val_total_loss,1.10278


0,1
train_total_loss,█▁
epoch,▁▁██
n_examples,▁▁██
_runtime,▁▂██
_timestamp,▁▂██
_step,▁▃▆█
val_total_loss,█▁


In [9]:
trainer.resume(data_loader, epochs=15)

[34m[1mwandb[0m: wandb version 0.12.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Epoch 10  ||  N Examples 28160 || Training Loss: 1.10065  |  Validation Loss: 1.11592
Epoch 11  ||  N Examples 30720 || Training Loss: 1.09765  |  Validation Loss: 1.08472
Epoch 12  ||  N Examples 33280 || Training Loss: 1.09369  |  Validation Loss: 1.08569
Epoch 13  ||  N Examples 35840 || Training Loss: 1.08431  |  Validation Loss: 1.09072
Epoch 14  ||  N Examples 38400 || Training Loss: 1.08414  |  Validation Loss: 1.08653
Epoch 15  ||  N Examples 40960 || Training Loss: 1.08248  |  Validation Loss: 1.09383
Epoch 16  ||  N Examples 43520 || Training Loss: 1.08329  |  Validation Loss: 1.08461
Epoch 17  ||  N Examples 46080 || Training Loss: 1.06725  |  Validation Loss: 1.09197
Epoch 18  ||  N Examples 48640 || Training Loss: 1.07342  |  Validation Loss: 1.08721
Epoch 19  ||  N Examples 51200 || Training Loss: 1.06250  |  Validation Loss: 1.07853
Epoch 20  ||  N Examples 53760 || Training Loss: 1.07085  |  Validation Loss: 1.09269
Epoch 21  ||  N Examples 56320 || Training Loss: 1.077

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_total_loss,1.07085
_step,8.0
epoch,20.0
_runtime,25.0
_timestamp,1628725905.0
n_examples,53760.0
val_total_loss,1.09269


0,1
train_total_loss,█▁
epoch,▁▁██
n_examples,▁▁██
_runtime,▁▂▆▆█
_timestamp,▁▂▆▆█
_step,▁▃▅▆█
val_total_loss,█▁
