## CV_TasNet - Custom Training Parts

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import glob
import time
import datetime
import numpy as np
%config Completer.use_jedi = False

import tensorflow as tf
import tensorflow.keras as keras

# Unknownerror, cudnn 어쩌고저쩌고 에러
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

In [None]:
import sys
import random
import soundfile as sf
sys.getsizeof

In [None]:
BATCH_SIZE = 8
num_spks = 2

In [None]:
def random_int(lst, size):
    # 50800짜리 혹은 3000개짜리 리스트에서 뽑은 숫자는 빼고 다시 뽑아주는 함수
    drawn, lst = lst[0:size],lst[size:]
    return drawn, lst

In [None]:
def custom_normalize(tensor):
    # tensor shape : (BATCH_SIZE, timesteps, num_spks)
    scaled = tensor - tf.math.reduce_mean(tensor, axis=1, keepdims=True)
    normalized = scaled / tf.math.reduce_std(tensor, axis=1, keepdims=True)
    return normalized

In [None]:
def calculate_sisnr(predicted_batch, target_batch, BATCH_SIZE=BATCH_SIZE, num_spks=num_spks):
    # 이렇다고 해도 여전히 3 speakers에 대해서는 안 됨. 
    N_predicted_batch = custom_normalize(predicted_batch) # (BS, length, 2)
    N_target_batch = custom_normalize(target_batch) # (BS, length, 2)
    
    product = keras.layers.Dot(axes=(1, 1))([N_predicted_batch, N_target_batch]) # (BS, 2, 2)
    
    s_target_numerator = tf.zeros((1, num_spks, num_spks), dtype=tf.float32)
    for i in range(BATCH_SIZE):
        diag_part = tf.linalg.diag_part(product[i, :, :])
        diag = tf.linalg.diag(diag_part)
        diag = tf.expand_dims(diag, axis=0)
        s_target_numerator = tf.concat([s_target_numerator, diag], axis=0)
    s_target_numerator = s_target_numerator[1:, :, :] # discard first zeros, (BS, 2, 2)
    
    s_target_denominator = tf.zeros((1, 2, 2), dtype=tf.float32)
    for i in range(BATCH_SIZE):
        a1 = tf.norm(N_target_batch[i, :, 0])
        b1 = tf.norm(N_target_batch[i, :, 1])
        diag = tf.linalg.diag([tf.norm(N_target_batch[i, :, 0]), tf.norm(N_target_batch[i, :, 1])])
        # 0인지 체크하는 부분은, 이제 4초짜리 segments들이기 때문에 사실 안 쓰일것임
        # 4초동안 소리 없을리는 없잖아. 원래는 16/8000초짜리였음
#         if a1 == 0:
#             a1 = 1e-7
#         if b1 == 0:
#             b1 = 1e-7
        diag = tf.linalg.diag([a1, b1])
        diag = tf.linalg.inv(diag)
        diag = tf.expand_dims(diag, axis=0)
        s_target_denominator = tf.concat([s_target_denominator, diag], axis=0)
    s_target_denominator = s_target_denominator[1:, :, :] # discard first zeros
    
    s_noise_vector = keras.layers.Dot(axes=2)([N_target_batch, s_target_numerator]) # 여기까지가 product * original_source인거고
    s_noise = keras.layers.Dot(axes=2)([s_noise_vector, s_target_denominator]) # 이게 divided by norm
    e_noise = N_predicted_batch - N_target_batch
    si_snr = 10 * tf.experimental.numpy.log10(tf.norm(s_noise) / tf.norm(e_noise))
    return si_snr

In [None]:
# 그럼 여기서는...input, target, target1의 리스트들만 뽑고
def get_file_lst(drawn_numbers, input_files, target_files1, target_files2):
    selected_train_input_files = [input_files[idx] for idx in drawn_numbers]
    selected_train_target_files1 = [target_files1[idx] for idx in drawn_numbers]
    selected_train_target_files2 = [target_files2[idx] for idx in drawn_numbers]
    
    return selected_train_input_files, selected_train_target_files1, selected_train_target_files2

In [None]:
def generate_ds(input_segments, target_segments1, target_segments2):
    general_ds = tf.data.Dataset.from_tensor_slices((input_segments, target_segments1, target_segments2))
    general_ds = general_ds.batch(BATCH_SIZE, drop_remainder=True).shuffle(50000)
#     general_ds = general_ds.map(lambda *x: tf.expand_dims(x, axis=-1))
    general_ds = general_ds.map(lambda x, y, z: (tf.expand_dims(x, axis=-1), tf.expand_dims(y, axis=-1), tf.expand_dims(z, axis=-1)))
    general_ds = general_ds.prefetch(tf.data.AUTOTUNE)
    return general_ds

In [None]:
# # test
# BATCH_SIZE = 64
# num_spks = 2
# target = tf.random.normal((BATCH_SIZE, 32000, 2))
# pred = tf.random.normal((BATCH_SIZE, 32000, 2))
# calculate_sisnr(target, pred)