In [None]:
from pathlib import Path
import torch
import os
import re

In [None]:
# for every lvl2 file
#   load it
#   split it into chunks of 6144 ~ 17sec
#   for every split save it under
#   old part001 -> new part001:part004
#   old part002 -> new part005:part008
#   new part = old part + chunk_index + 1 | chunk_index = 0, 1, 2, 3
#   total parts = old total parts * 4

In [None]:
def subdivide(file):
    file = Path(file)
    current_part_str, total_parts_str = re.findall(r"part(\d+)-of-(\d+)", str(file))[0]
    current_part = int(current_part_str)
    total_parts = int(total_parts_str)
    data = torch.load(file, map_location="cpu")
    seq_len = data.shape[1]
    assert seq_len > 20668, "seq_len is too short to be split"
    data_splits = [x for x in data.split(4096, dim=1) if x.shape[1] == 4096]
    new_total_parts = total_parts * len(data_splits)
    for i, data_split in enumerate(data_splits):
        new_part = (current_part - 1) * len(data_splits) + i + 1
        new_file = file.parent / re.sub(r"part\d+-of-\d+", f"part{new_part:03d}-of-{new_total_parts:03d}", file.name)
        #torch.save(split_data, new_file)
        print(f"{str(file)[-15]} -> {new_file}")
    

In [None]:
f = "/storage/user/steiger/dataset/maestro-v3.0.0/2004/MIDI-Unprocessed_SMF_22_R1_2004_01-04_ORIG_MID--AUDIO_22_R1_2004_10_Track10_wav.part004-of-014.jukebox.lvl2.pt"
subdivide(f)

In [None]:
root_dir = Path(os.environ["MAESTRO_DATASET_DIR"])
root_dir

In [None]:
lvl2_files = list(root_dir.glob("**/*lvl2.v2.pt"))
print(len(lvl2_files))
lvl2_files[:3]

In [None]:
sample = torch.load(lvl2_files[0], map_location=torch.device('cpu'))
sample.shape

In [None]:
batch_files = lvl2_files[:64]

In [None]:
def load_batch(files):
    return [torch.load(f, map_location=torch.device('cpu')) for f in files]

In [None]:
%timeit load_batch(batch_files)

In [None]:
import sys
import os
import torch
from einops import rearrange

sys.path.append('../')
from src.dataset.jukebox_dataset import JukeboxDataset

In [None]:
dataset = JukeboxDataset(root_dir =root_dir, split="train", lvl=2, sequence_len=4000, use_cache=True)

In [None]:
for i in range(len(dataset)):
    print(i)
    sample = dataset[i]
    break

In [None]:
def load_dataset_batch(dataset, batch_size=64):
    indices = torch.randint(0, len(dataset), (batch_size,))
    return [dataset[indices[i]] for i in range(batch_size)]

In [None]:
%timeit load_dataset_batch(dataset, batch_size=64)

In [None]:
from src.datamodule.jukebox_datamodule import JukeboxDataModule

datamodule = JukeboxDataModule(root_dir=root_dir, batch_size=64, num_workers=2, lvl=2, sequence_len=4000, use_cache=False)
datamodule.setup()
dataloader = datamodule.train_dataloader()
data_iter = iter(dataloader)

In [None]:
%timeit data_iter.__next__()