# Imports

In [2]:
import os
import h5py
import hdf5plugin
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np

In [50]:
torch.cuda.is_bf16_supported()

True

In [3]:
dataset_path = "/mnt/disks/audio-ai-research-speech-data/ace_step_datasets/rjw_all_LoRa_data/processed/"
file_path = '/mnt/disks/audio-ai-research-speech-data/ace_step_datasets/rjw_all_LoRa_data/processed/MERIDIAN_BTBC_MASTER_2024_08_07_00_24b__000.hdf5'
# f = h5py.File(file_path, 'r')  # 'r' for read-only mode
# f.close()

In [38]:
# {
#     k: torch.from_numpy(np.asarray(f[k])) for k in f.keys() if k != "keys"
# }

In [4]:
def print_structure(item, level=0):
    if isinstance(item, h5py.Group):
        print('  ' * level + f"Group: {item.name}")
        for key in item.keys():
            print_structure(item[key], level + 1)
    elif isinstance(item, h5py.Dataset):
        print('  ' * level + f"Dataset: {item.name}, Shape: {item.shape}, Type: {item.dtype}")

# print_structure(f)

In [5]:
# f['lyric_mask']#[:]

In [6]:
def augment_tags(text_token_ids, mask, shuffle, dropout):
    if not shuffle and not dropout:
        return text_token_ids, mask

    COMMA = 275
    bos = text_token_ids[-1:]
    text_token_ids = text_token_ids[:-1]

    tags = []
    start_idx = 0
    _len = len(text_token_ids)
    for idx in range(_len):
        if text_token_ids[idx] == COMMA:
            if start_idx < idx:
                tags.append(text_token_ids[start_idx:idx])
            start_idx = idx + 1
    if start_idx < _len:
        tags.append(text_token_ids[start_idx:_len])

    if shuffle:
        # Shuffle tags using torch's random seed
        perm = torch.randperm(len(tags))
        tags = [tags[i] for i in perm]

    if dropout:
        tags = [x for x in tags if torch.rand(()) > dropout]

    comma = torch.tensor([COMMA], dtype=text_token_ids.dtype)
    tags_and_commas = []
    for x in tags:
        tags_and_commas.append(x)
        tags_and_commas.append(comma)
    if tags_and_commas:
        tags_and_commas[-1] = bos
    else:
        tags_and_commas.append(bos)

    text_token_ids = torch.cat(tags_and_commas)
    mask = mask[: len(text_token_ids)]
    return text_token_ids, mask

In [7]:
def pytree_to_dtype(x, dtype):
    if isinstance(x, list):
        return [pytree_to_dtype(y, dtype) for y in x]
    elif isinstance(x, dict):
        return {k: pytree_to_dtype(v, dtype) for k, v in x.items()}
    elif isinstance(x, torch.Tensor) and x.dtype.is_floating_point:
        return x.to(dtype)
    else:
        return x


In [8]:
class HDF5Dataset(Dataset):
	def __init__(self, dataset_path, dtype, tag_shuffle, tag_dropout):
		self.dataset_path = dataset_path
		self.dtype = dtype
		self.tag_shuffle = tag_shuffle
		self.tag_dropout = tag_dropout
		self.filenames = sorted(os.listdir(dataset_path))

	def __len__(self):
		return len(self.filenames)

	def __getitem__(self, idx):
		file_path = os.path.join(self.dataset_path, self.filenames[idx])
		with h5py.File(file_path, "r") as f:
			# torch.tensor(f[k]) is slow
			sample = {
				k: torch.from_numpy(np.asarray(f[k])) for k in f.keys() if k != "keys"
			}
		sample["text_token_ids"], sample["text_attention_mask"] = augment_tags(
			sample["text_token_ids"],
			sample["text_attention_mask"],
			self.tag_shuffle,
			self.tag_dropout,
		)
		sample["text_attention_mask"] = sample["text_attention_mask"].float()
		sample = pytree_to_dtype(sample, self.dtype)
		return sample

if torch.cuda.is_bf16_supported():
	to_dtype = torch.bfloat16
else:
	to_dtype = torch.float16

tag_dropout = 0.5

ds = HDF5Dataset(
	dataset_path=dataset_path,
	dtype=to_dtype,
	tag_shuffle=True,
	tag_dropout=tag_dropout,
)

In [14]:
TEST_INDEX=0
print(ds.filenames[TEST_INDEX])
one_item = ds[0]

AUGER_Armchair_Cartographer_MASTER_2023_05_19_00_24b__000.hdf5


In [None]:
look for "prompts? not tags?"

In [15]:
for _key, _item in one_item.items():
    print("# ===== ===== ===== ===== ===== #")
    print(f'Key: {_key}')
    try:
        print(_item.shape, _item.mean())
    except:
        print(_item)
    print("# ===== ===== ===== ===== ===== #")

# ===== ===== ===== ===== ===== #
Key: attention_mask
torch.Size([323]) tensor(1., dtype=torch.bfloat16)
# ===== ===== ===== ===== ===== #
# ===== ===== ===== ===== ===== #
Key: lyric_mask
torch.Size([4]) tensor(1., dtype=torch.bfloat16)
# ===== ===== ===== ===== ===== #
# ===== ===== ===== ===== ===== #
Key: lyric_token_ids
tensor([ 261,  259, 6688,    2])
# ===== ===== ===== ===== ===== #
# ===== ===== ===== ===== ===== #
Key: mert_ssl_hidden_states
torch.Size([2244, 1024]) tensor(0.0071, dtype=torch.bfloat16)
# ===== ===== ===== ===== ===== #
# ===== ===== ===== ===== ===== #
Key: mhubert_ssl_hidden_states
torch.Size([1499, 768]) tensor(-0.0031, dtype=torch.bfloat16)
# ===== ===== ===== ===== ===== #
# ===== ===== ===== ===== ===== #
Key: speaker_embds
torch.Size([512]) tensor(0., dtype=torch.bfloat16)
# ===== ===== ===== ===== ===== #
# ===== ===== ===== ===== ===== #
Key: target_latents
torch.Size([8, 16, 323]) tensor(0.2236, dtype=torch.bfloat16)
# ===== ===== ===== ===== ===== #

In [None]:
def train_dataloader(self):
    ds = HDF5Dataset(
        dataset_path=self.hparams.dataset_path,
        dtype=self.to_dtype,
        tag_shuffle=True,
        tag_dropout=self.hparams.tag_dropout,
    )
    return DataLoader(
        ds,
        batch_size=self.hparams.batch_size,
        shuffle=True,
        num_workers=self.hparams.num_workers,
        # pin_memory=True,
        # persistent_workers=True,
    )