diff --git a/basic_pitch/constants.py b/basic_pitch/constants.py index a78a487b..aede4ae1 100644 --- a/basic_pitch/constants.py +++ b/basic_pitch/constants.py @@ -27,6 +27,7 @@ ANNOTATIONS_BASE_FREQUENCY = 27.5 # lowest key on a piano ANNOTATIONS_N_SEMITONES = 88 # number of piano keys AUDIO_SAMPLE_RATE = 22050 +AUDIO_SLICE_TIME = 20 # seconds of every audio slice AUDIO_N_CHANNELS = 1 N_FREQ_BINS_NOTES = ANNOTATIONS_N_SEMITONES * NOTES_BINS_PER_SEMITONE N_FREQ_BINS_CONTOURS = ANNOTATIONS_N_SEMITONES * CONTOURS_BINS_PER_SEMITONE diff --git a/basic_pitch/inference.py b/basic_pitch/inference.py index 1393ff37..3a807752 100644 --- a/basic_pitch/inference.py +++ b/basic_pitch/inference.py @@ -31,6 +31,7 @@ AUDIO_SAMPLE_RATE, AUDIO_N_SAMPLES, ANNOTATIONS_FPS, + AUDIO_SLICE_TIME, FFT_HOP, ) from basic_pitch import ICASSP_2022_MODEL_PATH, note_creation as infer @@ -70,11 +71,70 @@ def window_audio_file(audio_original: Tensor, hop_size: int) -> Tuple[Tensor, Li return audio_windowed, window_times +def split_unwrapped_data(unwrapped_list: List[Dict[str, np.array]]) -> Dict[str, np.array]: + """ + Merge the split model inference results and return the complete result. + + Returns: + A dictionary with the notes, onsets and contours. + + """ + resDict = unwrapped_list[0] + + if len(unwrapped_list) < 2: + return resDict + + for k in resDict.keys(): + for i in range(1, len(unwrapped_list)): + tempDict = unwrapped_list[i] + resDict[k] = np.append(resDict[k], tempDict[k], axis=0) + + return resDict + + +def slice_audio(audio_original: np.array) -> List[np.array]: + """ + Cut audio Array by AUDIO_SLICE_TIME (default 5 sec) * AUDIO_SAMPLE_RATE and return slice list + + Returns: + resList: list of slice + audio slice list. + + """ + resList = [] + original_length = audio_original.shape[0] + + sliceLen = AUDIO_SAMPLE_RATE * AUDIO_SLICE_TIME + partNums = int(np.ceil(original_length / sliceLen)) + + for i in range(partNums): + sliceEnd = sliceLen * i + sliceLen + if sliceEnd > original_length: + sliceEnd = original_length + tempSlice = audio_original[sliceLen * i : sliceEnd] + resList.append(tempSlice) + + return resList + + +def read_audio(audio_path: Union[pathlib.Path, str]) -> np.array: + """ + Read wave file (as mono) and return audio signal + + Returns: + audio_original: np.array + original audio signal. + + """ + audio_original, _ = librosa.load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True) + return audio_original + + def get_audio_input( - audio_path: Union[pathlib.Path, str], overlap_len: int, hop_size: int + audio_original: np.array, overlap_len: int, hop_size: int ) -> Tuple[Tensor, List[Dict[str, int]], int]: """ - Read wave file (as mono), pad appropriately, and return as + padding appropriately of audio signal, and return as windowed signal, with window length = AUDIO_N_SAMPLES Returns: @@ -87,7 +147,7 @@ def get_audio_input( """ assert overlap_len % 2 == 0, "overlap_length must be even, got {}".format(overlap_len) - audio_original, _ = librosa.load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True) + # audio_original, _ = librosa.load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True) original_length = audio_original.shape[0] audio_original = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original]) @@ -139,23 +199,39 @@ def run_inference( overlap_len = n_overlapping_frames * FFT_HOP hop_size = AUDIO_N_SAMPLES - overlap_len - audio_windowed, _, audio_original_length = get_audio_input(audio_path, overlap_len, hop_size) + # slice audio + audio_original = read_audio(audio_path) + audio_slice_list = slice_audio(audio_original) - output = model(audio_windowed) - unwrapped_output = {k: unwrap_output(output[k], audio_original_length, n_overlapping_frames) for k in output} + unwrapped_list = [] - if debug_file: - with open(debug_file, "w") as f: - json.dump( - { - "audio_windowed": audio_windowed.numpy().tolist(), - "audio_original_length": audio_original_length, - "hop_size_samples": hop_size, - "overlap_length_samples": overlap_len, - "unwrapped_output": {k: v.tolist() for k, v in unwrapped_output.items()}, - }, - f, - ) + for i in range(len(audio_slice_list)): + audio_original_slice = audio_slice_list[i] + + audio_windowed_slice, _, audio_original_length_slice = get_audio_input( + audio_original_slice, overlap_len, hop_size + ) + output = model(audio_windowed_slice) + unwrapped_output_slice = { + k: unwrap_output(output[k], audio_original_length_slice, n_overlapping_frames) for k in output + } + unwrapped_list.append(unwrapped_output_slice) + + if debug_file: + with open(debug_file, "a") as f: + json.dump( + { + "slice_ID": i, + "audio_windowed_slice": audio_windowed_slice.numpy().tolist(), + "audio_original_length_slice": audio_original_length_slice, + "hop_size_samples": hop_size, + "overlap_length_samples": overlap_len, + "unwrapped_output_slice": {k: v.tolist() for k, v in unwrapped_output_slice.items()}, + }, + f, + ) + # merge all + unwrapped_output = split_unwrapped_data(unwrapped_list) return unwrapped_output