In [6]:
import os
import re
import numpy as np

import uisrnn
from uisrnn import contrib

In [2]:
# rttm_path = '/home/jovyan/work/voxsrc21-dia/data/voxconverse/sample/rttm'
rttm_path = '/home/jovyan/work/datasets/voxconverse/test/rttm'
save_dir_path = '/home/jovyan/work/voxsrc21-dia/embeddings/sequences'
os.makedirs(save_dir_path, exist_ok=True)

dev_intervals_path = save_dir_path + '/voxcon-dev-intervals.npy'
test_intervals_path = save_dir_path + '/voxcon-test-intervals.npy'

dev_intervals = np.load(dev_intervals_path,allow_pickle=True)
test_intervals = np.load(test_intervals_path,allow_pickle=True)

# speaker assignment should be processed with the audios, but since they weren't
# these are from the logs, to replicate the files order on embedding extraction (and interval generation)
dev_audio_ids = ["abjxc","ahnss","akthc","asxwr","atgpi","aufkn","azisu","bkwns","bspxd","bxpwa","cyyxp","czlvt","dhorc","djqif","dscgs","edixl","ehpau","eqttu","esrit","exymw","eziem","ezsgk","femmv","fkvvo","fsaal","fvyvb","fxgvy","ggvel","gocbm","gqbvk","gqdxy","grzbb","hiyis","hqyok","hycgx","ikgcq","imtug","ipqqq","iqbww","iqtde","irvat","jnivh","jynhe","kctgl","kkghn","ktzmw","kuduk","ldkmv","lfzib","lknjp","luvfz","mdbod","mekog","mesob","mgpok","migzj","mkrcv","mpvoh","mqxsf","mwfmq","nctdh","nfqjx","ntchr","nxgad","oekmc","ooxnm","pgkde","plbbw","pnyir","ppgjx","qfdpp","qouur","qppll","qrzjk","qydmg","qygfk","qzwxa","rcxzg","rxgun","sikkm","sosnj","suuxu","syiwe","szsyz","tguxv","tjkfn","tlprc","tplwz","udjij","uexjc","ulriv","uvnmy","vmbga","whmpa","wjhgf","xiglo","xxwgv","ycxxe","ydlfw","ypwjd","yrsve","yuzyu","ywcwr","zcdsd","zfkap","zmndm","zrlyl","zyffh","nnqfq","aisvi","usbgm","xvllq","oenox","praxo","onpra","kefgo","bauzd","mjgil","blwmj","gofnj","uatlu","rtvuw","wnfoi","evtyi","tcwsn","pilgb","cmfyw","dbugl","mevkw","jsdmu","jiqvr","hkzpa","hgeec","jcako","epdpg","cqaec","kkwkn","spzmn","ngyrk","sldwj","cmhsm","ndkwv","kbkon","bdopb","qhesr","cwryz","djngn","dvngl","qsfzo","gwtwd","cobal","sduml","vysqj","jtagk","hgdez","wdjyj","qpylu","tfvyr","falxo","bydui","willh","wspbh","ufpel","kiadt","nrogz","imbqf","crixb","ylnza","wewoz","qvtia","kszpd","bwzyf","xypdm","mvjuk","jhdav","pqmho","jsmbi","ccokr","ampme","odkzj","tiams","tucrg","bravd","houcx","gpjne","goyli","txcok","jyirt","oxxwk","iwdjy","kckqn","ioasm","paibn","kklpv","vbjlx","jyflp","sqkup","xmfzh","afjiv","eapdk","pnook","yfcmz","gzvkx","oklol","qjgpl","wbqza","wmori","ysgbf","zajzs","zidwg","ztzzr","zvmyn"]
test_audio_ids = ["aepyx","aiqwk","bjruf","bmsyn","bxcfq","byapz","clfcg","cqfmj","crylr","cvofp","dgvwu","dohag","dxbbt","dzsef","eauve","eazeq","eguui","epygx","eqsta","euqef","fijfi","fpfvy","fqrnu","fxnwf","fyqoe","gcfwp","gtjow","gtnjb","gukoa","guvqf","gylzn","gyomp","hcyak","heolf","hhepf","ibrnm","ifwki","iiprr","ikhje","jdrwl","jjkrt","jjvkx","jrfaz","jsbdo","jttar","jxpom","jzkzt","kajfh","kmunk","kpjud","ktvto","kvkje","lbfnx","ledhe","lilfy","ljpes","lkikz","lpola","lscfc","ltgmz","lubpm","luobn","mjmgr","msbyq","mupzb","myjoe","nlvdr","nprxc","ocfop","ofbxh","olzkb","ooxlj","oqwpd","otmpf","ouvtt","poucc","ppexo","pwnsw","qadia","qeejz","qlrry","qwepo","rarij","rmvsh","rxulz","sebyw","sexgc","sfdvy","svxzm","tkybe","tpslg","uedkc","uqxlg","usqam","vncid","vylyk","vzuru","wdvva","wemos","wprog","wwzsk","xggbk","xkgos","xlyov","xmyyy","xqxkt","xtdcl","xtzoq","xvxwv","ybhwz","ylzez","ytmef","yukhy","yzvon","zedtj","zfzlc","zowse","zqidv","zztbo","ralnu","uicid","laoyl","jxydp","pzxit","upshw","gfneh","kzmyi","nkqzr","kgjaa","dkabn","eucfa","erslt","mclsr","fzwtp","dzxut","pkwrt","gmmwm","leneg","sxqvt","pgtkk","fuzfh","vtzqw","rsypp","qxana","optsn","dxokr","ptses","isxwc","gzhwb","mhwyr","duvox","ezxso","jgiyq","rpkso","kmjvh","wcxfk","gcvrb","eddje","pccww","vuewy","tvtoe","oubab","jwggf","aggyz","bidnq","neiye","mkhie","iowob","jbowg","gwloo","uevxo","nitgx","eoyaz","qoarn","mxdpo","auzru","diysk","cwbvu","jeymh","iacod","cawnd","vgaez","bgvvt","tiido","aorju","qajyo","ryken","iabca","tkhgs","tbjqx","mqtep","fowhl","fvhrk","nqcpi","mbzht","uhfrw","utial","cpebh","tnjoh","jsymf","vgevv","mxduo","gkiki","bvyvm","hqhrb","isrps","nqyqm","dlast","pxqme","bpzsc","vdlvr","lhuly","crorm","bvqnu","tpnyf","thnuq","swbnm","cadba","sbrmv","wibky","wlfsf","wwvcs","xffsa","xkmqx","xlsme","ygrip","ylgug","ytula","zehzu","zsgto","zzsba","zzyyo"]


# train/test_sequences -> list of sequences: M (utterances) * L (segments) * D (embeddings dimension) 
# train/test_cluster_ids -> list of sequences speakers ids: M * L * 1 (string)
train_sequences = np.load('/home/jovyan/work/sequences/fixed-voxcon-dev-sequences.npy', allow_pickle=True).tolist()
train_cluster_ids = np.load('/home/jovyan/work/voxsrc21-dia/embeddings/sequences/voxcon-dev-cluster-ids.npy', allow_pickle=True).tolist()

test_sequences = np.load('/home/jovyan/work/sequences/fixed-voxcon-test-sequences.npy', allow_pickle=True).tolist()
test_cluster_ids = np.load('/home/jovyan/work/voxsrc21-dia/embeddings/sequences/voxcon-test-cluster-ids.npy', allow_pickle=True).tolist()


## CRIAR FOLDS AQUI

intervals = dev_intervals
audio_ids = dev_audio_ids
sequences = train_sequences
cluster_ids = train_cluster_ids

In [39]:
from contrib import  _get_cdf

cdf = _get_cdf(concat_ids, 1)

In [40]:
cdf

-13138.010143738391

In [37]:
# import contrib


# for m, segments_ids in enumerate(cluster_ids):
#     for l, speaker_id in enumerate(segments_ids):
#         cluster_ids[m][l] = str(int(re.sub('[^0-9]', "", speaker_id))) # int to remove 0 in 00, 01, ...

# concat_seq, concat_ids=uisrnn.utils.concatenate_training_data(train_sequences,cluster_ids)
# contrib.estimate_crp_alpha(concat_ids)

estimated_crp_alpha = 0.76

In [None]:
import numpy as np
from functools import partial
import torch.multiprocessing as mp
ctx = mp.get_context('forkserver')

import uisrnn

SAVED_MODEL_NAME = 'voxcon_dev_model.uisrnn'

NUM_WORKERS = 2

def diarization_experiment(model_args, training_args, inference_args):

    train_sequence = np.load('/app/fixed-voxcon-dev-sequences.npy', allow_pickle=True).tolist()
    train_cluster_id = np.load('/app/voxsrc21-dia/embeddings/sequences/voxcon-dev-cluster-ids.npy', allow_pickle=True).tolist()
    
    test_sequences = np.load('/app/fixed-voxcon-test-sequences.npy', allow_pickle=True).tolist()
    test_cluster_ids = np.load('/app/voxsrc21-dia/embeddings/sequences/voxcon-test-cluster-ids.npy', allow_pickle=True).tolist()
    
    # How many elements each list should have
    n = 53

    # using list comprehension 
    split_train_sequence = [train_sequence[i:i + n] for i in range(0, len(train_sequence), n)]
    split_train_cluster_id = [train_cluster_id[i:i + n] for i in range(0, len(train_cluster_id), n)]
    
    training_args.train_iteration = 300
    model_args.crp_alpha = estimated_crp_alpha
    model = uisrnn.UISRNN(model_args)
    
    for sequence, cluster_id in zip(split_train_sequence, split_train_cluster_id):
#         concatenated_train_sequence = np.concatenate(sequence)
#         concatenated_train_cluster_id = np.concatenate(cluster_id)
    
        # Training
        model.fit(sequence, cluster_id, training_args)        
        model.save(SAVED_MODEL_NAME)

    
    # testing
    predicted_cluster_ids = []
    test_record = []
    # predict sequences in parallel
    model.rnn_model.share_memory()
    pool = ctx.Pool(NUM_WORKERS, maxtasksperchild=None)
    pred_gen = pool.imap(func=partial(model.predict, args=inference_args), iterable=test_sequences)
    # collect and score predicitons
    for idx, predicted_cluster_id in enumerate(pred_gen):
        accuracy = uisrnn.compute_sequence_match_accuracy(test_cluster_ids[idx], predicted_cluster_id)
        predicted_cluster_ids.append(predicted_cluster_id)
        test_record.append((accuracy, len(test_cluster_ids[idx])))
        print('Ground truth labels:')
        print(test_cluster_ids[idx])
        print('Predicted labels:')
        print(predicted_cluster_id)
        print('-' * 80)

    # close multiprocessing pool
    pool.close()

    print('Finished diarization experiment')
    print(uisrnn.output_result(model_args, training_args, test_record))
    
    
    
    
def main():
    """The main function."""
    model_args, training_args, inference_args = uisrnn.parse_arguments()
    print(model_args, training_args, inference_args)
    diarization_experiment(model_args, training_args, inference_args)


if __name__ == "__main__":
    main()
    print('Program completed!')

## Testes

In [None]:
smooth = 1
transit_num = smooth
bias_denominator = 2 * smooth
for cluster_id_seq in train_cluster_ids:
    for entry in range(len(cluster_id_seq) - 1):
        transit_num += (cluster_id_seq[entry] != cluster_id_seq[entry + 1])
        bias_denominator += 1
bias = transit_num / bias_denominator

print(bias, bias_denominator)

In [None]:
label_to_center = {
    'A': np.array([0.0, 0.0]),
    'B': np.array([0.0, 1.0]),
    'C': np.array([1.0, 0.0]),
    'D': np.array([1.0, 1.0]),
}

In [None]:
 # generate training data
train_cluster_id = ['A'] * 400 + ['B'] * 300 + ['C'] * 200 + ['D'] * 100
random.shuffle(train_cluster_id)
train_sequence = _generate_random_sequence(
    train_cluster_id, label_to_center, sigma=0.01)
train_sequences = [
    train_sequence[:100, :],
    train_sequence[100:300, :],
    train_sequence[300:600, :],
    train_sequence[600:, :]
]

In [None]:
if isinstance(train_sequences, np.ndarray):
    # train_sequences is already the concatenated sequence
    print('np.array is bad')
elif isinstance(train_sequences, list):
    # train_sequences is a list of un-concatenated sequences
    # we will concatenate it later, after estimating transition_bias
    print('tá top pai')
    pass
else:
    raise TypeError('train_sequences must be a list or numpy.ndarray')

In [None]:
if not isinstance(train_sequences, list) or not isinstance(train_cluster_ids, list):
    raise TypeError('train_sequences and train_cluster_ids must be lists')
if len(train_sequences) != len(train_cluster_ids):
    raise ValueError('train_sequences and train_cluster_ids must have same size')

train_cluster_ids = [x.tolist() if isinstance(x, np.ndarray) else x for x in train_cluster_ids]
global_observation_dim = None

# print(train_cluster_ids.shape, train_cluster_ids[0].shape)


for i, (train_sequence, train_cluster_id) in enumerate(zip(train_sequences, train_cluster_ids)):
    train_length, observation_dim = train_sequence.shape
    print(train_length,observation_dim)
    if i == 0:
        global_observation_dim = observation_dim
    elif global_observation_dim != observation_dim:
        raise ValueError('train_sequences must have consistent observation dimension')
    if not isinstance(train_cluster_id, list):
        raise TypeError('Elements of train_cluster_ids must be list or numpy.ndarray')
    if len(train_cluster_id) != train_length:
        raise ValueError('Each train_sequence and its train_cluster_id must have same length')