In [None]:
# 0.0 import packages

import numpy as np
import time

import torch
import torch.nn as nn
import torch.optim as optim
from numba import cuda as numba
from GPUtil import showUtilization as gpu_usage
from tqdm import tqdm
import time
from utils.data_loader import data_provider
import matplotlib.pyplot as plt
from models.bi_lstm import bi_LSTM
from models.transformer import Transformer
from models.bert_inspired import BertInspired
from utils.tools import dotdict

In [None]:
# 0.2 GPU stuff
device_num = 1
device = torch.device(f"cuda:{device_num}" if torch.cuda.is_available() else "cpu")
print("torch device: ", torch.cuda.get_device_name(device))
#device = torch.device("cpu")

# function to clear GPU memory
def free_gpu_cache():                        
    torch.cuda.empty_cache()
    numba.select_device(device_num)
    numba.close()
    numba.select_device(device_num)
    print("GPU Usage after emptying the cache")
    gpu_usage()

In [None]:
# 3.1 helper functions for training

def test_network(model, test_loader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            # get data
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            #print ("labels: ", labels)
            #print ("pred: ", outputs)
            total = labels.shape[0] * labels.shape[1]
            correct = 0
            for i, frame in enumerate(labels):
                #print (i, " frame: ", frame)
                #print (i, " outputs[i]: ", outputs[i])
                for val in torch.eq(frame, outputs[i]):
                    if val:
                        correct += 1
            
    return 100 * correct / total

def print_stats(iteration_list, accuracy_list, loss_list):
    # final accuracy plot        
    plt.plot(iteration_list, accuracy_list)
    plt.title("accuracy over time")
    plt.xlabel("iterations")
    plt.ylabel("accuracy")
    plt.show()
    
    # final loss plot        
    plt.plot(iteration_list, loss_list)
    plt.title("loss over time")
    plt.xlabel("iterations")
    plt.ylabel("loss")
    plt.show()

In [None]:
model_type = "bert_inspired" # "bert_inspired" "transformer" or "biLSTM"

model = None
if model_type == "biLSTM":
    # Create model
    config = dotdict({
        "input_dim": 128,
        "hidden_dim": 128,
        "output_dim": 9,
        "num_layers": 2,
        "model_type": model_type
    })
    # create model
    model = bi_LSTM(config)
    model.to(device)
elif model_type == "transformer":
    # 0 = ????
    config = dotdict({
        "enc_in": 128,
        "dec_in": 128,
        "c_out": 9,
        "d_model": 128,
        "dropout": .05,
        "output_attention": False,
        "n_heads": 8,
        "d_ff": None,
        "activation": "gelu",
        "e_layers": 2,
        "d_layers": 1,
        "model_type": model_type
    })
    model = Transformer(config)
    model.to(device)
elif model_type == "bert_inspired":
    config = dotdict({
        "enc_in": (32, 16), # (#windows, # mel filters)
        "c_out": 9,
        "d_model": 512,
        "dropout": .05,
        "output_attention": False,
        "n_heads": 8,
        "d_ff": None,
        "activation": "gelu",
        "e_layers": 12,
        "model_type": model_type
    })
    model = BertInspired(config)
    model.to(device)

assert model is not None, "Didn't select a valid model"

In [None]:
# 3.0 Training

# free_gpu_cache()

# training parameters
batch_size = 512
learning_rate = 0.0000001
num_epochs = 5

# get dataloaders
# train_dataset, train_loader, test_dataset, test_loader, val_dataset, val_loader  = set_up_dataloaders(
#     batch_size, load_datasets(data_id="_5s_50hz", print_out=True)
# )

config = dotdict({
        "batch_size": batch_size,
        "num_workers": 0,
        "seq_len": 9,
        "data_id": "32x16"
    })

train_dataset, train_loader = data_provider(config, flag="train")
val_dataset, val_loader = data_provider(config, flag="val")
test_dataset, test_loader = data_provider(config, flag="test")
# seq_length x num_windows x num_mel_filters

# print ("device name: ", torch.cuda.get_device_name(0))
# print ("model.type: ", myModel.model_type)
# print ("model.device: ", next(myModel.parameters()).device)

criterion = nn.BCEWithLogitsLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# lists for data collection
iter = 0
delta = 100
iteration_list = []
accuracy_list = []
loss_list = []

# Perform epochs
startTime = time.time()
min_valid_loss = np.inf
for epoch in range(num_epochs):
    model.train()
    total = 0
    correct = 0
    train_loss = 0.0
    for batch_index, (feats, labels) in enumerate(tqdm(train_loader)):
        feats = feats.float().to(device)
        labels = labels.float().to(device)

        optimizer.zero_grad()
        
        # Forward
        output = model(feats)
        loss = criterion(output, labels)
        
        # Backward
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)
    valid_loss = 0.0
    model.eval()     # Optional when not using Model Specific layer
    for batch_index, (feats, labels) in enumerate(tqdm(val_loader)):
        # Transfer Data to GPU if available
        feats = feats.float().to(device)
        labels = labels.float().to(device)
        output = model(feats)
        loss = criterion(output,labels)
        # Calculate Loss
        valid_loss += loss.item()
    valid_loss /= len(val_loader)
    print(f"Epoch {epoch}\t\tTraining Loss: {train_loss}\t\tValidation Loss: {valid_loss}") 
    if min_valid_loss > valid_loss:
        print(f"Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss:.6f})\tSaving The Model")
        min_valid_loss = valid_loss
         
        # Saving State Dict
        torch.save(model.state_dict(), 'saved_model.pth')

    
    # print(f'\t iteration: {iter}\t loss: {loss_list[len(loss_list)-1].item():.3f}\t accuracy: {accuracy_list[len(accuracy_list)-1]:.3f} %') 
    # print('Test accuracy: %d %%' % (100 * correct / total)) 
        # test accuracy and log stats
        # if iter % delta == 0 and iter != 0:
        #     print("Testing Network")
        #     acc = test_network(model, test_loader)
        #     iteration_list.append(iter)
        #     accuracy_list.append(acc)
        #     loss_list.append(loss)
        #     print(f'\t iteration: {iter}\t loss: {loss_list[len(loss_list)-1].item():.3f}\t accuracy: {accuracy_list[len(accuracy_list)-1]:.3f} %')
    
        # # increase iteration
        # iter += 1

print ("time elapsed: ", round((time.time() - startTime), 2), " sec")