In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import librosa
import librosa.display
from IPython.display import Audio
from scipy.signal import find_peaks
from scipy.signal import butter, sosfilt, sosfreqz
import math
import os

In [3]:
sample_rate = 48000

In [4]:
# Extracts timestamps from a .txt file
# Each line in the file contains a start and finish time separated by a tab.

def get_real_timestamps(audio_timestamp_file):
    
    timestamps = []

    with open(audio_timestamp_file, "r") as f:
        for line in f:
            line = line.strip()
            if line:
                start, end = map(float, line.split("\t"))
                timestamps.append((start, end))

    return timestamps

In [5]:
import os
import librosa

# Function to load audio data along with corresponding timestamps and labels.
def load_data(audio_dir, timestamp_dir, label, sr=48000):
    
    """
    Loads audio data from WAV files along with corresponding timestamps and labels.

    Parameters:
        audio_dir (str): Directory containing audio WAV files.
        timestamp_dir (str): Directory containing timestamp text files.
        label (str): Label to assign to the loaded data.
        sr (int): Sample rate of the audio files (default is 48000).

    Returns:
        tuple: A tuple containing lists of loaded audio data, timestamps, and labels.
    """
    
    audio_files = [file for file in os.listdir(audio_dir) if file.endswith('.wav')]
    timestamp_files = [file for file in os.listdir(timestamp_dir) if file.endswith('.txt')]
    
    
    audios = []
    timestamps = []
    labels = []
    
    for audio_file in audio_files:
    
        try:
            
            base_name = os.path.splitext(audio_file)[0]
            timestamp_file = base_name + '-label.txt'
            
            if timestamp_file not in timestamp_files:
                raise ValueError(f"No corresponding timestamp file found for {audio_file}")
                
            audio_path = os.path.join(audio_dir, audio_file)
            timestamp_path = os.path.join(timestamp_dir, timestamp_file)
            
            # Adding timestamps to the list
            real_timestamps = get_real_timestamps(timestamp_path)

            # Loading audio file
            data, sample_rate = librosa.load(audio_path, sr=sr, mono=True)
            data = librosa.resample(data, orig_sr=sample_rate, target_sr=sr)

            audios.append(data)
            timestamps.append(real_timestamps)
            labels.append(label)

        except Exception as e:
            print(f"Error processing {audio_file}: {e}")
            continue

    
    return audios, timestamps, labels

In [6]:
# This function designs a Butterworth bandpass filter and applies it to the input data.

def butter_bandpass(lowcut, highcut, fs, order=8):

    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    sos = butter(order, [low, high], analog=False, btype='band', output='sos')
    return sos

def butter_bandpass_filter(data, lowcut, highcut, fs, order=8):

    sos = butter_bandpass(lowcut, highcut, fs, order=order)
    y = sosfilt(sos, data)
    return y

In [7]:
# Function to normalize the given data

def normalize_data(data):

    # Find the minimum and maximum values in the data
    data_min = np.min(data)
    data_max = np.max(data)

    # Normalize the data using min-max scaling
    normalized_data = (data - data_min) / (data_max - data_min)

    return normalized_data

In [8]:
# Moving Average of the Data

def compute_moving_average(data, window_size=15):

    kernel = np.ones(window_size) / window_size
    moving_averages = np.convolve(data, kernel, mode='valid')
    moving_averages = np.round(moving_averages, 2)
    
    return moving_averages

In [13]:
# Preprocesses the input data by applying a bandpass filter, computing moving average, and normalizing the data.
def preprocess_data(data):
    
    """
    Parameters:
    - data (ndarray): Input data to be preprocessed.
        
    Returns:
    - ndarray: Preprocessed data.

    Note:
    - This function applies a bandpass filter, computes the moving average, and normalizes the data.
    """

    # Apply bandpass filter to remove unwanted frequencies
    filtered_data = butter_bandpass_filter(data, lowcut=1000, highcut=4000, fs=48000, order=8)

    # Compute moving average of the absolute values of the filtered data
    averaged_data = compute_moving_average(np.abs(filtered_data))

    # Normalize the data
    normalized_data = normalize_data(averaged_data)

    # Flatten the data
    flattened_data = normalized_data.flatten()

    return flattened_data

In [14]:
# Summation of the frequency bins in a spectogram

def get_frequency_sums(data):
    
    sums = []
    S_db_tp = np.transpose(mel_spectrogram_db)

    for bin in S_db_tp:
        sums.append(np.sum(bin))

    sums = np.array(sums)

    return sums

In [15]:
def detect_coughs(data, percentile_threshold = 99.8):
    
    # Statistics of the data
    max_value = np.max(data)
    mean_value = np.mean(data)
    std = np.std(data)
    data_length = len(data)
    
    # Threshold 
    threshold = np.percentile(data, percentile_threshold)
    
    # Peak detection
    cough_indices, _ = find_peaks(data, prominence = 5, height=threshold)
    cough_indices = list(cough_indices)

    # Deleting overlaps in the peaks - Avoiding to count same cough more than one
    i = 0
    while i < len(cough_indices):

        peak = cough_indices[i]
        peak_range = (peak - 50, peak + 50) 
        overlap_indices = [index for index in cough_indices
                       if peak_range[0] < index < peak_range[1]]

        if len(overlap_indices) > 1:

            # Find the index with maximum amplitude 
            max = overlap_indices[0]
            for index in overlap_indices:
                if (data[index] > data[max]):
                    max = index

            overlap_indices.remove(max)

            for element in overlap_indices:
                cough_indices.remove(element)

        i += 1

    # Finding the timestamps of the coughs
    predicted_timestamps = [] 
    for index in cough_indices:
        predicted_timestamps.append(round(index / sample_rate, 6))
    

    # Scaling timestamps according to original recording duration
    for i in range(len(predicted_timestamps)):
        
        actual_index = total_samples * cough_indices[i] / data_length
        actual_timestamp = round(actual_index / sample_rate, 6)
        
        predicted_timestamps[i] = actual_timestamp
    

    # Filtering the sound after coughing
    for ts in predicted_timestamps:
        match = list((ts_2 for ts_2 in predicted_timestamps if ts < ts_2 < ts + 0.2))
        if len(match) != 0:
            index = predicted_timestamps.index(match[0])
            predicted_timestamps.remove(match[0])
            cough_indices.remove(cough_indices[index])

    return cough_indices, predicted_timestamps

In [18]:
audio_dir = '../recordings-and-timestamps/audio-recordings/cough-recordings/'
timestamp_dir = '../recordings-and-timestamps/audio-timestamps/cough-timestamps/'

audio_files, timestamp_files,_ = load_data(audio_dir, timestamp_dir, 'cough')

all_timestamp_data = dict()

# Getting cough detection results
i = 1

for data, real_timestamps in zip(audio_files, timestamp_files):
    
    preprocessed_data = preprocess_data(data)
    total_samples = len(data)
        
    # Getting melspectrogram
    mel_spectrogram = librosa.feature.melspectrogram(y=preprocessed_data, sr=sample_rate)
    mel_spectrogram_db = librosa.amplitude_to_db(mel_spectrogram, ref=np.max)    

    frequency_sums = get_frequency_sums(mel_spectrogram_db)
    
    # Cough detection
    cough_indices, predicted_timestamps = detect_coughs(frequency_sums)

    # Adding predicted and real timestamp tuples to the list
    all_timestamp_data[i]=(predicted_timestamps, real_timestamps)

    cough_count = len(cough_indices)

    # The results
    print(f"\nAudio {i}")
    print("---------------------")
    print(f"Cough Count: {cough_count}")
    print("Cough Predicted Timestamps: {}".format([round(timestamp, 6) for timestamp in predicted_timestamps]))
    print("Cough Real Timestamps: {}".format([timestamp for timestamp in real_timestamps]))
    print(f"Cough Indices: {cough_indices}\n")
    
    i += 1


Audio 1
---------------------
Cough Count: 5
Cough Predicted Timestamps: [3.605262, 9.706475, 17.642318, 25.172836, 29.972741]
Cough Real Timestamps: [(3.33773, 3.927718), (9.447733, 9.969193), (12.818076, 13.29121), (17.372347, 17.900496), (25.149178, 25.678278), (29.956115, 30.395156), (40.311107, 40.808312), (49.672679, 50.157047)]
Cough Indices: [338, 910, 1654, 2360, 2810]


Audio 2
---------------------
Cough Count: 9
Cough Predicted Timestamps: [3.765297, 9.503908, 14.154529, 18.581153, 28.639722, 37.642302, 43.039583, 47.380874, 52.810154]
Cough Real Timestamps: [(7.6774, 8.272404), (25.127484, 25.722488), (47.316563, 47.820028), (3.752714, 3.839789), (9.459266, 9.819177), (14.123593, 14.384819), (18.552817, 18.78792), (28.57517, 28.862518), (37.593255, 37.845773), (39.212853, 39.523421), (43.026746, 43.290874), (52.732725, 53.263884)]
Cough Indices: [353, 891, 1327, 1742, 2685, 3529, 4035, 4442, 4951]


Audio 3
---------------------
Cough Count: 2
Cough Predicted Timestamps: 


Audio 23
---------------------
Cough Count: 5
Cough Predicted Timestamps: [8.521292, 31.856194, 38.777744, 40.601449, 47.618984]
Cough Real Timestamps: [(1.216182, 1.584015), (1.642094, 1.990567), (2.100845, 2.440502), (4.454706, 4.894158), (8.483467, 8.911278), (12.592782, 12.959478), (13.166107, 13.471687), (13.602649, 13.943152), (26.621427, 27.024116), (27.268827, 27.581686), (27.779933, 28.086596), (31.356926, 31.811758), (31.829227, 32.18609), (47.58344, 47.939619), (48.053061, 48.356716), (48.488027, 48.794965)]
Cough Indices: [799, 2987, 3636, 3807, 4465]


Audio 24
---------------------
Cough Count: 2
Cough Predicted Timestamps: [25.28653, 44.366075]
Cough Real Timestamps: [(1.179008, 2.023963), (25.216251, 25.871442), (44.311634, 44.958205)]
Cough Indices: [2371, 4160]


Audio 25
---------------------
Cough Count: 3
Cough Predicted Timestamps: [0.693323, 5.439921, 14.837117]
Cough Real Timestamps: [(0.555423, 1.243091), (2.08945, 2.644874), (5.355869, 5.898068), (14.798068, 


Audio 47
---------------------
Cough Count: 4
Cough Predicted Timestamps: [6.612866, 12.21247, 29.939218, 40.050503]
Cough Real Timestamps: [(1.381414, 1.916155), (6.497422, 6.942235), (7.096208, 7.583792), (12.172479, 12.788373), (21.025968, 21.504997), (29.879457, 30.392703), (39.956182, 40.486536), (50.734343, 51.3759)]
Cough Indices: [620, 1145, 2807, 3755]


Audio 48
---------------------
Cough Count: 2
Cough Predicted Timestamps: [13.056, 40.778667]
Cough Real Timestamps: [(13.006581, 13.401668), (40.708352, 41.0746)]
Cough Indices: [1224, 3823]


Audio 49
---------------------
Cough Count: 4
Cough Predicted Timestamps: [1.397328, 4.085318, 7.925303, 30.538551]
Cough Real Timestamps: [(1.359574, 1.767104), (2.006828, 2.376687), (4.044477, 4.291049), (4.479403, 4.73625), (7.852743, 8.202204), (30.480484, 30.843291), (42.53321, 43.004387)]
Cough Indices: [131, 383, 743, 2863]


Audio 50
---------------------
Cough Count: 5
Cough Predicted Timestamps: [12.148794, 33.171859, 37.8649

In [23]:
# Comparing the real timestamps and predicted timestamps.

def check_performance(predicted_timestamps, real_timestamps):

    # Margin of error allowed in matching timestamps, in seconds
    time_margin = 0.40

    # Display the lists of predicted and actual cough timestamps
    #print(predicted_timestamps, "\n")
    #print(real_timestamps, "\n")

    # Initialize counters for true positives and false positives
    true_positive = 0 
    false_positive = 0

    # Determine the longer list between predicted and actual cough timestamps

    if len(predicted_timestamps) > len(real_timestamps):
        for pred in predicted_timestamps:

            # Check for matches between the predicted timestamp and actual cough timestamps within a margin of error
            match = list((rt for rt in real_timestamps if pred - time_margin < rt[0] < pred + time_margin or pred - time_margin < rt[1] < pred + time_margin))

            if len(match) != 0:
                true_positive += 1
            else:
                false_positive += 1

    else:

        for pred in real_timestamps:

            # Check for matches between the actual timestamp and predicted cough timestamps within a margin of error
            match = list((rt for rt in predicted_timestamps  if pred[0] - time_margin < rt < pred[0] + time_margin or pred[1] - time_margin < rt < pred[1] + time_margin))

            if len(match) != 0:
                true_positive += 1
            else:
                false_positive += 1

    return true_positive, false_positive

In [24]:
# General Performance of the Model
performances = []
true_positives = []
false_positives = []

for key in sorted(all_timestamp_data.keys()):
        
    data = all_timestamp_data.get(key)
    
    predicted_timestamps = data[0]
    real_timestamps = data[1]
    
    true_positive, false_positive = check_performance(predicted_timestamps, real_timestamps)
    
    true_positives.append(true_positive)
    false_positives.append(false_positive)
    
    precision = true_positive / (true_positive + false_positive)
    # Recall is assumed as 1 since the model does not predict the absence of coughs
    recall = 1

    f1_score = round(2 * (precision * recall) / (precision + recall), 3)
    print("Audio {}: {}\n".format(key,f1_score))

    
model_performance = np.sum(true_positives) / np.sum(true_positives + false_positives)
print("\nModel Performance: ", model_performance)

Audio 1: 0.769

Audio 2: 0.857

Audio 3: 0.667

Audio 4: 1.0

Audio 5: 1.0

Audio 6: 0.8

Audio 7: 0.4

Audio 8: 0.7

Audio 9: 0.222

Audio 10: 0.25

Audio 11: 0.824

Audio 12: 0.4

Audio 13: 0.667

Audio 14: 0.857

Audio 15: 1.0

Audio 16: 1.0

Audio 17: 0.857

Audio 18: 0.5

Audio 19: 0.519

Audio 20: 1.0

Audio 21: 0.4

Audio 22: 1.0

Audio 23: 0.4

Audio 24: 0.8

Audio 25: 0.857

Audio 26: 1.0

Audio 27: 0.444

Audio 28: 0.667

Audio 29: 0.696

Audio 30: 0.8

Audio 31: 0.857

Audio 32: 1.0

Audio 33: 0.6

Audio 34: 0.667

Audio 35: 0.5

Audio 36: 0.857

Audio 37: 0.522

Audio 38: 0.8

Audio 39: 0.489

Audio 40: 0.727

Audio 41: 0.4

Audio 42: 0.667

Audio 43: 0.444

Audio 44: 0.571

Audio 45: 1.0

Audio 46: 0.778

Audio 47: 0.667

Audio 48: 1.0

Audio 49: 0.833

Audio 50: 0.385

Audio 51: 0.75


Model Performance:  0.45727482678983833
