In [None]:
import torch
import yaml
from data import DrumMIDIDataset
from torch.utils.data import DataLoader
import numpy as np
from torchinfo import summary
from models import GrooveIQ
from trainers import GrooveIQ_Trainer
from utils import create_optimizer, create_scheduler, WeightScheduler
import wandb
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Config

In [None]:
%%writefile _configs/config_test.yaml

name : "Puru"
expt : "Test"

###### Dataset -----------------------------------------------------------------
data:
    train_path: "dataset/serialized/merged_ts=4-4_2bar_tr0.80-va0.10-te0.10_train.pkl"
    val_path: "dataset/serialized/merged_ts=4-4_2bar_tr0.80-va0.10-te0.10_val.pkl"
    test_path: "dataset/serialized/merged_ts=4-4_2bar_tr0.80-va0.10-te0.10_test.pkl"
    num_bars: 2
    feature_type: "fixed"  # [fixed, flexible]
    steps_per_quarter: 4
    subset: 0.01             # Fraction of the dataset to load
    num_workers: 0           # Number of workers for data loading
    batch_size: 16           # Batch size
    aug_config:
        win_sizes:
          - size: 2
            prob: 0.75
          - size: 3
            prob: 0.125
          - size: 4
            prob: 0.125
        velocity_range:
          min: 0.45
          max: 0.80
        max_hits_per_win: 2
        win_retain_prob: 1.0
        num_buttons: 2
        timing_jitter: 0.05
        velocity_jitter: 0.05
        miss_prob: 0.05
        spurious_prob: 0.1
    calc_desc: False

###### Network Specs -------------------------------------------------------------
model:
    z_dim: 16
    embed_dim: 32
    encoder_depth: 1
    encoder_heads: 1
    decoder_depth: 1
    decoder_heads: 1
    num_buttons:   2
    is_causal: True
    p: 0.5 # Probability of using z_post instead of z_prior
    button_penalty: 1 # 1 : L1, 2 : Group (L1 over T of L2(D))
    button_dropout: 0.2 # Dropout probability for button hits

###### Common Training Parameters ------------------------------------------------
training:
  config_file                 : "_configs/config_test.yaml"
  use_wandb                   : False   # Toggle wandb logging
  wandb_run_id                : "none" # "none" or "run_id"
  resume                      : True   # Resume an existing run (run_id != 'none')
  gradient_accumulation_steps : 1
  wandb_project               : "thesis" # wandb project to log to

###### Loss ----------------------------------------------------------------------
loss:
  pos_weight                  : 21.0 # Weight for positive class in hit loss
  hit_penalty                 : 21.0 # Hit penalty for hit loss (21 for train set) 
  threshold                   : 0.5  # Threshold for hit prediction
  recons_weight               : 1.0  # Weight for reconstruction loss
  kld_weight                  : 0.01 # Weight for kld loss
  distill_weight              : 0.01 # Weight for distill loss
  button_penalty_weight       : 0.0  # Weight for button penalty loss

###### Optimizer -----------------------------------------------------------------
optimizer:
  name: "adamw" # Options: sgd, adam, adamw
  lr: 0.0004    # Base learning rate

  # Common parameters
  weight_decay: 0.000001

  # Parameter groups
  # You can add more param groups as you want and set their learning rates and patterns
  param_groups:
    - name: self_attn
      patterns: []  # Will match all parameters containing "encoder"
      lr: 0.0002    # LR for self_attn
      layer_decay:
        enabled: False
        decay_rate: 0.8

    - name: ffn
      patterns: []
      lr: 0.0002  # LR for ffn
      layer_decay:
        enabled: False
        decay_rate: 0.8

  # Layer-wise learning rates
  layer_decay:
    enabled: False
    decay_rate: 0.75

  # SGD specific parameters
  sgd:
    momentum: 0.9
    nesterov: True
    dampening: 0

  # Adam specific parameters
  adam:
    betas: [0.9, 0.999]
    eps: 1.0e-8
    amsgrad: False

  # AdamW specific parameters
  adamw:
    betas: [0.9, 0.999]
    eps: 1.0e-8
    amsgrad: False

###### Scheduler -----------------------------------------------------------------
scheduler:
  name: "cosine"  # Options: reduce_lr, cosine, cosine_warm

  # ReduceLROnPlateau specific parameters
  reduce_lr:
    mode: "min"  # Options: min, max
    factor: 0.1  # Factor to reduce learning rate by
    patience: 10  # Number of epochs with no improvement after which LR will be reduced
    threshold: 0.0001  # Threshold for measuring the new optimum
    threshold_mode: "rel"  # Options: rel, abs
    cooldown: 0  # Number of epochs to wait before resuming normal operation
    min_lr: 0.0000001  # Minimum learning rate
    eps: 1e-8  # Minimal decay applied to lr

  # CosineAnnealingLR specific parameters
  cosine:
    T_max: 15  # Maximum number of iterations
    eta_min: 0.0000001  # Minimum learning rate
    last_epoch: -1

  # CosineAnnealingWarmRestarts specific parameters
  cosine_warm:
    T_0: 10    # Number of iterations for the first restart
    T_mult: 10 # Factor increasing T_i after each restart
    eta_min: 0.0000001  # Minimum learning rate
    last_epoch: -1

  # Warmup parameters (can be used with any scheduler)
  warmup:
    enabled: True
    type: "exponential"  # Options: linear, exponential
    epochs: 5
    start_factor: 0.1
    end_factor: 1.0

###### KL Weight Scheduler -----------------------------------------------------------------
kl_weight_scheduler:
  enabled: True
  params:
    weight_max: 0.2
    total_epochs: 5
    zero_epochs: 0
    warmup_epochs: 2

In [None]:
with open('_configs/config_test.yaml', 'r') as file:
    config = yaml.safe_load(file)

# Datasets / Dataloader

In [None]:
## Load Datasets
train_dataset = DrumMIDIDataset(
    path     = config["data"]["train_path"],
    num_bars = config["data"]["num_bars"],
    feature_type = config["data"]["feature_type"],
    steps_per_quarter = config["data"]["steps_per_quarter"],
    subset   = config["data"]["subset"],
    aug_config = config["data"]["aug_config"],
    calc_desc = False
)

val_dataset  = DrumMIDIDataset(
    path     = config["data"]["val_path"],
    num_bars = config["data"]["num_bars"],
    feature_type      = config["data"]["feature_type"],
    steps_per_quarter = config["data"]["steps_per_quarter"],
    subset   = config["data"]["subset"],
    aug_config = config["data"]["aug_config"],
    calc_desc = False
)

test_dataset = DrumMIDIDataset(
    path     = config["data"]["test_path"],
    num_bars = config["data"]["num_bars"],
    feature_type      = config["data"]["feature_type"],
    steps_per_quarter = config["data"]["steps_per_quarter"],
    subset   = config["data"]["subset"],
    aug_config = config["data"]["aug_config"],
    calc_desc = False
)

## Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size = config["data"]["batch_size"],
    num_workers = config["data"]["num_workers"],
    shuffle = True,
    collate_fn = train_dataset.collate_fn,
    pin_memory = True
)

val_loader = DataLoader(
    val_dataset,
    batch_size = config["data"]["batch_size"],
    num_workers = config["data"]["num_workers"],
    shuffle = True,
    collate_fn = val_dataset.collate_fn,
    pin_memory = True
)

test_loader = DataLoader(
    test_dataset,
    batch_size = config["data"]["batch_size"],
    num_workers = config["data"]["num_workers"],
    shuffle = True,
    collate_fn = test_dataset.collate_fn,
    pin_memory = True
)

## Test a sample
for batch in train_loader:
    print(f"Batch size: {len(batch['samples'])}")
    print(f"Grid shape: {batch['grid'].shape}")
    if batch['button_hvo'] is not None:
        print(f"Button HVO shape: {batch['button_hvo'].shape}")
    else:
        print("No button HVO")
    if batch['labels'] is not None: 
        print(f"Label shape: {batch['labels'].shape}")
    else:
        print("No labels")
    grid = batch['grid']
    random_idx = np.random.randint(len(batch['samples']))
    sample = batch['samples'][random_idx]
    sample.feature.play()
    if batch['button_hvo'] is not None:
        button_hvo = batch['button_hvo'][random_idx]
        button_hvo_feature = sample.feature.from_button_hvo(button_hvo, steps_per_quarter=4)
        button_hvo_feature.play_button_hvo(button_hvo_feature)
    break

In [None]:
NUM_QUARTERS_PER_BAR = 4 # 4/4 time signature
MAX_LENGTH = config["data"]["num_bars"] * NUM_QUARTERS_PER_BAR * config["data"]["steps_per_quarter"] + 1
print(f"Max length: {MAX_LENGTH}")

# Model

In [None]:
model_config = config["model"]
model_config.update(
    T=MAX_LENGTH,
    E=grid.shape[2],
    M=grid.shape[3]
)

inputs = [grid]
model = GrooveIQ(**model_config)
summary(model, input_data = inputs, device = device)

# Training

In [None]:
trainer = GrooveIQ_Trainer(
    model = model,
    config = config,
    run_name = config["expt"],
    config_file = config['training']['config_file'],
    device = device
)

trainer.set_optimizer(
    create_optimizer(
        model=model,
        opt_config=config['optimizer']
    )
)

trainer.set_scheduler(
    create_scheduler(
        optimizer=trainer.optimizer,
        scheduler_config=config['scheduler'],
        train_loader=train_loader,
        gradient_accumulation_steps=config['training']['gradient_accumulation_steps']
    )
)

if config['kl_weight_scheduler']['enabled']:
    trainer.set_kl_weight_scheduler(
        WeightScheduler(**config['kl_weight_scheduler']['params'])
    )

In [None]:
trainer.train(train_loader, val_loader, epochs=5)