Get Started (detailed)

In [None]:
import torch
import librosa
import numpy as np

from modules import feature_extractor, DetectionNet, BreathDetector

In [None]:
# model loading
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DetectionNet().to(device)
checkpoint = torch.load("respiro-en.pt")
model.load_state_dict(checkpoint["model"])
model.eval()

In [None]:
wav_path = "samples/train-clean-100_19_198_000010_000003.wav"
#wav_path = "samples/train-clean-360_14_208_000001_000000.wav"
#wav_path = "samples/train-other-500_20_205_000002_000002.wav"
wav, sr = librosa.load(wav_path, sr=16000)
feature, length = feature_extractor(wav)
feature, length = feature.to(device), length.to(device)
output = model(feature, length)

# 0.064 is the threshold obtained from our validation set
# You can try more strict thresholds like 0.5 or 0.9
threshold = 0.064

# min_length: length threshold to avoid too short detected breath, which tends to be the end part of speech
# default: 20 ms
min_length = 20

prediction = (output[0] > 0.064).nonzero().squeeze().tolist()
if isinstance(prediction, list) and len(prediction)>1:
    diffs = np.diff(prediction)
    splits = np.where(diffs != 1)[0] + 1
    splits = np.split(prediction, splits)
    splits = list(filter(lambda split: len(split)>min_length, splits))
    if len(splits)>1:
        for split in splits:
            print(split)
# The segments of breath are printed
# 229 means 229 ms

In [None]:
start = splits[1][0]
end = splits[1][-1]

from IPython.display import Audio
Audio(data=wav[int((start*0.01)*sr):int((end*0.01)*sr)], rate=sr) 

Get Started (quick)

In [None]:
import torch
from modules import DetectionNet, BreathDetector

# model loading
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DetectionNet().to(device)
checkpoint = torch.load("respiro-en.pt")
model.load_state_dict(checkpoint["model"])
model.eval()

detector = BreathDetector(model) # Args: model, device=None

tree = detector("train-clean-100_19_198_000010_000003.wav") # Args: wav_path, threshold=0.064, min_length=20
print(tree)

In [None]:
print(tree[2.6:5.2])

In [None]:
print(sorted(tree))