Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add audio slice prediction function #26

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions basic_pitch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ANNOTATIONS_BASE_FREQUENCY = 27.5 # lowest key on a piano
ANNOTATIONS_N_SEMITONES = 88 # number of piano keys
AUDIO_SAMPLE_RATE = 22050
AUDIO_SLICE_TIME = 20 # seconds of every audio slice
AUDIO_N_CHANNELS = 1
N_FREQ_BINS_NOTES = ANNOTATIONS_N_SEMITONES * NOTES_BINS_PER_SEMITONE
N_FREQ_BINS_CONTOURS = ANNOTATIONS_N_SEMITONES * CONTOURS_BINS_PER_SEMITONE
Expand Down
112 changes: 94 additions & 18 deletions basic_pitch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
AUDIO_SAMPLE_RATE,
AUDIO_N_SAMPLES,
ANNOTATIONS_FPS,
AUDIO_SLICE_TIME,
FFT_HOP,
)
from basic_pitch import ICASSP_2022_MODEL_PATH, note_creation as infer
Expand Down Expand Up @@ -70,11 +71,70 @@ def window_audio_file(audio_original: Tensor, hop_size: int) -> Tuple[Tensor, Li
return audio_windowed, window_times


def split_unwrapped_data(unwrapped_list: List[Dict[str, np.array]]) -> Dict[str, np.array]:
"""
Merge the split model inference results and return the complete result.

Returns:
A dictionary with the notes, onsets and contours.

"""
resDict = unwrapped_list[0]

if len(unwrapped_list) < 2:
return resDict

for k in resDict.keys():
for i in range(1, len(unwrapped_list)):
tempDict = unwrapped_list[i]
resDict[k] = np.append(resDict[k], tempDict[k], axis=0)

return resDict


def slice_audio(audio_original: np.array) -> List[np.array]:
"""
Cut audio Array by AUDIO_SLICE_TIME (default 5 sec) * AUDIO_SAMPLE_RATE and return slice list

Returns:
resList: list of slice
audio slice list.

"""
resList = []
original_length = audio_original.shape[0]

sliceLen = AUDIO_SAMPLE_RATE * AUDIO_SLICE_TIME
partNums = int(np.ceil(original_length / sliceLen))

for i in range(partNums):
sliceEnd = sliceLen * i + sliceLen
if sliceEnd > original_length:
sliceEnd = original_length
tempSlice = audio_original[sliceLen * i : sliceEnd]
resList.append(tempSlice)

return resList


def read_audio(audio_path: Union[pathlib.Path, str]) -> np.array:
"""
Read wave file (as mono) and return audio signal

Returns:
audio_original: np.array
original audio signal.

"""
audio_original, _ = librosa.load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True)
return audio_original


def get_audio_input(
audio_path: Union[pathlib.Path, str], overlap_len: int, hop_size: int
audio_original: np.array, overlap_len: int, hop_size: int
) -> Tuple[Tensor, List[Dict[str, int]], int]:
"""
Read wave file (as mono), pad appropriately, and return as
padding appropriately of audio signal, and return as
windowed signal, with window length = AUDIO_N_SAMPLES

Returns:
Expand All @@ -87,7 +147,7 @@ def get_audio_input(
"""
assert overlap_len % 2 == 0, "overlap_length must be even, got {}".format(overlap_len)

audio_original, _ = librosa.load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True)
# audio_original, _ = librosa.load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True)

original_length = audio_original.shape[0]
audio_original = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original])
Expand Down Expand Up @@ -139,23 +199,39 @@ def run_inference(
overlap_len = n_overlapping_frames * FFT_HOP
hop_size = AUDIO_N_SAMPLES - overlap_len

audio_windowed, _, audio_original_length = get_audio_input(audio_path, overlap_len, hop_size)
# slice audio
audio_original = read_audio(audio_path)
audio_slice_list = slice_audio(audio_original)

output = model(audio_windowed)
unwrapped_output = {k: unwrap_output(output[k], audio_original_length, n_overlapping_frames) for k in output}
unwrapped_list = []

if debug_file:
with open(debug_file, "w") as f:
json.dump(
{
"audio_windowed": audio_windowed.numpy().tolist(),
"audio_original_length": audio_original_length,
"hop_size_samples": hop_size,
"overlap_length_samples": overlap_len,
"unwrapped_output": {k: v.tolist() for k, v in unwrapped_output.items()},
},
f,
)
for i in range(len(audio_slice_list)):
audio_original_slice = audio_slice_list[i]

audio_windowed_slice, _, audio_original_length_slice = get_audio_input(
audio_original_slice, overlap_len, hop_size
)
output = model(audio_windowed_slice)
unwrapped_output_slice = {
k: unwrap_output(output[k], audio_original_length_slice, n_overlapping_frames) for k in output
}
unwrapped_list.append(unwrapped_output_slice)

if debug_file:
with open(debug_file, "a") as f:
json.dump(
{
"slice_ID": i,
"audio_windowed_slice": audio_windowed_slice.numpy().tolist(),
"audio_original_length_slice": audio_original_length_slice,
"hop_size_samples": hop_size,
"overlap_length_samples": overlap_len,
"unwrapped_output_slice": {k: v.tolist() for k, v in unwrapped_output_slice.items()},
},
f,
)
# merge all
unwrapped_output = split_unwrapped_data(unwrapped_list)

return unwrapped_output

Expand Down