In [1]:
import torch
import torch.nn as nn
import torchaudio
from torchvision.datasets import DatasetFolder
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

from utils import *
from model import BiLSTM
from CheckpointWriter.CheckpointWriter import CheckpointWriter

from tqdm.notebook import tqdm, trange

In [2]:
# helper functions
def load_audio(path):
    return torchaudio.load(path)[0]

def custom_collate(batch):
    transposed = list(zip(*batch))
    samples = transposed[0]
    targets = torch.LongTensor(transposed[1])
    return [samples, targets]

def sort_batch(batch):
    batch_length = list(map(len, batch))
    perm_idx = sorted(list(range(len(batch_length))), key=lambda k: batch_length[k], reverse=True)
    batch_sorted = [batch[perm_idx[i]].squeeze(0).T for i in range(len(batch_length))]
    return batch_sorted, perm_idx

## prepare for training

In [4]:
# device
device_id = '2'
device = torch.device('cuda:'+device_id)

# parameters
n_fft = 512
input_size = n_fft // 2 + 1
hidden_size = 256
num_layers = 2
learning_rate = 0.03
weight_decay = 1e-5
num_epochs = 100
batch_size = 128
expr_log = 'fft_128_2'
data_dir = os.path.join('data', 'F2000_split')

# data loader
spectrogram = torchaudio.transforms.Spectrogram(n_fft=n_fft)
transform = transforms.Compose([
                        spectrogram,
                        ])
train_set = DatasetFolder(root=os.path.join(data_dir, 'train'), 
                        loader=load_audio, 
                        extensions=['wav'],
                        transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, collate_fn=custom_collate)

test_set = DatasetFolder(root=os.path.join(data_dir, 'test'), 
                        loader=load_audio, 
                        extensions=['wav'],
                        transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, collate_fn=custom_collate)

# model
num_classes = len(train_set.classes)
model = BiLSTM(input_size, hidden_size, num_layers, num_classes, device_id=device_id).to(device)

# checkpoint
model_specs = {}
model_specs['base'] = {}
model_specs['base']['input_size'] = input_size
# model_specs['base']['max_len'] = max_len 
model_specs['base']['hidden_size'] = hidden_size
model_specs['base']['num_layers'] = num_layers
model_specs['base']['num_classes'] = num_classes
# model_specs['base']['feature_choice'] = feature_choice
checkpoint = CheckpointWriter(os.path.join('runs', expr_log), model_specs)

# loss
criterion = nn.CrossEntropyLoss()

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True, patience=100)

# Tensorboard
writer = SummaryWriter(os.path.join('runs', expr_log))

## Training loop

In [5]:
total_step = len(train_loader)
best_val_accu = -1
best_loss_val = 100
loss_val = 100
accuracy_val = 0
t = trange(num_epochs, desc='{}, {}%'.format(loss_val, accuracy_val))
for epoch in t:
    model.train()
    correct = 0
    for i, (images, labels) in enumerate(train_loader):
        images_sorted, perm_idx = sort_batch(images)
        images_packed = nn.utils.rnn.pack_padded_sequence(nn.utils.rnn.pad_sequence(images_sorted, batch_first=True), 
                                                          [v.size(0) for v in images_sorted], 
                                                          batch_first=True,
                                                          enforce_sorted=False)
        images_packed = images_packed.to(device)
        labels_sorted = labels[perm_idx]
        labels_sorted = labels_sorted.to(device)


        # images = images.to(device)
        # labels = labels.to(device)

        # Forward pass
        outputs = model(images_packed, len(images))
        loss = criterion(outputs, labels_sorted)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        # calculate training accuracy
        _, predicted = torch.max(outputs.softmax(dim=1), dim=1)
        correct += (predicted == labels_sorted).sum()
    accuracy = 100 * torch.true_divide(correct, train_set.__len__()) 

    # Validate the model
    model.eval()
    with torch.no_grad():
        correct_val = 0
        for images, labels in test_loader:
            images_sorted, perm_idx = sort_batch(images)
            images_packed = nn.utils.rnn.pack_padded_sequence(nn.utils.rnn.pad_sequence(images_sorted, batch_first=True), 
                                                              [v.size(0) for v in images_sorted], 
                                                              batch_first=True,
                                                              enforce_sorted=False)
            images_packed = images_packed.to(device)
            labels_sorted = labels[perm_idx]
            labels_sorted = labels_sorted.to(device)
            # images = images.to(device)
            # labels = labels.to(device)
            outputs= model(images_packed, len(images))
            loss_val = criterion(outputs, labels_sorted)
            _, predicted = torch.max(outputs.softmax(dim=1), 1)
            correct_val += (predicted == labels_sorted).sum()
        accuracy_val = 100 * torch.true_divide(correct_val, test_set.__len__())

    # Learning rate adjustment
    scheduler.step(loss_val)
    
    # Tensorboard update
    writer.add_scalar('loss/train', loss.item(), epoch+1)
    writer.add_scalar('accuracy/train', accuracy, epoch+1)
    writer.add_scalar('loss/validation', loss_val.item(), epoch+1)
    writer.add_scalar('accuracy/validation', accuracy_val, epoch+1)
    writer.add_scalar('Learning rate', optimizer.param_groups[0]['lr'], epoch+1)
    
    # tqdm update
    t.set_description('{:.4f}, {:.2f}%'.format(loss_val.item(), accuracy_val))
    
    # Update model with best validation accuracy
    if loss_val < best_loss_val:
        best_loss_val = loss_val
        checkpoint.save(model, optimizer, accuracy_val, loss_val, epoch, loss)

# close tensorboard
writer.close()


100, 0%:   0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 