# Neural Amp Modeler: Online Training Notebook
**Note**:
This notebook is meant to be used on [Google Colab](https://colab.research.google.com/github/sdatkinson/neural-amp-modeler/blob/main/bin/train/easy_colab.ipynb).

🔶**Before you run**🔶

Make sure to get a GPU! (From the upper-left menu, click Runtime->Change runtime type->Select "GPU" from the "Hardware accelerator dropdown menu)

## Step 1: Get data
* **Download the reamp signal.** Here: [v3_0_0.wav](https://drive.google.com/file/d/1Pgf8PdE0rKB1TD4TRPKbpNo1ByR3IOm9/view?usp=drive_link).
* **Reamp your gear.** Then reamp the gear you want to model using it. Save that reamp as "output.wav". *Note: Use 48kHz, 24-bit, mono.* For other sample rates, use [the CLI trainer](https://github.com/sdatkinson/neural-amp-modeler).
* **Upload your files.** Upload the input (DI) and output (amped) files you want to use by clicking the Folder icon on the left ⬅ and then clicking the upload icon or by dragging the files into the panel.

## Step 2: Train!
Configure your training run below, then hit the Play button to start training!

🕙NOTE: At default settings, training will take about 10 minutes.🕙

In [None]:
try:
    import nam
except ImportError as e:
    print("Installing NAM into Colab. This should take under 2 minutes.")
    !if [ ! -d logs ]; then mkdir logs; fi
    # Check what we're starting with
    !pip list > logs/packages.log
    # Then install
    # !pip install neural-amp-modeler > logs/install.log
    # Hint: use the next line instead for the very latest!
    # FIXME
    !pip install git+https://github.com/sdatkinson/neural-amp-modeler.git@main > logs/install.log

from functools import partial
from pathlib import Path
from typing import Optional, Tuple

from nam.models.metadata import UserMetadata
from nam.train._names import INPUT_BASENAMES, LATEST_VERSION, Version
from nam.train._version import PROTEUS_VERSION, Version  # FIXME
from nam.train.core import TrainOutput, train  # FIXME
from nam.train.metadata import TRAINING_KEY

from nam.models.metadata import GearType, ToneType, UserMetadata


def run(
    epochs: int = 100,
    delay: Optional[int] = None,
    model_type: str = "WaveNet",
    architecture: str = "standard",
    lr: float = 0.004,
    lr_decay: float = 0.007,
    seed: Optional[int] = 0,
    user_metadata: Optional[UserMetadata] = None,
    ignore_checks: bool = False,
    fit_cab: bool = False,
):
    """
    Wrapper around the core trainer functionality.

    :param epochs: How many epochs we'll train for.
    :param delay: How far the output algs the input due to round-trip latency during
        reamping, in samples.
    :param stage_1_channels: The number of channels in the WaveNet's first stage.
    :param stage_2_channels: The number of channels in the WaveNet's second stage.
    :param lr: The initial learning rate
    :param lr_decay: The amount by which the learning rate decays each epoch
    :param seed: RNG seed for reproducibility.
    :param user_metadata: User-specified metadata to include in the .nam file.
    :param ignore_checks: Ignores the data quality checks and YOLOs it
    """

    BUGGY_INPUT_BASENAMES = {
        # 1.1.0 has the spikes at the wrong spots.
        "v1_1_0.wav"
    }
    OUTPUT_BASENAME = "output.wav"
    TRAIN_PATH = "."

    def check_for_files() -> Tuple[Version, str]:
        # TODO use hash logic as in GUI trainer!
        print("Checking that we have all of the required audio files...")
        for name in BUGGY_INPUT_BASENAMES:
            if Path(name).exists():
                raise RuntimeError(
                    f"Detected input signal {name} that has known bugs. Please download the latest input signal, {LATEST_VERSION[1]}"
                )
        for input_version, input_basename in INPUT_BASENAMES:
            if Path(input_basename).exists():
                if input_version == PROTEUS_VERSION:
                    print(f"Using Proteus input file...")
                elif input_version != LATEST_VERSION.version:
                    print(
                        f"WARNING: Using out-of-date input file {input_basename}. "
                        "Recommend downloading and using the latest version, "
                        f"{LATEST_VERSION.name}."
                    )
                break
        else:
            raise FileNotFoundError(
                f"Didn't find NAM's input audio file. Please upload {LATEST_VERSION.name}"
            )
        if not Path(OUTPUT_BASENAME).exists():
            raise FileNotFoundError(
                f"Didn't find your reamped output audio file. Please upload {OUTPUT_BASENAME}."
            )
        if input_version != PROTEUS_VERSION:
            print(f"Found {input_basename}, version {input_version}")
        else:
            print(f"Found Proteus input {input_basename}.")
        return input_version, input_basename
    
    def get_valid_export_directory():
        def get_path(version):
            return Path("exported_models", f"version_{version}")

        version = 0
        while get_path(version).exists():
            version += 1
        return get_path(version)

    input_version, input_basename = check_for_files()

    train_output: TrainOutput = train(
        input_basename,
        OUTPUT_BASENAME,
        TRAIN_PATH,
        input_version=input_version,
        epochs=epochs,
        latency=delay,
        model_type=model_type,
        architecture=architecture,
        lr=lr,
        lr_decay=lr_decay,
        seed=seed,
        local=False,
        ignore_checks=ignore_checks,
        fit_cab=fit_cab,
    )
    model = train_output.model
    training_metadata = train_output.metadata

    if model is None:
        print("No model returned; skip exporting!")
    else:
        print("Exporting your model...")
        model_export_outdir = get_valid_export_directory()
        model_export_outdir.mkdir(parents=True, exist_ok=False)
        model.net.export(
            model_export_outdir,
            user_metadata=user_metadata,
            other_metadata={TRAINING_KEY: training_metadata.model_dump()},
        )
        print(f"Model exported to {model_export_outdir}. Enjoy!")


%load_ext tensorboard

import ipywidgets as widgets

# NOTE: Enums need to be handled carefully since the values need to be supplied literally here!

#@markdown # Training parameters
epochs = 100 #@param {type: "number"}
architecture = "standard"  #@param ["standard", "lite", "feather", "nano"] {type: "string"}
latency_samples = "auto"  #@param {type: "string"}
fit_cab = False  #@param {type: "boolean"}
ignore_checks = False #@param {type: "boolean"}

#@markdown # Metadata
use_metadata = False #@param {type: "boolean"}
name = "My model" #@param {type:"string"}
modeled_by = "Your name" #@param {type:"string"}
gear_make = "GearCo" #@param {type:"string"}
gear_model = "GearName" #@param {type:"string"}
gear_type = "amp" #@param ["amp", "pedal", "pedal_amp", "amp_cab", "amp_pedal_cab", "preamp", "studio"] {type:"string"}
tone_type = "clean" #@param ["clean", "overdrive", "crunch", "hi_gain", "fuzz"] {type:"string"}

def _verbose_enum(E, val):
    try:
        return E(val)
    except ValueError as e:
        raise ValueError(
            str(e)
            + "\nValid choices are: "
            + ", ".join(list(x.value for x in E))
        )

def _parse_latency(ls: str):
    if ls.lower() == "auto":
        return None
    try:
        return int(ls)
    except ValueError as e:
        raise ValueError(
            f"Invalid value for latency {ls} was given. Either use 'auto' or provide "
            f"the reamp latency, in samples.\nOriginal error:\n\n{e}"
        )

user_metadata = None if not use_metadata else UserMetadata(
    name=name,
    modeled_by=modeled_by,
    gear_make=gear_make,
    gear_model=gear_model,
    gear_type=_verbose_enum(GearType, gear_type.lower()),
    tone_type=_verbose_enum(ToneType, tone_type.lower())
)
run_partial = partial(run, user_metadata=user_metadata)

%tensorboard --logdir /content/lightning_logs
run(
    epochs=epochs,
    architecture=architecture,
    fit_cab=fit_cab,
    ignore_checks=ignore_checks,
    user_metadata=user_metadata,
    delay=_parse_latency(latency_samples)
)

## Step 3: Check the results and download your model
We're done!

Have a look at the plot above to see how your model compares to the real gear you're modeling.
Hopefully it looks good!
Go to the file browser on the left panel ⬅ and download `model.nam` from the `exported_model` directory (you may need to hit the refresh button).

Additionally, if you want to continue to train this model later you can download the lightning model artifacts from `lightning_logs`. If not, that's fine too.

# 🎸 **ENJOY!** 🎸