In [None]:
import torch
from ray import tune, air
from ray.air import session
from ray.tune.search.optuna import OptunaSearch
import gc
from torch.autograd import Variable
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader

from model import *
from data import *

In [None]:
TRAIN_DATA_LABELS = "/Corrupted_speach/labels.csv"
SERIALIZED_DATA_FOLDER_TRAIN = "/serialized_train/"
SERIALIZED_DATA_FOLDER_TEST = "/serialized_test/"

SAMPLE_RATE = 48000
BATCH_SIZE = 32
WINDOW_SIZE = 2**16

In [None]:
def train(model, criterion, optimizer, train_data_loader):
    train_batch_counter = 0
    loss_train = 0
    for train_clean, train_noisy in train_data_loader:
        model.train()
        
        if torch.cuda.is_available():
            train_clean, train_noisy = train_clean.cuda(), train_noisy.cuda()
        train_clean, train_noisy = Variable(train_clean), Variable(train_noisy)

        model.zero_grad()

        output = model(train_noisy.squeeze(1))
        loss = criterion(source=train_clean.squeeze(1), estimate_source=output)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 3)
        optimizer.step()
        loss_train += loss.item()
        train_batch_counter += 1

    # clear cache
    gc.collect()
    torch.cuda.empty_cache()

    return model 

In [None]:
def test(model, criterion, test_data_loader):
    test_batch_counter = 0
    loss_test = 0
    model.eval()
    with torch.no_grad():
        for test_clean, test_noisy in test_data_loader:    
            if torch.cuda.is_available():
                test_clean, test_noisy = test_clean.cuda(), test_noisy.cuda()
            test_clean, test_noisy = Variable(test_clean), Variable(test_noisy)

            output = model(test_noisy.squeeze(1))
            loss = criterion(source=test_clean.squeeze(1), estimate_source=output)

            loss_test += loss.item()
            test_batch_counter += 1

    return loss_test/test_batch_counter

In [None]:
def objective(config):
    model = Pytorch_DTLN(frame_len=1536, 
                         frame_hop=384, 
                         dropout=config['separation_dropout'],
                         encoder_size=config['encoder_size'],
                         hidden_size=int(config['encoder_size']/2),
                         LSTM_size=4
                         )
    optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
    criterion = SiSnr()

    SERIALIZED_DATA_FOLDER_TRAIN = "/serialized_train/"
    SERIALIZED_DATA_FOLDER_TEST = "/serialized_test/"

    train_dataset = AudioDataset(data_path=SERIALIZED_DATA_FOLDER_TRAIN)
    test_dataset = AudioDataset(data_path=SERIALIZED_DATA_FOLDER_TEST)

    train_data_loader = DataLoader(dataset=train_dataset, batch_size=config['batch_size'], shuffle=False)
    test_data_loader = DataLoader(dataset=test_dataset, batch_size=config['batch_size'], shuffle=False)

    while True:
        print('iteration')
        model = train(model, criterion, optimizer, train_data_loader)  # Train the model
        loss = test(model, criterion, test_data_loader)  # Compute test accuracy
        session.report({"test_loss": loss})  # Report to Tune

In [None]:
search_space = {"lr": tune.choice([1e-5]), 
                "separation_dropout": tune.choice([0.2, 0.25, 0.3]),
                "encoder_size": tune.choice([256, 512, 1024]),
                "batch_size": tune.choice([8, 16, 32]),
                "LSTM_size": tune.choice({2, 3, 4})}
algo = OptunaSearch()

In [None]:
tuner = tune.Tuner(
    objective,
    tune_config=tune.TuneConfig(
        metric="test_loss",
        mode="min",
        search_alg=algo,
    ),
    run_config=air.RunConfig(
        stop={"training_iteration": 50},
    ),
    param_space=search_space,
)
results = tuner.fit()
print("Best config is:", results.get_best_result().config)