In [None]:
import os
import torch
import librosa
import numpy as np
from transformers import WhisperProcessor, WhisperModel

segment_length1 = 480000  
overlap_rate1 = 0.3      
sampling_rate = 16000    

audio_path_full_enhance_cc_train = "/home/sichengyu/Downloads/dementiabank/Pitt_new/norm/train_test_split/train/train_split/folds/fold_5/train/cc_enhence"
audio_path_full_enhance_cd_train = "/home/sichengyu/Downloads/dementiabank/Pitt_new/norm/train_test_split/train/train_split/folds/fold_5/train/cd_enhence"
audio_path_full_enhance_cc_test = "/home/sichengyu/Downloads/dementiabank/Pitt_new/norm/train_test_split/train/train_split/folds/fold_5/val/cc_enhence"
audio_path_full_enhance_cd_test = "/home/sichengyu/Downloads/dementiabank/Pitt_new/norm/train_test_split/train/train_split/folds/fold_5/val/cd_enhence"

feature_dic = {
    'list_full_cc_train': [],
    'list_full_cd_train': [],
    'list_norm_cc_train': [],
    'list_norm_cd_train': [],
    'list_full_cc_test': [],
    'list_full_cd_test': [],
    'list_norm_cc_test': [],
    'list_norm_cd_test': []
}

model_name = "openai/whisper-small"
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperModel.from_pretrained(model_name, output_hidden_states=True)
model.eval()

def get_segment_features(audio_file, audio_file_index, segment_length=segment_length1, sampling_rate=16000, overlap_rate=overlap_rate1):
    """
    Extract features from a single audio file.

    Parameters:
        audio_file (str): Audio file path.
        audio_file_index (int): Index of the audio file.
        segment_length (int): Length of each segment.
        sampling_rate (int): Sampling rate.
        overlap_rate (float): Overlap rate.

    Returns:
        tuple: (list of segment features, list of segment indices)
    """

    y, sr = librosa.load(audio_file, sr=sampling_rate)
    step_size = int(segment_length * (1 - overlap_rate))
    total_segments = int(len(y) / step_size) + (1 if len(y) % step_size > overlap_rate * segment_length else 0)

    segment_features_list = []
    segment_indices = []

    for i in range(total_segments):
        start = i * step_size
        end = start + segment_length
        segment = y[start:end]

        # If the last segment's length is insufficient, pad with the ending portion of the previous segment
        if len(segment) < segment_length and i > 0:  # Ensure it's not the first segment
            padding_needed = segment_length - len(segment)
            # Get the ending part of the previous segment for padding
            previous_segment_end = start + segment_length - step_size  # End of the previous segment
            padding_start = max(0, previous_segment_end - padding_needed)  # Ensure not to exceed the start of the audio file
            padding_values = y[padding_start:previous_segment_end]
            segment = np.concatenate((segment, padding_values))

        # Preprocess the current segment and extract features
        inputs = processor(segment, sampling_rate=sampling_rate, return_tensors="pt")
        input_features = inputs.input_features  # Whisper uses input_features

        with torch.no_grad():
            # Only call the encoder part
            encoder_outputs = model.encoder(input_features=input_features, output_hidden_states=True)
            hidden_states = encoder_outputs.hidden_states  # Get all hidden layers

        num_layers = len(hidden_states)

        # Select the layer to extract (e.g., the 12th layer)
        desired_layer = 12 
        if num_layers > desired_layer:
            layer_features = hidden_states[desired_layer].squeeze(0)  # Remove the first dimension (batch_size)
            segment_features_list.append(layer_features)
            segment_indices.append((audio_file_index, i))
        else:
            print(f"Encoder has only {num_layers} layers, cannot extract features from layer {desired_layer + 1}.")
            continue

    return segment_features_list, segment_indices

def get_features_from_directory(audio_path, segment_length=segment_length1, sampling_rate=16000, overlap_rate=0.3):
    """
    Extract features from all audio files in the specified directory.

    Parameters:
        audio_path (str): Audio file directory path.
        segment_length (int): Length of each segment.
        sampling_rate (int): Sampling rate.
        overlap_rate (float): Overlap rate.

    Returns:
        tuple: (list of all segment features, list of all segment indices)
    """
    audio_files = [os.path.join(audio_path, file) for file in os.listdir(audio_path) if file.endswith('.wav')]

    features_list = []
    segments_indices = []

    for audio_file_index, audio_file in enumerate(audio_files):
        print(f"Processing file: {audio_file} (Index: {audio_file_index})")
        segment_features, segment_indices = get_segment_features(
            audio_file,
            audio_file_index,
            segment_length=segment_length,
            sampling_rate=sampling_rate,
            overlap_rate=overlap_rate
        )
        features_list.extend(segment_features)
        segments_indices.extend(segment_indices)

    return features_list, segments_indices

def normalize_features(features_list):
    """
    Normalize the features.

    Parameters:
        features_list (list of torch.Tensor): List of features.

    Returns:
        list of torch.Tensor: List of normalized features.
    """
    normalized_features = []
    for feature in features_list:
        feature_np = feature.numpy()
        feature_norm = (feature_np - feature_np.min()) / (feature_np.max() - feature_np.min() + 1e-8)
        feature_norm_tensor = torch.from_numpy(feature_norm)
        normalized_features.append(feature_norm_tensor)
    return normalized_features


In [None]:
feature_dic['list_full_cc_train'], indices_cc_train = get_features_from_directory(
    audio_path_full_enhance_cc_train,
    segment_length=segment_length1,
    sampling_rate=sampling_rate,
    overlap_rate=overlap_rate1
)

feature_dic['list_full_cd_train'], indices_cd_train = get_features_from_directory(
    audio_path_full_enhance_cd_train,
    segment_length=segment_length1,
    sampling_rate=sampling_rate,
    overlap_rate=overlap_rate1
)

feature_dic['list_full_cc_test'], indices_cc_test = get_features_from_directory(
    audio_path_full_enhance_cc_test,
    segment_length=segment_length1,
    sampling_rate=sampling_rate,
    overlap_rate=overlap_rate1
)

feature_dic['list_full_cd_test'], indices_cd_test = get_features_from_directory(
    audio_path_full_enhance_cd_test,
    segment_length=segment_length1,
    sampling_rate=sampling_rate,
    overlap_rate=overlap_rate1
)

feature_dic['list_norm_cc_train'] = normalize_features(feature_dic['list_full_cc_train'])
feature_dic['list_norm_cd_train'] = normalize_features(feature_dic['list_full_cd_train'])
feature_dic['list_norm_cc_test'] = normalize_features(feature_dic['list_full_cc_test'])
feature_dic['list_norm_cd_test'] = normalize_features(feature_dic['list_full_cd_test'])


In [None]:
indices_cc_train = np.array(indices_cc_train)
indices_cd_train = np.array(indices_cd_train)

max_index_cc_train = indices_cc_train[:, 0].max()

indices_cd_train_adjusted = indices_cd_train.copy()
indices_cd_train_adjusted[:, 0] += max_index_cc_train + 1

indices_train = np.vstack((indices_cc_train, indices_cd_train_adjusted))
indices_train = jnp.array(indices_train)

In [None]:
indices_cc_test = np.array(indices_cc_test)
indices_cd_test = np.array(indices_cd_test)

max_index_cc_test = indices_cc_test[:, 0].max()

indices_cd_test_adjusted = indices_cd_test.copy()
indices_cd_test_adjusted[:, 0] += max_index_cc_test + 1

indices_test = np.vstack((indices_cc_test, indices_cd_test_adjusted))
indices_test = jnp.array(indices_test)


In [None]:
features_cc = torch.stack(feature_dic['list_full_cc_train'], dim=0)
features_cd = torch.stack(feature_dic['list_full_cd_train'], dim=0)
features_cc_test = torch.stack(feature_dic['list_full_cc_test'], dim=0)
features_cd_test = torch.stack(feature_dic['list_full_cd_test'], dim=0)

features1 = torch.cat([features_cc, features_cd], dim=0)
features2= torch.cat([features_cc_test, features_cd_test], dim=0)

labels_cc = torch.zeros(len(feature_dic['list_full_cc_train']))
labels_cd = torch.ones(len(feature_dic['list_full_cd_train']))
labels_cc_test = torch.zeros(len(feature_dic['list_full_cc_test']))
labels_cd_test = torch.ones(len(feature_dic['list_full_cd_test']))
labels1 = torch.cat([labels_cc, labels_cd], dim=0)
labels2 = torch.cat([labels_cc_test, labels_cd_test], dim=0)

features1_s = features1.permute(0, 2, 1)
features2_s = features2.permute(0, 2, 1)
print('features1.shape:',features1.shape)

In [None]:
base_path = '/home/sichengyu/text/NCDE'
feature_folder = 'feature_tensor/Pittnew/CV5'
save_path = os.path.join(base_path, feature_folder)

file_name = 'feature_Pitt_0.3_train1_whisper_30_new.pt'
file_name2 = 'feature_Pitt_0.3_val_whisper_30_new.pt'
file_name3 = 'labels1_Pitt_0.3_train1_whisper_30_new.pt'
file_name4 = 'labels2_Pitt_0.3_val_whisper_30_new.pt'
file_name5 = 'indices_train1_Pitt_0.3_whisper_30_new.pt'
file_name6 = 'indices_val_Pitt_0.3_whisper_30_new.pt'

full_path = os.path.join(save_path, file_name)
full_path2 = os.path.join(save_path, file_name2)
full_path3 = os.path.join(save_path, file_name3)
full_path4 = os.path.join(save_path, file_name4)
full_path5 = os.path.join(save_path, file_name5)
full_path6 = os.path.join(save_path, file_name6)

torch.save(features1, full_path)
torch.save(features2, full_path2)
torch.save(labels1,full_path3)
torch.save(labels2, full_path4)
torch.save(indices_train, full_path5)
torch.save(indices_test, full_path6)

print(f"Tensor saved to {full_path}")
