In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
from spectral_networks.nn.models import ReLUNet
from spectral_networks.nn.cplx.models import BispectralEmbedding
from transform_datasets.patterns.synthetic import *
from transform_datasets.transforms import *
from transform_datasets.utils.wandb import load_or_create_dataset
from torch_tools.trainer import Trainer
from torch_tools.logger import WBLogger
from torch_tools.config import Config
from torch_tools.regularizer import Regularizer, MultiRegularizer
from torch_tools.functional import l1_norm

from torch_tools.data import TrainValLoader
from pytorch_metric_learning import losses, distances
from torch.optim import Adam

from torch_tools.plotter import Plotter, MultiPlotter
from spectral_networks.analysis.wandb import gen_UVW_analysis_plots_1D, gen_train_val_spectrum_plot_1d
from spectral_networks.analysis.wandb import gen_avg_data_spectrum_plot_2D, gen_UVW_analysis_plots_2D

In [3]:
import matplotlib as mpl 
import matplotlib.pyplot as plt
plt.ioff()
mpl.is_interactive()


False

In [4]:
DATA_PROJECT = "dataset"
MODEL_PROJECT = "bispectrum"
ENTITY = "naturalcomputation"
DEVICE = "cpu"
SEED = 0

"""
DATASET
"""

dataset_config = Config(
    {
        "type": HarmonicsS1,
        "params": {"dim": 256, "n_classes": 10, "seed": 5},
    }
)

transforms_config = {
    "0": Config(
        {
            "type": CyclicTranslation1D,
            "params": {
                "fraction_transforms": 1.0,
                "sample_method": "linspace",
            },
        }
    ),
    "1": Config(
        {
            "type": UniformNoise,
            "params": {"n_samples": 1, "magnitude": 0.1},
        }
    ),
}


tdataset_config = {"dataset": dataset_config, "transforms": transforms_config}

dataset = load_or_create_dataset(tdataset_config, DATA_PROJECT, ENTITY)

"""
DATA_LOADER
"""

data_loader_config = Config(
    {
        "type": TrainValLoader,
        "params": {
            "batch_size": 32,
            "fraction_val": 0.2,
            "num_workers": 1,
            "seed": SEED,
        },
    }
)

data_loader = data_loader_config.build()
data_loader.load(dataset)

In [5]:
"""
MODEL
"""
model_config = Config(
    {
        "type": BispectralEmbedding,
        "params": {
            "size_in": dataset.dim,
            "hdim": 256,
            "seed": SEED,
            "device": DEVICE
        },
    }
)
model = model_config.build()

"""
OPTIMIZER
"""
optimizer_config = Config({"type": Adam, "params": {"lr": 0.001}})
# optimizer = optimizer_config.build()


'''
REGULARIZER
'''
regularizer_config1 = Config({'type': Regularizer, 'params': {'function': l1_norm, 
                                              'variables': ['out'],
                                              'coefficient': 0.1
                                             }
                         })

regularizer_config2 = Config({'type': Regularizer, 'params': {'function': l1_norm, 
                                              'variables': ['out'],
                                              'coefficient': 0.01
                                             }
                         })

multiregularizer_config = Config({'type': MultiRegularizer, 'params': {'regularizer_configs': [regularizer_config1, regularizer_config2]}})
regularizer = multiregularizer_config.build()


"""
LOSS
"""
loss_config = Config(
    {
        "type": losses.ContrastiveLoss,
        "params": {
            "pos_margin": 0,
            "neg_margin": 1,
            "distance": distances.LpDistance(),
        },
    }
)
loss = loss_config.build()

"""
MASTER CONFIG
"""

config = {
    "dataset": dataset_config,
    "model": model_config,
    "optimizer": optimizer_config,
    "loss": loss_config,
    "data_loader": data_loader_config,
}

'''
PLOTTER
'''
# step_plotter_config = Config({'type': Plotter, 'params': {'function': gen_UVW_analysis_plots_1D, 
#                                               'variables': ['model'],
#                                               'f_params': {'use_wandb': True}
#                                              }
#                          })


end_plotter_config = Config({'type': Plotter, 'params': {'function': gen_train_val_spectrum_plot_1d, 
                                              'variables': ['data_loader'],
                                              'f_params': {},
                                             }
                         })

# multiplotter_config = Config({'type': MultiPlotter, 'params': {'plotter_configs': [plotter_config2, plotter_config3]}})
# plotter = multiplotter_config.build()

# step_plotter = step_plotter_config.build()
end_plotter = end_plotter_config.build()

"""
LOGGING
"""
logging_config = Config(
    {
        "type": WBLogger,
        "params": {
            "config": config,
            "project": MODEL_PROJECT,
            "entity": ENTITY,
            "log_interval": 2, # len(data_loader.train),
            "watch_interval": 1, #len(data_loader.train),
            "end_plotter": end_plotter,
#             "step_plotter": step_plotter
        },
    }
)

logger = logging_config.build()



Setting attributes on ParameterDict is not supported.

[34m[1mwandb[0m: Currently logged in as: [33mshewmake[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [6]:
"""
TRAINER
"""

training_config = Config(
    {
        "type": Trainer,
        "params": {
            "model": model,
            "loss": loss,
            "logger": logger,
            "regularizer": regularizer,
            "device": DEVICE,
            "optimizer_config": optimizer_config,
        },
    }
)

trainer = training_config.build()

In [7]:
trainer.train(data_loader, epochs = 10)

Epoch 0 ||  N Examples 0 || Train Total Loss 146.63321 || Validation Total Loss 29.74714
Epoch 1 ||  N Examples 2560 || Train Total Loss 16.41786 || Validation Total Loss 10.43725
Epoch 2 ||  N Examples 5120 || Train Total Loss 7.55866 || Validation Total Loss 6.33664
Epoch 3 ||  N Examples 7680 || Train Total Loss 5.14227 || Validation Total Loss 4.74386
Epoch 4 ||  N Examples 10240 || Train Total Loss 4.07958 || Validation Total Loss 3.94973
Epoch 5 ||  N Examples 12800 || Train Total Loss 3.51584 || Validation Total Loss 3.49008
Epoch 6 ||  N Examples 15360 || Train Total Loss 3.17452 || Validation Total Loss 3.19145
Epoch 7 ||  N Examples 17920 || Train Total Loss 2.95183 || Validation Total Loss 2.99191
Epoch 8 ||  N Examples 20480 || Train Total Loss 2.79136 || Validation Total Loss 2.84261
Epoch 9 ||  N Examples 23040 || Train Total Loss 2.67384 || Validation Total Loss 2.72732
Epoch 10 ||  N Examples 25600 || Train Total Loss 2.58445 || Validation Total Loss 2.64069


VBox(children=(Label(value=' 0.03MB of 0.03MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_loss,1.41363
train_reg_loss,1.17082
train_total_loss,2.58445
epoch,10.0
n_examples,25600.0
_runtime,18.0
_timestamp,1628808371.0
_step,22.0
val_loss,1.41295
val_reg_loss,1.22774


0,1
train_loss,▁▅██▇▇██▇▇▇
train_reg_loss,█▂▁▁▁▁▁▁▁▁▁
train_total_loss,█▂▁▁▁▁▁▁▁▁▁
epoch,▁▁▂▂▂▂▃▃▄▄▅▅▅▅▆▆▇▇▇▇██
n_examples,▁▁▂▂▂▂▃▃▄▄▅▅▅▅▆▆▇▇▇▇██
_runtime,▁▁▂▂▂▂▃▃▄▄▅▅▅▅▆▆▆▇▇▇▇▇█
_timestamp,▁▁▂▂▂▂▃▃▄▄▅▅▅▅▆▆▆▇▇▇▇▇█
_step,▁▁▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇██
val_loss,██▃▇▄▇▁▅▇▄▄
val_reg_loss,█▃▂▂▁▁▁▁▁▁▁


In [8]:
# trainer.resume(data_loader, epochs=20)