In [14]:
from myTransformer import *
from pedalboard_param_utils import *

dac_model_path = dac.utils.download(model_type="44khz")
dac_model = dac.DAC.load(dac_model_path, weights_only=True).to('cuda')
model_dim = 64
key_dim = 32
num_heads = 8
enc_out_dim = 9
ffn_hidden_dim = 32
num_stack=4
num_vocab = len(vocab)

my_model = MyModel(
    dac=dac_model, 
    num_vocab=num_vocab, 
    model_dim=model_dim, 
    key_dim=key_dim, 
    enc_out_dim=enc_out_dim, 
    ffn_hidden_dim=ffn_hidden_dim, 
    num_heads=num_heads,
    num_stack=num_stack
).to("cuda")


  model_dict = torch.load(location, "cpu")
  WeightNorm.apply(module, name, dim)


In [17]:
from EGFxSetDataParam import *
from ICMTSMTGuitarDataParam import * 
from torch.utils.data import ConcatDataset, DataLoader

egfx_data = EGFxSetData(pedal_dict)
icmt_mono = ICMTSMTGuitarDataMono(pedal_dict)
icmt_poly = ICMTSMTGuitarDataPoly(pedal_dict)

batch_size = 32

combined = ConcatDataset([egfx_data, icmt_mono, icmt_poly])
loader = DataLoader(dataset=combined, batch_size=32, num_workers=4, shuffle=True, collate_fn=collate_data)


In [18]:
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

lr = 0.0002
optimizer = Adam(my_model.parameters(), lr=lr)
criterion = CrossEntropyLoss(ignore_index=0)


In [26]:
from torch.utils.tensorboard import SummaryWriter
import datetime 
torch.manual_seed(0)

run_dir = "logs"
!rm -rf ./logs
run_time = datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")
logger = SummaryWriter(log_dir=Path(run_dir) / run_time, flush_secs=20)
epochs = 20
device = "cuda"

In [27]:
for epoch in range(epochs):
    
    # weight batch losses/scores proportional to batch size
    iter_count = 0
    loss_epoch = 0
    
    for batch_idx, batch_data in enumerate(loader):
        audio, pedal_str = batch_data
        my_model.zero_grad()
        
        # train on a batch of inputs
        logits = my_model(audio.to(device), pedal_str[:, :-1].to(device)).transpose(1,2)
        loss_batch = criterion(logits, pedal_str[:, 1:].to(device))
        loss_batch.backward()
        optimizer.step()
        
        # log loss
        loss_epoch += loss_batch.detach().item() * batch_size
        iter_count += batch_size
    
    # plot loss
    loss_epoch /= iter_count
    logger.add_scalar("cross_entropy_loss", loss_epoch, epoch)
    
    if not epoch % 10:
        print(f"Epoch: {epoch + 1}\tCross Entropy Loss: {loss_epoch :0.4f}")

Epoch: 1	Cross Entropy Loss: 242.5048
Epoch: 11	Cross Entropy Loss: 79.2807


In [35]:
loader_iter = iter(loader)
source, target = next(loader_iter)
predictions = torch.tensor([vocab.to_num(["<start>"]) for _ in range(32)]).to(device)
max_len = 32
while (predictions.shape[-1] < max_len):  
    logits = my_model(source.to(device), target[:, -1:].to(device))[:, -1:, :]
    print(logits.shape)
    next_word_indices = torch.argmax(logits, dim=-1) # [batch_size, 1, num_vocabs]
    
    predictions = torch.cat((predictions, next_word_indices), dim=1)



torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])
torch.Size([32, 1, 83])


In [36]:
for prediction in predictions.detach(): 
    print(vocab.to_str(prediction.cpu().numpy()))

['<start>', '<end>', '<sep>', 'reverb_dry_level_1', 'compressor_release_ms_50', 'compressor_ratio_1.5', 'phaser_centre_frequency_hz_600', 'compressor_release_ms_500', 'phaser_centre_frequency_hz_1500', 'phaser_mix_0.5', 'compressor_release_ms_500', 'reverb_damping_0', 'phaser_depth_1', 'chorus_centre_delay_ms_20', 'chorus_depth_0.15', 'reverb_dry_level_1', 'compressor_release_ms_50', 'compressor_release_ms_150', 'phaser_centre_frequency_hz_300', 'chorus_centre_delay_ms_10', 'reverb_room_size_1', 'compressor_release_ms_300', 'compressor_release_ms_50', 'phaser_depth_0.7', 'phaser_depth_0.7', 'chorus_mix_0.5', '<end>', '<sep>', 'chorus_feedback_0.2', 'phaser_centre_frequency_hz_300', 'phaser_depth_0.4', 'phaser_centre_frequency_hz_1500']
['<start>', 'phaser_mix_0.5', '<sep>', 'chorus_feedback_0.4', '<sep>', 'phaser_centre_frequency_hz_1500', 'chorus_depth_0.25', '<sep>', 'compressor_ratio_2', 'phaser_feedback_0.6', 'compressor_attack_ms_70', 'chorus_feedback_0.4', 'compressor_attack_ms_1