In [17]:
import numpy as np
import uisrnn

SAVED_MODEL_NAME = 'demo_model.uisrnn'

def diarization_experiment(model_args, training_args, inference_args):
    """Experiment pipeline.

    Load data --> train model --> test model --> output result

    Args:
    model_args: model configurations
    training_args: training configurations
    inference_args: inference configurations
    """

    predicted_cluster_ids = []
    test_record = []

    train_data = np.load('../data/uisrnn/toy_training_data.npz', allow_pickle=True)
    test_data = np.load('../data/uisrnn/toy_testing_data.npz', allow_pickle=True)
    train_sequence = train_data['train_sequence']
    train_cluster_id = train_data['train_cluster_id']
    test_sequences = test_data['test_sequences'].tolist()
    test_cluster_ids = test_data['test_cluster_ids'].tolist()

    model = uisrnn.UISRNN(model_args)

    # Training.
    # If we have saved a mode previously, we can also skip training by
    # calling：
    # model.load(SAVED_MODEL_NAME)
    model.fit(train_sequence, train_cluster_id, training_args)
    model.save(SAVED_MODEL_NAME)

    # Testing.
    # You can also try uisrnn.parallel_predict to speed up with GPU.
    # But that is a beta feature which is not thoroughly tested, so
    # proceed with caution.
    for (test_sequence, test_cluster_id) in zip(test_sequences, test_cluster_ids):
        predicted_cluster_id = model.predict(test_sequence, inference_args)
        predicted_cluster_ids.append(predicted_cluster_id)
        accuracy = uisrnn.compute_sequence_match_accuracy(
            test_cluster_id, predicted_cluster_id)
        test_record.append((accuracy, len(test_cluster_id)))
        print('Ground truth labels:')
        print(test_cluster_id)
        print('Predicted labels:')
        print(predicted_cluster_id)
        print('-' * 80)

    output_string = uisrnn.output_result(model_args, training_args, test_record)

    print('Finished diarization experiment')
    print(output_string)


def main():
    """The main function."""
    model_args, training_args, inference_args = parse_arguments()
    diarization_experiment(model_args, training_args, inference_args)

In [15]:
import easydict

def parse_arguments():
    model_args = easydict.EasyDict({
        "observation_dim": 256,
        "rnn_hidden_size": 512,
        "rnn_depth": 1,
        "rnn_dropout": 0.2,
        "transition_bias": None,
        "crp_alpha": 1.0,
        "sigma2": None,
        "verbosity": 2,
        "enable_cuda": True,
    })
    training_args = easydict.EasyDict({
        "optimizer": 'adam',
        "learning_rate": 1e-3,
        "train_iteration": 1000,
        "batch_size": 10,
        "num_permutations": 10,
        "sigma_alpha": 1.0,
        "sigma_beta": 1.0,
        "regularization_weight": 1e-5,
        "grad_max_norm": 5.0,
        "enforce_cluster_id_uniqueness": True,
    })
    inference_args = easydict.EasyDict({
        "beam_size": 10,
        "look_ahead": 1,
        "test_iteration": 2,
    })

    return (model_args, training_args, inference_args)

In [18]:
main()

Iter: 0  	Training Loss: -284.1512    
    Negative Log Likelihood: 6.0657	Sigma2 Prior: -290.2176	Regularization: 0.0006
Iter: 10  	Training Loss: -299.1838    
    Negative Log Likelihood: 5.6170	Sigma2 Prior: -304.8014	Regularization: 0.0006
Iter: 20  	Training Loss: -312.5705    
    Negative Log Likelihood: 6.2607	Sigma2 Prior: -318.8318	Regularization: 0.0006
Iter: 30  	Training Loss: -331.4164    
    Negative Log Likelihood: 7.0363	Sigma2 Prior: -338.4533	Regularization: 0.0006
Iter: 40  	Training Loss: -347.3345    
    Negative Log Likelihood: 8.4671	Sigma2 Prior: -355.8022	Regularization: 0.0006
Iter: 50  	Training Loss: -367.5362    
    Negative Log Likelihood: 10.4008	Sigma2 Prior: -377.9376	Regularization: 0.0007
Iter: 60  	Training Loss: -401.8967    
    Negative Log Likelihood: 14.0404	Sigma2 Prior: -415.9377	Regularization: 0.0007
Iter: 70  	Training Loss: -445.7974    
    Negative Log Likelihood: 22.8770	Sigma2 Prior: -468.6751	Regularization: 0.0007
Iter: 80  	Tra

Iter: 640  	Training Loss: -459.2185    
    Negative Log Likelihood: 44.9334	Sigma2 Prior: -504.1533	Regularization: 0.0013
Iter: 650  	Training Loss: -494.8901    
    Negative Log Likelihood: 35.1457	Sigma2 Prior: -530.0371	Regularization: 0.0013
Iter: 660  	Training Loss: -480.8114    
    Negative Log Likelihood: 39.9788	Sigma2 Prior: -520.7915	Regularization: 0.0013
Iter: 670  	Training Loss: -511.3131    
    Negative Log Likelihood: 36.9350	Sigma2 Prior: -548.2495	Regularization: 0.0013
Iter: 680  	Training Loss: -461.6263    
    Negative Log Likelihood: 42.8624	Sigma2 Prior: -504.4901	Regularization: 0.0013
Iter: 690  	Training Loss: -475.2091    
    Negative Log Likelihood: 33.4636	Sigma2 Prior: -508.6740	Regularization: 0.0013
Iter: 700  	Training Loss: -504.2809    
    Negative Log Likelihood: 36.6231	Sigma2 Prior: -540.9053	Regularization: 0.0013
Iter: 710  	Training Loss: -482.2141    
    Negative Log Likelihood: 46.8791	Sigma2 Prior: -529.0946	Regularization: 0.0013


Ground truth labels:
['75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_2', '75_2', '75_2', '75_2', '75_2', '75_2', '75_2', '75_2', '75_2', '75_2', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_0', '75_3', '75_3', '75_3', '75_3', '75_3', '75_3', '75_3', '75_3', '75_4', '75_4', '75_4', '75_4', '75_4', '75_4', '75_4', '75_4', '75_2', '75_2', '75_2', '75_2', '75_2', '75_4', '75_4', '75_4', '75_4', '75_4', '75_4', '75_4', '75_4', '75_4', '75_4', '75_4', '75_4', '75_4', '75_3', '75_3', '75_3', '75_3', '75_3', '75_1', '75_1', '75_1', '75_1', '75_1', '75_1', '75_1', '75_1', '75_2', '75_2', '75_2']
Predicted labels:
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

Ground truth labels:
['315_0', '315_0', '315_0', '315_0', '315_0', '315_0', '315_0', '315_0', '315_0', '315_0', '315_0', '315_0', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_4', '315_4', '315_4', '315_4', '315_4', '315_4', '315_4', '315_4', '315_4', '315_2', '315_2', '315_2', '315_2', '315_2', '315_2', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_1', '315_5', '315_5', '315_5', '315_5']
Predicted labels:
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,

Ground truth labels:
['419_0', '419_0', '419_0', '419_0', '419_0', '419_0', '419_0', '419_0', '419_0', '419_0', '419_1', '419_1', '419_1', '419_1', '419_1', '419_1', '419_1', '419_1', '419_1', '419_1', '419_1', '419_4', '419_4', '419_4', '419_4', '419_4', '419_4', '419_4', '419_4', '419_4', '419_4', '419_4', '419_3', '419_3', '419_3', '419_3', '419_3', '419_3', '419_3', '419_3', '419_3', '419_3', '419_1', '419_1', '419_1', '419_1', '419_1', '419_1', '419_1', '419_1', '419_1', '419_1', '419_2', '419_2', '419_2', '419_2', '419_2', '419_2', '419_2', '419_0', '419_0', '419_0', '419_0', '419_0', '419_1', '419_1', '419_1', '419_1', '419_1', '419_1', '419_1', '419_1', '419_1', '419_1', '419_6', '419_6', '419_6', '419_6', '419_6', '419_6', '419_6', '419_6', '419_6', '419_6', '419_6', '419_6', '419_6', '419_6', '419_4', '419_4', '419_4', '419_4', '419_4', '419_4']
Predicted labels:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3,

Ground truth labels:
['491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_3', '491_3', '491_3', '491_3', '491_3', '491_3', '491_3', '491_3', '491_3', '491_3', '491_3', '491_3', '491_3', '491_3', '491_3', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_1', '491_1', '491_1', '491_1', '491_1', '491_1', '491_1', '491_1', '491_1', '491_1', '491_1', '491_1', '491_1', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_0', '491_1', '491_1', '491_1', '491_1', '491_1', '491_1', '491_1', '491_1', '491_3', '491_3', '491_3', '491_3', '491_3', '491_3', '491_3', '491_3', '491_3', '491_3', '491_3', '491_3', '491_1', '491_1', '491_1', '491_1', '491_1', '491_1', '491_1', '491_1', '491_1']
Predicted labels:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,