diff --git a/scripts/train.py b/scripts/train.py index 33fa8b5..68ed5ee 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -17,7 +17,7 @@ import sys import torchaudio import torchinfo -from contextlib import contextmanager +from contextlib import nullcontext import sklearn.preprocessing from torch.autograd import Variable from torch.utils.tensorboard import SummaryWriter @@ -32,56 +32,36 @@ tqdm.monitor_interval = 0 -def train(args, unmix, encoder, device, train_sampler, criterion, optimizer): +def loop(args, unmix, encoder, device, sampler, criterion, optimizer, train=True): # unpack encoder object nsgt, _, cnorm = encoder losses = utils.AverageMeter() - unmix.train() - pbar = tqdm.tqdm(train_sampler, disable=args.quiet) - - for x, y_bass, y_vocals, y_other, y_drums in pbar: - pbar.set_description("Training batch") - x, y_bass, y_vocals, y_other, y_drums = x.to(device), y_bass.to(device), y_vocals.to(device), y_other.to(device), y_drums.to(device) - optimizer.zero_grad() - X = nsgt(x) - Xmag = cnorm(X) - Ymag_bass = cnorm(nsgt(y_bass)) - Ymag_vocals = cnorm(nsgt(y_vocals)) - Ymag_drums = cnorm(nsgt(y_drums)) - Ymag_other = cnorm(nsgt(y_other)) - - loss = criterion( - *unmix(Xmag), # forward call to unmix returns bass, vocals, other, drums - *(Ymag_bass, Ymag_vocals, Ymag_other, Ymag_drums), - *(y_bass, y_vocals, y_other, y_drums), - X, - x.shape[-1] - ) - - loss.backward() - optimizer.step() - losses.update(loss.item(), x.size(1)) - - return losses.avg + cm = None + name = '' + if train: + unmix.train() + cm = nullcontext + name = 'Train' + else: + unmix.eval() + cm = torch.no_grad + name = 'Validation' -def valid(args, unmix, encoder, device, valid_sampler, criterion): - # unpack encoder object - nsgt, _, cnorm = encoder + pbar = tqdm.tqdm(sampler, disable=args.quiet) - losses = utils.AverageMeter() - unmix.eval() - - with torch.no_grad(): - pbar = tqdm.tqdm(valid_sampler, disable=args.quiet) + with cm(): for x, y_bass, y_vocals, y_other, y_drums in pbar: - pbar.set_description("Validation batch") + pbar.set_description(f"{name} batch") + x, y_bass, y_vocals, y_other, y_drums = x.to(device), y_bass.to(device), y_vocals.to(device), y_other.to(device), y_drums.to(device) + if train: + optimizer.zero_grad() + X = nsgt(x) Xmag = cnorm(X) - Ymag_bass = cnorm(nsgt(y_bass)) Ymag_vocals = cnorm(nsgt(y_vocals)) Ymag_drums = cnorm(nsgt(y_drums)) @@ -95,8 +75,13 @@ def valid(args, unmix, encoder, device, valid_sampler, criterion): x.shape[-1] ) + if train: + loss.backward() + optimizer.step() + losses.update(loss.item(), x.size(1)) - return losses.avg + + return losses.avg def get_statistics(args, encoder, dataset, time_blocks): @@ -426,8 +411,8 @@ def kill_tboard(): for epoch in t: t.set_description("Training Epoch") end = time.time() - train_loss = train(args, unmix, encoder, device, train_sampler, criterion, optimizer) - valid_loss = valid(args, unmix, encoder, device, valid_sampler, criterion) + train_loss = loop(args, unmix, encoder, device, train_sampler, criterion, optimizer, train=True) + valid_loss = loop(args, unmix, encoder, device, valid_sampler, criterion, None, train=False) scheduler.step(valid_loss) train_losses.append(train_loss)