In [7]:
import h5py
from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple 

import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision.transforms import transforms
import pytorch_lightning as pl

ModuleNotFoundError: No module named 'pytorch_lightning'

In [4]:
class AudioDataset(Dataset):
    def __init__(self, hdf5_file_path: str, max_num_files: Optional[int], audio_transforms: Callable):
        super().__init__()
        self.hdf5_file_path = hdf5_file_path
        self.audio_transforms = audio_transforms
        self.max_num_files = max_num_files
        self.file = h5py.File(hdf5_file_path, "r")
        self.keys = list(self.file["file"].keys())  # Get all keys (file names)

        if self.max_num_files is not None:
            self.keys = self.keys[:self.max_num_files]

    def __getitem__(self, index: int):
        key = self.keys[index]
        audio_signal = self.file["file"][key]["signal"][()]  # Read the audio signal
        label1 = self.file["file"][key]["label1"][()].decode("utf-8")  # Read label1
        label2 = self.file["file"][key]["label2"][()].decode("utf-8")  # Read label2

        return audio_signal, label1, label2

    def __len__(self):
        return len(self.keys)

In [13]:
#!pip install pytorch_lightning
!pip list

Package                   Version
------------------------- -----------
absl-py                   0.11.0
aiohttp                   2.3.10
aiosignal                 1.3.1
annotated-types           0.5.0
anyio                     3.7.1
appdirs                   1.4.4
asttokens                 2.2.1
async-timeout             3.0.1
attrs                     23.1.0
backcall                  0.2.0
brotlipy                  0.7.0
build                     0.10.0
CacheControl              0.12.14
cachetools                4.2.0
certifi                   2020.12.5
cffi                      1.15.1
chardet                   4.0.0
charset-normalizer        2.0.4
cleo                      2.0.1
click                     8.1.5
colorama                  0.4.6
contourpy                 1.1.0
crashtest                 0.4.1
cryptography              39.0.1
cycler                    0.11.0
decorator                 5.1.1
distlib                   0.3.6
docker-pycreds            0.4.0
dulwich            

In [2]:
class AudioDataModule(pl.LightningDataModule):
    def __init__(
        self,
        hdf5_file_path: str,
        train_batch_size: int,
        val_batch_size: int,
        audio_transforms: Callable,
        num_workers: int,
        num_train_files: Optional[int] = None,
        num_val_files: Optional[int] = None,
    ):
        super().__init__()
        self.hdf5_file_path = hdf5_file_path
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.audio_transforms = audio_transforms
        self.num_workers = num_workers
        self.num_train_files = num_train_files
        self.num_val_files = num_val_files

        self.train_dataset = AudioDataset(
            hdf5_file_path=self.hdf5_file_path,
            max_num_files=self.num_train_files,
            audio_transforms=self.audio_transforms,
        )
        self.val_dataset = AudioDataset(
            hdf5_file_path=self.hdf5_file_path,
            max_num_files=self.num_val_files,
            audio_transforms=self.audio_transforms,
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.val_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )

NameError: name 'pl' is not defined

In [3]:
class AudioTransforms(object):
    def __init__(self, resolution: Tuple[int,int]):
        self.transforms=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Lambda(lambda X: 2*X-1.0),
                transforms.Resize(resolution),
            ]
        )
    def __call__(self,input, *args, **kwargs):
        return self.transforms(input)

In [4]:
audio_data_module = AudioDataModule(
    hdf5_file_path='/om2/user//schen77/overlapping_dataset_test.h5',
    train_batch_size=2,
    val_batch_size=2,
    audio_transforms=AudioTransforms(resolution=(128, 128)),
    num_workers=4,
    num_train_files=50,  
    num_val_files=10,    
)

NameError: name 'AudioDataModule' is not defined

In [5]:
# Load the data using the data module
train_loader = audio_data_module.train_dataloader()
val_loader = audio_data_module.val_dataloader()

# Iterate over the data loaders to access the data
for batch_idx, (audio_signal, label1, label2) in enumerate(train_loader):
    # process data
    print(f"Batch {batch_idx}: Audio Signal Shape: {audio_signal.shape}, Label1: {label1}, Label2: {label2}")

NameError: name 'audio_data_module' is not defined