In [36]:
import os
import json
import uuid
import pandas as pd
import torch
import torch.nn.functional as F
import torchaudio
from torchvision.models import resnet18
from b2aiprep.process import Audio, specgram, plot_spectrogram, plot_waveform
import IPython.display as Ipd
#from torchsummary import summary

import numpy as np
from pydub import AudioSegment
import librosa

In [37]:
def determine_is_stridor_and_phonatory(name, folder_path, patient_status_df):
    """
    Determine if the path should be marked as stridor and phonatory based on the folder structure and excel sheet.
    """
    
    
    
    for part in folder_path.split(os.sep):
        if part.upper() in ["CONTROL", "CONTROLS"]:
            return 0, 0, 0  # is_stridor, is_phonatory, status

    # Extract patient code from the name
    if name.startswith("Patient "):
        patient_code = int(name.split(" ")[1])
        if patient_code in patient_status_df['Code'].values:
            status = patient_status_df[patient_status_df['Code'] == patient_code]['Status'].values[0]
            if status == "No-Stridor":
                return 0, 0, 0  # is_stridor, is_phonatory, status
            elif status == "Phonatory":
                return 1, 1, 2  # is_stridor, is_phonatory, status
            elif status == "Non-Phonatory":
                return 1, 0, 1  # is_stridor, is_phonatory, status
            
            
    #global print_once
    #if print_once:
        #print(" WARNING: Patents 1,5, 10 do not found but gavien 0 , 0 , 0 status")
        #print_once = False
    #print('patient not found', name, folder_path)
    
    return 0, 0, 0  # Default to non-stridor if not found

def get_audio_duration(file_path):
    """
    Get the duration of the audio file in seconds.
    """
    audio = AudioSegment.from_wav(file_path)
    return len(audio) / 1000  # Convert milliseconds to seconds

def load_audio(file_path):
    """
    Load the audio file.
    """
    y, sr = librosa.load(file_path)
    return y, sr

def compute_short_time_energy(y, frame_size=1024, hop_length=512):
    """
    Compute the short-time energy of the audio signal.
    """
    energy = np.array([
        np.sum(np.square(y[i:i+frame_size]))
        for i in range(0, len(y), hop_length)
    ])
    return energy

def find_most_active_timeframe(y, sr, hop_length, total_duration, slot_duration):
    """
    Find the most active timeframe in the audio signal and break it into slots.
    """
    frame_duration = int(slot_duration * sr / hop_length)
    total_frames = int(total_duration * sr / hop_length)
    energy = compute_short_time_energy(y, frame_size=hop_length, hop_length=hop_length)
    
    # Find the most active timeframe
    active_start_frame = np.argmax(np.convolve(energy, np.ones(total_frames), mode='valid'))
    active_start_time = (active_start_frame * hop_length) / sr
    active_end_time = active_start_time + total_duration

    # Break the active timeframe into slots
    slots = []
    for i in range(total_frames // frame_duration):
        slot_start_time = active_start_time + i * slot_duration
        slot_end_time = slot_start_time + slot_duration
        slots.append({"start_time": slot_start_time, "end_time": slot_end_time})
    
    return slots

def generate_recording_objects(root_path, patient_status_df):
    recordings = []

    # Traverse the directory structure
    for root, dirs, files in os.walk(root_path):
        # Check if there are wav files in the current directory
        wav_files = [file for file in files if file.endswith('.wav')]
        if wav_files:
            # Get the parent folder name
            parent_folder = os.path.basename(root)
            
            for wav_file in wav_files:
                if parent_folder in ["Patient 1", "Patient 5", "Patient 10"]:
                    #print(f"parent_folder: {parent_folder}")
                    continue
                # Generate a unique ID for each player-session-recording combination
                uid = str(uuid.uuid4())
                # Extract the recording label (base name without extension)
                recording_label = os.path.splitext(wav_file)[0]
                # Get the full path of the wav file
                file_path = os.path.join(root, wav_file)
                # Get the duration of the audio file
                duration = get_audio_duration(file_path)

                # Calculate the number of slots and their duration
                num_slots = duration // 3
                if num_slots == 0:
                    continue  # Skip recordings that are too short
                total_duration = num_slots * 3
                
                # Load the audio file and compute its short-time energy
                y, sr = load_audio(file_path)
                slots = find_most_active_timeframe(y, sr, hop_length=512, total_duration=total_duration, slot_duration=3)
                
                # Determine is_stridor and is_phonatory from the folder structure and Excel sheet
                is_stridor, is_phonatory, status = determine_is_stridor_and_phonatory(parent_folder, root, patient_status_df)
                
                # Create the entry for the current recording
                entry = {
                    "uid": uid,
                    "name": parent_folder,
                    "recording": file_path,  # Include the full path
                    "recording_label": recording_label,
                    "is_stridor": is_stridor,
                    "is_phonatory": is_phonatory,
                    "status": status,
                    "duration": duration,
                    "recording_slots": slots
                }
                recordings.append(entry)
                
    return recordings

def save_to_csv(recordings, output_csv):
    # Flatten the recording slots for CSV
    flattened_data = []
    for record in recordings:
        base_data = {key: record[key] for key in record if key != 'recording_slots'}
        for slot in record['recording_slots']:
            slot_data = base_data.copy()
            slot_data.update(slot)
            flattened_data.append(slot_data)
    
    df = pd.DataFrame(flattened_data)
    df.to_csv(output_csv, index=False)

def save_to_json(recordings, output_json):
    with open(output_json, 'w') as json_file:
        json.dump(recordings, json_file, indent=4)

def count_samples(recordings):
    no_stridor_count = 0
    phonatory_count = 0
    non_phonatory_count = 0

    for record in recordings:
        if record['status'] == 0:
            no_stridor_count += 1
        elif record['status'] == 2:
            phonatory_count += 1
        elif record['status'] == 1:
            non_phonatory_count += 1
    
    return no_stridor_count, phonatory_count, non_phonatory_count


def filter_recordings_by_label(data, label_prefix):
    """
    Filter recordings to only include those with recording_label starting with the specified prefix.
    Exclude specific labels when the prefix is 'FIMO'.
    """
    exclude_labels = {"FIMO": ["FIMOcricoid", "FIMOthyroid", "FIMOC", "FIMOT"]}

    filtered_data = []
    for recording in data:
        if recording['recording_label'].startswith(label_prefix):
            if label_prefix == "FIMO" and recording['recording_label'] in exclude_labels["FIMO"]:
                continue
            filtered_data.append(recording)

    return filtered_data

In [38]:
# Usage
root_path = '/data/parsak/stridor/Data files'  # Replace with the actual path to your folders
excel_path = '/data/parsak/stridor/Stridor_Labels (1).xlsx'  # Replace with the actual path to your Excel file
output_csv = 'rec_5_time.csv'  # Desired output CSV file name
output_json = 'rec_5_time.json'  # Desired output JSON file name

# Read the Excel file
patient_status_df = pd.read_excel(excel_path)

#print_once = True
recordings = generate_recording_objects(root_path, patient_status_df)

# Uncomment the following lines to save to CSV
# save_to_csv(recordings, output_csv)
# print(f"Recording objects have been saved to {output_csv}")

# Uncomment the following lines to save to JSON
# save_to_json(recordings, output_json)
# print(f"Recording objects have been saved to {output_json}")

# Count the samples
no_stridor_count, phonatory_count, non_phonatory_count = count_samples(recordings)
print(f"No-Stridor samples: {no_stridor_count}")
print(f"Phonatory samples: {phonatory_count}")
print(f"Non-Phonatory samples: {non_phonatory_count}")

No-Stridor samples: 350
Phonatory samples: 89
Non-Phonatory samples: 267


In [39]:
# Filter the data to only include recordings with recording_label starting with "FIMO"
filtered_data = filter_recordings_by_label(recordings, "FIMO")

df = pd.DataFrame(recordings)

In [94]:
class MySpecgramDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, label, classification, segment_size=3):
        self.dataframe = dataframe
        self.label = label
        self.classification = classification
        self.segment_size = segment_size
        
        # Ensure the DataFrame has the necessary columns
        required_columns = ['uid', 'recording', 'recording_label', self.label]
        for col in required_columns:
            if col not in dataframe.columns:
                raise ValueError(f"DataFrame is missing required column: {col}")

        # Extract necessary information from the DataFrame
        self.audio_files = dataframe['recording'].tolist()
        self.labels = dataframe[self.label].tolist()
        self.uids = dataframe['uid'].tolist()
        self.recording_labels = dataframe['recording_label'].tolist()
        
        if classification == 'stridor**nonstridor':
            self.num_classes = 2
        elif classification == 'nonstridor**phonatory**nonphonatory':
            self.num_classes = 3
        else:
            raise ValueError("Error: type 'stridor**nonstridor' Or 'nonstridor**phonatory**nonphonatory' for classification")
            
        if (label == 'is_stridor' and classification == 'stridor**nonstridor') or (label == 'status' and classification == 'nonstridor**phonatory**nonphonatory'):
            pass
        else:
            raise ValueError("label and classification are NOT match")

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        uid = self.uids[idx]
        audio_file = self.audio_files[idx]
        label = self.labels[idx]
        recording_label = self.recording_labels[idx]
        num_classes = self.num_classes

        try:
            audio = Audio.from_file(self.audio_files[idx])
            audio = audio.to_16khz()
            
           
            # get middle K seconds if audio is too long, pad with zeros if it is too short
            if audio.signal.size(0) > self.segment_size * 16000:
                d = (audio.signal.size(0) - self.segment_size * 16000) // 2
                audio.signal = audio.signal[d:d + self.segment_size * 16000]
            else:
                #print("before", audio.signal.size())
                padding_needed = self.segment_size*16000-audio.signal.size(0)
                audio.signal = F.pad(audio.signal, (0, 0, 0, padding_needed)) 
                #audio.signal = torch.nn.functional.pad(audio.signal, (0, self.segment_size*16000-audio.signal.size(0)), mode='constant', value=0)
                #print("after",audio.signal.size())
            
            
            win_length = 60
            #print(audio.signal.size(1))
            hop_length = 1
            nfft = int(win_length*16)
            features_specgram = specgram(audio, win_length=win_length, hop_length=hop_length, n_fft=nfft)
            features_specgram = 10.0 * torch.log10(torch.maximum(features_specgram, torch.full(features_specgram.size(), fill_value=1e-10))).T
            
            features_specgram = torch.nn.functional.interpolate(features_specgram.unsqueeze(0).unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)

        except Exception as e:
            print(f"Error loading {audio_file}: {e}")
            return self.__getitem__((idx + 1) % len(self))

        # Convert the label to a tensor

        
        label = torch.tensor(label, dtype=torch.long)
        label = F.one_hot(label, num_classes=num_classes).float()

        return {'uid': uid, 'signal': features_specgram, 'audio': audio.signal, self.label: label, 'recording_label': recording_label}


In [95]:

dataset = MySpecgramDataset(df, label = 'is_stridor',  classification = 'stridor**nonstridor', segment_size = 5) # classification:  'stridor**nonstridor' Or 'nonstridor**phonatory**nonphonatory' 
label = 'is_stridor' # 'is_stridor' or 'status'
number_classes = 2
#classes = ['no-stridor', 'non-phonatory', 'phonatory']
classes = ['no-stridor', 'stridor']


print(f"dataset: {dataset}")

# Split the dataset into train, validation, and test sets
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

# Create DataLoaders for batching and shuffling
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)

# Example: Iterate through the train DataLoader
for batch in train_dataloader:
    
    #print(batch['recording_label'], batch[label])
    #print('Label is: ', batch[label][0])
    #for i in range(1,5):
    #    plot_spectrogram(batch['signal'][i])
    #    plot_waveform(batch['audio'][i], 16000, title="Waveform", ax=None)
    #print("the shape is: ", batch['signal'][0].shape)
    #print("the size is: ", batch['signal'].shape)
    break

dataset: <__main__.MySpecgramDataset object at 0x14b7a5924550>


In [96]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [97]:
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 1, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

In [98]:
model = ViT(
    image_size = 512,
    patch_size = 64,
    num_classes = number_classes,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)


In [None]:
import torch

def eval(model, dataloader, device):
    model.eval()
    acc = 0
    total = 0
    with torch.no_grad():
        for batch in dataloader:
            signal = batch['signal'].to(device)
            labels = batch[label].to(device)

            outputs = model(signal.view(-1, 1, signal.size(1), signal.size(2)))
            _, predicted = torch.max(outputs, 1)
            _, true_labels = torch.max(labels, 1)
            
            acc += (predicted == true_labels).sum().item()
            total += labels.size(0)

    return acc / total

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)  # Move model to GPU if available

num_epochs = 20
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=5e-5)

best_val_acc = 0
for epoch in range(num_epochs):
    model.train()
    closs = []
    for batch in train_dataloader:
        optimizer.zero_grad()

        signal = batch['signal'].to(device)  # Move batch to GPU
        labels = batch[label].to(device)  # Move labels to GPU
        outputs = model(signal.view(-1, 1, signal.size(1), signal.size(2)))

        # Convert labels to long before calculating cross entropy loss
        _, true_labels = torch.max(labels, 1)
        loss = torch.nn.functional.cross_entropy(outputs, true_labels)
        
        closs.append(loss.item())
        loss.backward()
        optimizer.step()

    train_acc = eval(model, train_dataloader, device)
    val_acc = eval(model, val_dataloader, device)
    print(f'Epoch: {epoch+1} TrainLoss: {sum(closs)/len(closs):.4f} TrainACC: {train_acc:.4f} ValACC: {val_acc:.4f}')

    if val_acc >= best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), './model.pth')
        print('Saved!')


Epoch: 1 TrainLoss: 0.7987 TrainACC: 0.6436 ValACC: 0.6571
Saved!
Epoch: 2 TrainLoss: 0.6683 TrainACC: 0.5975 ValACC: 0.6143
Epoch: 3 TrainLoss: 0.6969 TrainACC: 0.7234 ValACC: 0.6429
Epoch: 4 TrainLoss: 0.4541 TrainACC: 0.8670 ValACC: 0.6286
Epoch: 5 TrainLoss: 0.2505 TrainACC: 0.9450 ValACC: 0.6714
Saved!
Epoch: 6 TrainLoss: 0.1925 TrainACC: 0.9965 ValACC: 0.6714
Saved!
Epoch: 7 TrainLoss: 0.1243 TrainACC: 0.9025 ValACC: 0.6143
Epoch: 8 TrainLoss: 0.2200 TrainACC: 0.9255 ValACC: 0.6143
Epoch: 9 TrainLoss: 0.0661 TrainACC: 0.9982 ValACC: 0.7286
Saved!
Epoch: 10 TrainLoss: 0.0048 TrainACC: 1.0000 ValACC: 0.6714
Epoch: 11 TrainLoss: 0.0013 TrainACC: 1.0000 ValACC: 0.6571
Epoch: 12 TrainLoss: 0.0007 TrainACC: 1.0000 ValACC: 0.6571
Epoch: 13 TrainLoss: 0.0006 TrainACC: 1.0000 ValACC: 0.6571
Epoch: 14 TrainLoss: 0.0004 TrainACC: 1.0000 ValACC: 0.6571
Epoch: 15 TrainLoss: 0.0003 TrainACC: 1.0000 ValACC: 0.6571
Epoch: 16 TrainLoss: 0.0002 TrainACC: 1.0000 ValACC: 0.6571


In [None]:
1
2
3
4
5
6
7
8

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def get_predictions_and_labels(model, dataloader, device):
    model.to(device)
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            signal = batch['signal'].to(device)
            outputs = model(signal.view(-1, 1, signal.size(1), signal.size(2)).to(device))
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            labels = torch.argmax(batch[label], dim=1).cpu().numpy()  # Convert one-hot encoded labels to class indices
            all_preds.extend(preds)
            all_labels.extend(labels)
    return all_preds, all_labels

# Assuming test_dataloader is defined and model is trained
model.load_state_dict(torch.load('./ViT_2_5s.pth'))

# Get predictions and true labels for the test set
preds, labels = get_predictions_and_labels(model, test_dataloader, device)

# Compute confusion matrix and classification report
conf_matrix = confusion_matrix(labels, preds)
class_report = classification_report(labels, preds, target_names=classes)

# Plot the confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

print('Classification Report:\n', class_report)


In [None]:
preds, labels = get_predictions_and_labels(model, val_dataloader, device)
print(

In [None]:

F.one_hot(torch.tensor([0,1,2,0,1,1,2], dtype=torch.long), num_classes=3)