In [33]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Dict
import re
from collections import defaultdict
from pathlib import Path

class MultiFeatureDataset(Dataset):
    def __init__(self, 
                 feature_dirs: Dict[str, str],
                 feature_types: List[str],
                 transform=None):
        """
        Args:
            feature_dirs (Dict[str, str]): Dictionary mapping feature set names to their directories.
                                           e.g., {
                                               'BoundaryTone': 'path/to/BoundaryTone-features',
                                               'EarlyLate': 'path/to/EarlyLate-features',
                                               'PictureNaming': 'path/to/PictureNaming-features'
                                           }
            feature_types (List[str]): List of feature types (subfolder names) to include.
                                       e.g., ['energy', 'f0', 'f0-4096', 'jitter', 'rp', 'shimmer']
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.feature_dirs = feature_dirs
        self.feature_types = feature_types
        self.transform = transform
        self.samples = self._gather_samples()

    def _gather_samples(self) -> List[Dict]:
        """
        Gather all samples ensuring that each sample has corresponding features in all directories and feature types.
        Returns a list of dictionaries with feature set names and feature types as keys and file paths as values.
        Also extracts the subject ID.
        """
        # Mapping: feature_set -> feature_type -> sample_id -> file_path
        feature_files = defaultdict(lambda: defaultdict(dict))
        sample_id_pattern = re.compile(r'^(subj-\d+_.+\.wav)_(\w+)\.npy$')

        
        for feature_set, dir_path in self.feature_dirs.items():
            print(f"Processing feature set: {feature_set}")
            for feature_type in self.feature_types:
                feature_type_dir = Path(dir_path) / feature_type
                if not feature_type_dir.is_dir():
                    print(f"  Warning: Feature type directory not found: {feature_type_dir}")
                    continue
                files = list(feature_type_dir.glob('*.npy'))
                print(f"  Feature type '{feature_type}' has {len(files)} files.")
                for file_path in files:
                    filename = file_path.name
                    match = sample_id_pattern.match(filename)
                    if match:
                        sample_id, feature_suffix = match.groups()
                        # Verify that the feature_suffix matches the current feature_type
                        if feature_suffix != feature_type:
                            print(f"    Warning: Feature suffix '{feature_suffix}' does not match feature type '{feature_type}' in file '{filename}'. Skipping.")
                            continue
                        feature_files[feature_set][feature_type][sample_id] = str(file_path)
                    else:
                        print(f"    Warning: Filename does not match pattern and will be skipped: {filename}")
        
        # Now, find common sample_ids across all feature sets and feature types
        print("\nGathering common samples across all feature sets and feature types...")
        sample_ids_per_feature_set = defaultdict(set)
        for feature_set in self.feature_dirs.keys():
            for feature_type in self.feature_types:
                sample_ids = set(feature_files[feature_set][feature_type].keys())
                sample_ids_per_feature_set[feature_set].add(feature_type)
        
        # Collect sample_ids that have all feature sets and all feature types
        all_sample_ids = None
        for feature_set in self.feature_dirs.keys():
            for feature_type in self.feature_types:
                current_ids = set(feature_files[feature_set][feature_type].keys())
                if all_sample_ids is None:
                    all_sample_ids = current_ids
                else:
                    all_sample_ids = all_sample_ids.intersection(current_ids)
        
        if not all_sample_ids:
            print("No common samples found across all feature sets and feature types.")
            return []
        
        print(f"Total common samples found: {len(all_sample_ids)}\n")
        
        # Now, create sample entries
        samples = []
        for sample_id in all_sample_ids:
            sample_entry = {}
            # Extract subject_id from sample_id
            subject_id_match = re.match(r'subj-(\d+)_', sample_id)
            if subject_id_match:
                subject_id = subject_id_match.group(1)
                sample_entry['subject_id'] = subject_id
            else:
                print(f"    Warning: Could not extract subject_id from sample_id '{sample_id}'. Skipping.")
                continue
            # Collect all feature file paths
            missing_feature = False
            for feature_set in self.feature_dirs.keys():
                for feature_type in self.feature_types:
                    file_path = feature_files[feature_set][feature_type].get(sample_id)
                    if file_path:
                        key = f"{feature_set}_{feature_type}"
                        sample_entry[key] = file_path
                    else:
                        print(f"    Warning: Missing file for sample_id '{sample_id}' in feature_set '{feature_set}', feature_type '{feature_type}'. Skipping sample.")
                        missing_feature = True
                        break
                if missing_feature:
                    break
            if not missing_feature:
                samples.append(sample_entry)
        
        print(f"Total valid samples after checking all feature sets and feature types: {len(samples)}")
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        sample_info = self.samples[idx]
        features = []
        for key in sorted(sample_info.keys()):
            if key == 'subject_id':
                continue
            feature = np.load(sample_info[key])
                 # print all the shapes of the features
            if feature.shape == () :
                feature = np.array([feature])
            elif feature.shape != (1,) :
                feature = np.array([np.mean(feature, axis=0), np.std(feature, axis=0), np.median(feature, axis=0)])
            
            features.append(feature)
        
        # Concatenate all features into a single feature vector
        # Assuming all features are 1D arrays. Adjust if features have different dimensions.
        try:
            combined_features = np.concatenate(features)
        except Exception as e:
            print(f"Error concatenating features for sample index {idx}: {e}")
       
            combined_features = np.array([])
        
        subject_id = sample_info['subject_id']

        if self.transform:
            combined_features = self.transform(combined_features)
        else:
            combined_features = torch.tensor(combined_features, dtype=torch.float32)

        return combined_features, subject_id

    def get_numpy_data(self) -> Tuple[np.ndarray, List[str]]:
        """
        Returns all data as a NumPy array and a list of subject IDs.
        Suitable for scikit-learn.
        """
        all_features = []
        subject_ids = []
        for idx, sample in enumerate(self.samples):
            features = []
            for key in sorted(sample.keys()):
                if key == 'subject_id':
                    continue
                feature = np.load(sample[key])
                
                if feature.shape == () :
                    feature = np.array([feature])
                elif feature.shape != (1,) :
                    feature = np.array([np.mean(feature, axis=0), np.std(feature, axis=0), np.median(feature, axis=0)])
            
                features.append(feature)
            try:
                combined_features = np.concatenate(features)
            except Exception as e:
                print(f"Error concatenating features for sample index {idx}: {e}")
                continue
            all_features.append(combined_features)
            subject_ids.append(sample['subject_id'])
        return np.array(all_features), subject_ids



In [36]:


if __name__ == "__main__":
    # 定义特征目录路径
    basedir = '/data/storage025/Turntaking/wavs_single_channel_normalized_nosil/'

    # feature_dirs = {'test'  : os.path.join(basedir, 'test-features')}
    feature_dirs = {'PictureNaming'  : os.path.join(basedir, 'PictureNaming-features'),}


    # 定义特征类型（子文件夹名称）
    feature_types = ['energy', 'f0', 'jitter', 'rp']

    # 初始化数据集
    dataset = MultiFeatureDataset(feature_dirs=feature_dirs, feature_types=feature_types)

    # 检查样本数量
    print(f"\nTotal samples in dataset: {len(dataset)}\n")

    if len(dataset) == 0:
        print("No samples available for DataLoader. Please check your directory structure and file naming conventions.")
    else:
        # 创建 DataLoader
        torch_dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4)

        # 示例遍历 DataLoader
        for batch_idx, (batch_features, batch_subject_ids) in enumerate(torch_dataloader):
            print(f"Batch {batch_idx + 1}:")
            print(f"  Features shape: {batch_features.shape}")
            print(f"  Subject IDs: {batch_subject_ids}")
            # 在此处添加您的训练代码
            break  # 仅示例一次迭代
        
    X, subject_ids = dataset.get_numpy_data()
    print(f"\nFeature matrix shape: {X.shape}")
    print(f"Subject IDs (前5个): {subject_ids[:5]}")  

Processing feature set: test
  Feature type 'energy' has 31 files.
  Feature type 'f0' has 31 files.
  Feature type 'jitter' has 31 files.
  Feature type 'rp' has 31 files.

Gathering common samples across all feature sets and feature types...
Total common samples found: 31

Total valid samples after checking all feature sets and feature types: 31

Total samples in dataset: 31

Batch 1:
  Features shape: torch.Size([2, 8])
  Subject IDs: ('2112', '2112')

Feature matrix shape: (31, 8)
Subject IDs (前5个): ['2112', '2112', '2112', '2112', '2112']


In [None]:
# 使用示例

if __name__ == "__main__":
    # 定义特征目录路径
    feature_dirs = {
        'BoundaryTone': '/path/to/BoundaryTone-features',
        'EarlyLate': '/path/to/EarlyLate-features',
        'PictureNaming': '/path/to/PictureNaming-features'
    }

    # 定义特征类型（子文件夹名称）
    feature_types = ['energy', 'f0', 'f0-4096', 'jitter', 'rp', 'shimmer']

    # 初始化数据集
    dataset = MultiFeatureDataset(feature_dirs=feature_dirs, feature_types=feature_types)

    # 检查样本数量
    print(f"\nTotal samples in dataset: {len(dataset)}\n")

    if len(dataset) == 0:
        print("No samples available for DataLoader. Please check your directory structure and file naming conventions.")
    else:
        # 创建 DataLoader
        torch_dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

        # 示例遍历 DataLoader
        for batch_idx, (batch_features, batch_subject_ids) in enumerate(torch_dataloader):
            print(f"Batch {batch_idx + 1}:")
            print(f"  Features shape: {batch_features.shape}")
            print(f"  Subject IDs: {batch_subject_ids}")
            # 在此处添加您的训练代码
            break  # 仅示例一次迭代

        # 对于 scikit-learn：获取 NumPy 数组
        X, subject_ids = dataset.get_numpy_data()
        print(f"\nFeature matrix shape: {X.shape}")
        print(f"Subject IDs (前5个): {subject_ids[:5]}")  # 打印前5个subject_id

        # 示例 scikit-learn 使用
        # from sklearn.model_selection import train_test_split
        # from sklearn.ensemble import RandomForestClassifier
        # 假设您有标签 y
        # y = ...
        # X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
        # clf = RandomForestClassifier()
        # clf.fit(X_train, y_train)
        # predictions = clf.predict(X_test)
