Skip to content

Commit

Permalink
Combine train/valid loops
Browse files Browse the repository at this point in the history
  • Loading branch information
sevagh committed Jan 15, 2022
1 parent 0b9feec commit 4795bf9
Showing 1 changed file with 27 additions and 42 deletions.
69 changes: 27 additions & 42 deletions scripts/train.py
Expand Up @@ -17,7 +17,7 @@
import sys
import torchaudio
import torchinfo
from contextlib import contextmanager
from contextlib import nullcontext
import sklearn.preprocessing
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter
Expand All @@ -32,56 +32,36 @@
tqdm.monitor_interval = 0


def train(args, unmix, encoder, device, train_sampler, criterion, optimizer):
def loop(args, unmix, encoder, device, sampler, criterion, optimizer, train=True):
# unpack encoder object
nsgt, _, cnorm = encoder

losses = utils.AverageMeter()
unmix.train()
pbar = tqdm.tqdm(train_sampler, disable=args.quiet)

for x, y_bass, y_vocals, y_other, y_drums in pbar:
pbar.set_description("Training batch")
x, y_bass, y_vocals, y_other, y_drums = x.to(device), y_bass.to(device), y_vocals.to(device), y_other.to(device), y_drums.to(device)
optimizer.zero_grad()
X = nsgt(x)
Xmag = cnorm(X)
Ymag_bass = cnorm(nsgt(y_bass))
Ymag_vocals = cnorm(nsgt(y_vocals))
Ymag_drums = cnorm(nsgt(y_drums))
Ymag_other = cnorm(nsgt(y_other))

loss = criterion(
*unmix(Xmag), # forward call to unmix returns bass, vocals, other, drums
*(Ymag_bass, Ymag_vocals, Ymag_other, Ymag_drums),
*(y_bass, y_vocals, y_other, y_drums),
X,
x.shape[-1]
)

loss.backward()
optimizer.step()
losses.update(loss.item(), x.size(1))

return losses.avg

cm = None
name = ''
if train:
unmix.train()
cm = nullcontext
name = 'Train'
else:
unmix.eval()
cm = torch.no_grad
name = 'Validation'

def valid(args, unmix, encoder, device, valid_sampler, criterion):
# unpack encoder object
nsgt, _, cnorm = encoder
pbar = tqdm.tqdm(sampler, disable=args.quiet)

losses = utils.AverageMeter()
unmix.eval()

with torch.no_grad():
pbar = tqdm.tqdm(valid_sampler, disable=args.quiet)
with cm():
for x, y_bass, y_vocals, y_other, y_drums in pbar:
pbar.set_description("Validation batch")
pbar.set_description(f"{name} batch")

x, y_bass, y_vocals, y_other, y_drums = x.to(device), y_bass.to(device), y_vocals.to(device), y_other.to(device), y_drums.to(device)

if train:
optimizer.zero_grad()

X = nsgt(x)
Xmag = cnorm(X)

Ymag_bass = cnorm(nsgt(y_bass))
Ymag_vocals = cnorm(nsgt(y_vocals))
Ymag_drums = cnorm(nsgt(y_drums))
Expand All @@ -95,8 +75,13 @@ def valid(args, unmix, encoder, device, valid_sampler, criterion):
x.shape[-1]
)

if train:
loss.backward()
optimizer.step()

losses.update(loss.item(), x.size(1))
return losses.avg

return losses.avg


def get_statistics(args, encoder, dataset, time_blocks):
Expand Down Expand Up @@ -426,8 +411,8 @@ def kill_tboard():
for epoch in t:
t.set_description("Training Epoch")
end = time.time()
train_loss = train(args, unmix, encoder, device, train_sampler, criterion, optimizer)
valid_loss = valid(args, unmix, encoder, device, valid_sampler, criterion)
train_loss = loop(args, unmix, encoder, device, train_sampler, criterion, optimizer, train=True)
valid_loss = loop(args, unmix, encoder, device, valid_sampler, criterion, None, train=False)

scheduler.step(valid_loss)
train_losses.append(train_loss)
Expand Down

0 comments on commit 4795bf9

Please sign in to comment.