In [1]:
from functools import partial
from ast import List
import bz2
from pathlib import Path
import pickle
from transformers import AutoModelForTextEncoding, AutoModel, MusicgenForCausalLM
import warnings
import omegaconf
import torch
import torch.nn.functional as F
import typing as tp
from transformers import AutoProcessor, MusicgenForConditionalGeneration, MusicgenProcessor
from transformers.optimization import AdamW

from transformers.models.musicgen.modeling_musicgen import shift_tokens_right
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
from audiocraft.data.info_audio_dataset import AudioInfo
from audiocraft.data.music_dataset import MusicDataset, MusicInfo, Paraphraser, augment_music_info_description
from torch.utils.data import DataLoader
from audiocraft.models.builders import get_conditioner_provider

from audiocraft.modules.conditioners import AttributeDropout, ClassifierFreeGuidanceDropout, ConditioningAttributes, ConditioningProvider, SegmentWithAttributes, WavCondition
from audiocraft.solvers.builders import DatasetType, get_audio_datasets
from audiocraft.solvers.compression import CompressionSolver
# from audiocraft.utils.cache import CachedBatchLoader, CachedBatchWriter
from audiocraft.utils.utils import get_dataset_from_loader
from accelerate import Accelerator, cpu_offload


cfg = omegaconf.OmegaConf.merge(
    omegaconf.OmegaConf.load("config/solver/musicgen/default.yaml"),
    omegaconf.OmegaConf.load("config/solver/default.yaml"),
    omegaconf.OmegaConf.load("config/model/lm/default.yaml"),
    omegaconf.OmegaConf.load("config/solver/compression/default.yaml"),
    omegaconf.OmegaConf.load("config/solver/compression/encodec_musicgen_32khz.yaml"),
    omegaconf.OmegaConf.load("config/solver/musicgen/musicgen_melody_32khz.yaml"),
    omegaconf.OmegaConf.load("config/model/lm/musicgen_lm.yaml"),
    omegaconf.OmegaConf.load("config/config.yaml"),
    omegaconf.OmegaConf.load("config/conditioner/chroma2music.yaml"),
    omegaconf.OmegaConf.load("config/dset/audio/example.yaml")   
)
cfg.dataset.segment_duration = 30

cache_path = "./cache_dir"
cache_path = None
use_cached_writer = False
use_cached_reader = False

cached_batch_writer = None
cached_batch_loader = None
# if cache_path is not None and use_cached_writer:
#     cached_batch_writer = CachedBatchWriter(Path(cache_path))
# else:
#     cached_batch_loader = CachedBatchLoader(
#         Path(cache_path), cfg.dataset.batch_size, cfg.dataset.num_workers,
#         min_length=cfg.optim.updates_per_epoch or 1)

cfg_dropout_p = 0.1 
attribute_dropout_p = {"default": 0.1}
# cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout_p)
# att_dropout = AttributeDropout(p=attribute_dropout_p)
current_stage = "train"
device = "cpu"
cfg.device = device

# condition_provider = get_conditioner_provider(cfg["transformer_lm"]["dim"], cfg)
# compression_model = CompressionSolver.wrapped_model_from_checkpoint(cfg, cfg.compression_model_checkpoint, device=device)

from typing import List
import numpy as np
from torch.utils.data import Dataset
from collections import defaultdict

class PreparedDataset(Dataset):
    def  __init__(self, path) -> None:
        super().__init__()
        self.path = Path(path) 
        self.files = list(self.path.iterdir())
        
    def __len__(self) -> int:
        return len(self.files)
    
    def read_file(self, filename: tp.Union[str, Path]) -> dict[str, torch.Tensor]:
        try:
            # with bz2.BZ2File(filename, "rb") as compressed_data_file:
            #     data = pickle.load(compressed_data_file)
            # return data
            return np.load(filename)
        except:
            print(filename)
        
    def __getitem__(self, index) -> dict[str, torch.Tensor]:
        values = dict(self.read_file(self.files[index]))
        # print(values.keys())
        return values

    @staticmethod
    def pad_sequences(samples: List[np.ndarray], pad_value: int) -> List[np.ndarray]:
        max_len = max(sample.shape[0] for sample in samples)
        for idx, sample in enumerate(samples):
            res = np.zeros((max_len, *sample.shape[1:]), dtype=sample.dtype) + pad_value
            res[:sample.shape[0]] = sample 
            samples[idx] = res 
        return samples 
    
    @staticmethod
    def collate_fn(samples: List[dict[str, torch.Tensor]], use_cls: bool = False) -> dict[str, torch.Tensor]:
        
        for idx, att_mask in enumerate(PreparedDataset.pad_sequences([s["attention_mask"] for s in samples], 0)):
            samples[idx]["attention_mask"] = att_mask

        for idx, ehs in enumerate(PreparedDataset.pad_sequences([s["encoder_hidden_states"] for s in samples], 0)):
            samples[idx]["encoder_hidden_states"] = ehs

        if use_cls:
            for idx, att_mask in enumerate(PreparedDataset.pad_sequences([s["attention_mask_cls"] for s in samples], 0)):
                samples[idx]["attention_mask_cls"] = att_mask
                
            for idx, ehs in enumerate(PreparedDataset.pad_sequences([s["encoder_hidden_states_cls"] for s in samples], 0)):
                samples[idx]["encoder_hidden_states_cls"] = ehs

        res = defaultdict(list)
        for sample in samples:
            for k, v in sample.items():
                res[k].append(v)
        return {k: torch.from_numpy(np.array(v)) for k, v in res.items()}

  from .autonotebook import tqdm as notebook_tqdm


In [21]:
decoder = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small").decoder
config = decoder.config

In [7]:
batch_size = 1 
use_cls_loss = True
dataset = PreparedDataset("generated_train_dataset")
dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=8, collate_fn=partial(dataset.collate_fn, use_cls=use_cls_loss))

In [71]:
for data in dataloader:
    break

In [72]:
data["input_ids"], data["labels"]

(tensor([[[2048, 1817, 1883,  ..., 1087,  277,  327],
          [2048, 1200, 1737,  ..., 1540,    9, 1556],
          [2048, 1525, 1056,  ..., 1390, 1453, 2006],
          [2048, 1286,  944,  ...,  225, 1036,  606]]], dtype=torch.int32),
 tensor([[[1817, 1883, 1366,  ...,  277,  327,  236],
          [1200, 1737,   73,  ...,    9, 1556, 1148],
          [1525, 1056, 1146,  ..., 1453, 2006, 1674],
          [1286,  944, 1706,  ..., 1036,  606,  939]]], dtype=torch.int32))

In [78]:
data["labels"].shape, data["labels_cls"].shape

(torch.Size([1, 4, 1500]), torch.Size([1, 4, 1500, 2048]))

In [82]:
data["labels"][:, 0, ...].view(-1).shape, data["labels_cls"][:, 0, ...].view(-1, data["labels_cls"].size(-1)).shape

(torch.Size([1500]), torch.Size([1500, 2048]))

In [91]:
ce = 0
for i in range(4):
    ce += F.cross_entropy(data["labels_cls"][:, i, ...].view(-1, data["labels_cls"].size(-1)), data["labels"][:, i, ...].view(-1).long())
ce /= 4 
print(ce)    

tensor(5.8272)


In [97]:
F.cross_entropy(data["labels_cls"].view(-1, config.vocab_size), data["labels"].view(-1).long())

tensor(5.8272)

In [73]:
data["labels_cls"].view(-1, config.vocab_size).softmax(1).argmax(1)

tensor([ 166, 1799,  457,  ..., 1773,  987, 1280])

In [74]:
data["padding_mask"]

tensor([[1, 1, 1,  ..., 1, 1, 1]], dtype=torch.int32)

In [75]:
data.keys()

dict_keys(['input_ids', 'attention_mask', 'attention_mask_cls', 'encoder_hidden_states', 'encoder_hidden_states_cls', 'labels', 'labels_cls', 'padding_mask'])