In [5]:
!pip install torch transformers pytorch_lightning
!cd content
!wget http://www.yeeking.net/llama-midi/MIDI.py
!wget http://www.yeeking.net/llama-midi/midi_model.py
!wget http://www.yeeking.net/llama-midi/midi_synthesizer.py
!wget http://www.yeeking.net/llama-midi/midi_tokenizer.py

/bin/bash: line 1: cd: content: No such file or directory
--2024-09-16 17:52:03--  http://www.yeeking.net/llama-midi/MIDI.py
Resolving www.yeeking.net (www.yeeking.net)... 80.87.143.6
Connecting to www.yeeking.net (www.yeeking.net)|80.87.143.6|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 71151 (69K) [text/x-python]
Saving to: ‘MIDI.py’


2024-09-16 17:52:04 (207 KB/s) - ‘MIDI.py’ saved [71151/71151]



In [None]:
import argparse
import os
import random

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
from torch import optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader

import MIDI
from midi_model import MIDIModel
from midi_tokenizer import MIDITokenizer
from types import SimpleNamespace


EXTENSION = [".mid", ".midi"]
tokenizer = MIDITokenizer()


def file_ext(fname):
    return os.path.splitext(fname)[1].lower()


class MidiDataset(Dataset):
    def __init__(self, midi_list, tokenizer: MIDITokenizer, max_len=2048, min_file_size=3000, max_file_size=384000,
                 aug=True, check_alignment=False):
        self.tokenizer = tokenizer
        self.midi_list = midi_list
        self.max_len = max_len
        self.min_file_size = min_file_size
        self.max_file_size = max_file_size
        self.aug = aug
        self.check_alignment = check_alignment

    def __len__(self):
        return len(self.midi_list)

    def load_midi(self, index):
        path = self.midi_list[index]
        try:
            with open(path, 'rb') as f:
                datas = f.read()
            if len(datas) > self.max_file_size:  # large midi file will spend too much time to load
                raise ValueError("file too large")
            elif len(datas) < self.min_file_size:
                raise ValueError("file too small")
            mid = MIDI.midi2score(datas)
            if max([0] + [len(track) for track in mid[1:]]) == 0:
                raise ValueError("empty track")
            mid = self.tokenizer.tokenize(mid)
            if self.check_alignment and not self.tokenizer.check_alignment(mid):
                raise ValueError("not aligned")
            if self.aug:
                mid = self.tokenizer.augment(mid)
        except Exception:
            mid = self.load_midi(random.randint(0, self.__len__() - 1))
        return mid

    def __getitem__(self, index):
        mid = self.load_midi(index)
        mid = np.asarray(mid, dtype=np.int16)
        # if mid.shape[0] < self.max_len:
        #     mid = np.pad(mid, ((0, self.max_len - mid.shape[0]), (0, 0)),
        #                  mode="constant", constant_values=self.tokenizer.pad_id)
        start_idx = random.randrange(0, max(1, mid.shape[0] - self.max_len))
        start_idx = random.choice([0, start_idx])
        mid = mid[start_idx: start_idx + self.max_len]
        mid = mid.astype(np.int64)
        mid = torch.from_numpy(mid)
        return mid


def collate_fn(batch):
    max_len = max([len(mid) for mid in batch])
    batch = [F.pad(mid, (0, 0, 0, max_len - mid.shape[0]), mode="constant", value=tokenizer.pad_id) for mid in batch]
    batch = torch.stack(batch)
    return batch


def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
    """ Create a schedule with a learning rate that decreases linearly after
    linearly increasing during a warmup period.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))

    return LambdaLR(optimizer, lr_lambda, last_epoch)


class TrainMIDIModel(MIDIModel):
    def __init__(self, tokenizer: MIDITokenizer, n_layer=12, n_head=16, n_embd=1024, n_inner=4096, flash=False,
                 lr=2e-4, weight_decay=0.01, warmup=1e3, max_step=1e6):
        super(TrainMIDIModel, self).__init__(tokenizer=tokenizer, n_layer=n_layer, n_head=n_head, n_embd=n_embd,
                                             n_inner=n_inner, flash=flash)
        self.lr = lr
        self.weight_decay = weight_decay
        self.warmup = warmup
        self.max_step = max_step

    def configure_optimizers(self):
        param_optimizer = list(self.named_parameters())
        no_decay = ['bias', 'norm']  # no decay for bias and Norm
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                'weight_decay': self.weight_decay},
            {
                'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
                'weight_decay': 0.0
            }
        ]
        optimizer = optim.AdamW(
            optimizer_grouped_parameters,
            lr=self.lr,
            betas=(0.9, 0.99),
            eps=1e-08,
        )
        lr_scheduler = get_linear_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=self.warmup,
            num_training_steps=self.max_step,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "step",
                "frequency": 1
            }
        }

    def training_step(self, batch, batch_idx):
        x = batch[:, :-1].contiguous()  # (batch_size, midi_sequence_length, token_sequence_length)
        y = batch[:, 1:].contiguous()
        hidden = self.forward(x)
        rand_idx = [-1] + random.sample(list(range(y.shape[1] - 2)), min(127, (y.shape[1] - 2) // 2))
        hidden = hidden[:, rand_idx]
        hidden = hidden.reshape(-1, hidden.shape[-1])
        y = y[:, rand_idx]
        y = y.reshape(-1, y.shape[-1])  # (batch_size*midi_sequence_length, token_sequence_length)
        x = y[:, :-1]
        logits = self.forward_token(hidden, x)
        loss = F.cross_entropy(
            logits.view(-1, self.tokenizer.vocab_size),
            y.view(-1),
            reduction="mean",
            ignore_index=self.tokenizer.pad_id
        )
        self.log("train/loss", loss)
        self.log("train/lr", self.lr_schedulers().get_last_lr()[0])
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch[:, :-1].contiguous()  # (batch_size, midi_sequence_length, token_sequence_length)
        y = batch[:, 1:].contiguous()
        hidden = self.forward(x)
        hidden = hidden.reshape(-1, hidden.shape[-1])
        y = y.reshape(-1, y.shape[-1])  # (batch_size*midi_sequence_length, token_sequence_length)
        x = y[:, :-1]
        logits = self.forward_token(hidden, x)
        loss = F.cross_entropy(
            logits.view(-1, self.tokenizer.vocab_size),
            y.view(-1),
            reduction="mean",
            ignore_index=self.tokenizer.pad_id
        )
        self.log("val/loss", loss, sync_dist=True)
        return loss

    def on_validation_start(self):
        torch.cuda.empty_cache()

    def on_validation_end(self):
        @rank_zero_only
        def gen_example():
            mid = self.generate()
            mid = self.tokenizer.detokenize(mid)
            img = self.tokenizer.midi2img(mid)
            img.save(f"sample/{self.global_step}_0.png")
            with open(f"sample/{self.global_step}_0.mid", 'wb') as f:
                f.write(MIDI.score2midi(mid))
            prompt = val_dataset.load_midi(random.randint(0, len(val_dataset) - 1))
            prompt = np.asarray(prompt, dtype=np.int16)
            ori = prompt[:512]
            prompt = prompt[:256].astype(np.int64)
            mid = self.generate(prompt)
            mid = self.tokenizer.detokenize(mid)
            img = self.tokenizer.midi2img(mid)
            img.save(f"sample/{self.global_step}_1.png")
            img = self.tokenizer.midi2img(self.tokenizer.detokenize(ori))
            img.save(f"sample/{self.global_step}_1_ori.png")
            with open(f"sample/{self.global_step}_1.mid", 'wb') as f:
                f.write(MIDI.score2midi(mid))

        try:
            gen_example()
        except Exception as e:
            print(e)
        torch.cuda.empty_cache()


def get_midi_list(path):
    all_files = {
        os.path.join(root, fname)
        for root, _dirs, files in os.walk(path)
        for fname in files
    }
    all_midis = sorted(
        fname for fname in all_files if file_ext(fname) in EXTENSION
    )
    print("Found midi : ", len(all_midis), "in path", path)
    return all_midis


# if __name__ == '__main__':
def goforit():


    opt = SimpleNamespace(
        resume='',
        ckpt='',
        data='/content/drive/MyDrive/BCN-july-24/midi/',
        data_val_split=128,
        max_len=4096,
        seed=0,
        lr=2e-05,
        weight_decay=0.01,
        warmup_step=1000.0,
        max_step=1000000.0,
        grad_clip=1.0,
        batch_size_train=6,
        batch_size_val=2,
        workers_train=2,
        workers_val=2,
        acc_grad=2,
        accelerator='gpu',
        devices=-1,
        fp32=False,
        disable_benchmark=False,
        log_step=1,
        val_step=42 # needs to be less than the number of batches
    )


    if not os.path.exists("lightning_logs"):
        os.mkdir("lightning_logs")
    if not os.path.exists("sample"):
        os.mkdir("sample")

    pl.seed_everything(opt.seed)
    print("---load dataset---")
    tokenizer = MIDITokenizer()
    midi_list = get_midi_list(opt.data)
    print("Loaded midis ", len(midi_list))
    random.shuffle(midi_list)
    full_dataset_len = len(midi_list)
    train_dataset_len = full_dataset_len - opt.data_val_split
    train_midi_list = midi_list[:train_dataset_len]
    val_midi_list = midi_list[train_dataset_len:]
    train_dataset = MidiDataset(train_midi_list, tokenizer, max_len=opt.max_len)
    val_dataset = MidiDataset(val_midi_list, tokenizer, max_len=opt.max_len, aug=False)
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=opt.batch_size_train,
        shuffle=True,
        persistent_workers=True,
        num_workers=opt.workers_train,
        pin_memory=True,
        collate_fn=collate_fn
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=opt.batch_size_val,
        shuffle=False,
        persistent_workers=True,
        num_workers=opt.workers_val,
        pin_memory=True,
        collate_fn=collate_fn
    )
    print(f"train: {len(train_dataset)}  val: {len(val_dataset)}")
    # model = TrainMIDIModel(tokenizer, flash=True, lr=opt.lr, weight_decay=opt.weight_decay,
    #                        warmup=opt.warmup_step, max_step=opt.max_step)
    # https://github.com/SkyTNT/midi-model/issues/14
    # something broke with better transformers
    model = TrainMIDIModel(tokenizer, flash=False, lr=opt.lr, weight_decay=opt.weight_decay,
                           warmup=opt.warmup_step, max_step=opt.max_step)
    if opt.ckpt:
        ckpt = torch.load(opt.ckpt, map_location="cpu")
        state_dict = ckpt.get("state_dict", ckpt)
        model.load_state_dict(state_dict, strict=False)
    print("---setup trainer---")
    checkpoint_callback = ModelCheckpoint(
        monitor="val/loss",
        mode="min",
        save_top_k=1,
        save_last=True,
        auto_insert_metric_name=False,
        filename="epoch={epoch},loss={val/loss:.4f}",
    )
    callbacks = [checkpoint_callback]

    trainer = Trainer(
        precision=32 if opt.fp32 else 16,
        accumulate_grad_batches=opt.acc_grad,
        gradient_clip_val=opt.grad_clip,
        accelerator="gpu",
        devices=opt.devices,
        max_steps=opt.max_step,
        benchmark=not opt.disable_benchmark,
        val_check_interval=opt.val_step,
        log_every_n_steps=1,
        strategy="ddp_notebook",
        callbacks=callbacks,
    )
    ckpt_path = opt.resume
    if ckpt_path == "":
        ckpt_path = None
    print("---start train---")
    trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt_path)

goforit()


INFO:lightning_fabric.utilities.seed:Seed set to 0


---load dataset---
Found midi :  2732 in path /content/drive/MyDrive/BCN-july-24/midi/
Loaded midis  2732
train: 2604  val: 128


/usr/local/lib/python3.10/dist-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


---setup trainer---
---start train---


INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:lightning_fabric.utilities.distributed:Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
INFO:pytorch_lightning.utilities.rank_zero:----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type       | Pa

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
  with bar, torch.cuda.amp.autocast(enabled=amp):
generating:   2%|▏         | 8/511 [00:00<00:51,  9.83it/s]


name 'val_dataset' is not defined


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 5/511 [00:00<00:16, 31.01it/s]


name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:18, 27.45it/s]


name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:20, 25.18it/s][A
generating:   1%|          | 6/511 [00:00<00:20, 24.57it/s][A
generating:   2%|▏         | 9/511 [00:00<00:21, 23.66it/s][A
generating:   3%|▎         | 13/511 [00:00<00:19, 25.08it/s]


name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 1/511 [00:00<00:09, 52.71it/s]


name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 1/511 [00:00<00:09, 55.05it/s]


name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:22, 22.51it/s][A
generating:   1%|          | 6/511 [00:00<00:22, 22.50it/s][A
generating:   2%|▏         | 9/511 [00:00<00:22, 22.40it/s][A
generating:   2%|▏         | 12/511 [00:00<00:22, 22.17it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 22.12it/s][A
generating:   4%|▎         | 18/511 [00:00<00:22, 22.15it/s][A
generating:   4%|▍         | 21/511 [00:00<00:22, 22.20it/s][A
generating:   5%|▍         | 24/511 [00:01<00:21, 22.20it/s][A
generating:   5%|▌         | 27/511 [00:01<00:21, 22.23it/s][A
generating:   6%|▌         | 30/511 [00:01<00:21, 22.13it/s][A
generating:   6%|▋         | 33/511 [00:01<00:21, 22.07it/s][A
generating:   7%|▋         | 36/511 [00:01<00:21, 22.04it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 21.96it/s][A
generating:   8%|▊         | 42/511 [00:01<00:21, 21.79it/s][A
generating:   9%|▉         | 45/511 [00:02<00:21, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 1/511 [00:00<00:09, 51.08it/s]


name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:22, 22.67it/s][A
generating:   1%|          | 6/511 [00:00<00:22, 22.43it/s][A
generating:   2%|▏         | 9/511 [00:00<00:22, 22.38it/s][A
generating:   2%|▏         | 12/511 [00:00<00:22, 22.45it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 22.49it/s][A
generating:   4%|▎         | 18/511 [00:00<00:21, 22.54it/s][A
generating:   4%|▍         | 21/511 [00:00<00:21, 22.49it/s][A
generating:   5%|▍         | 24/511 [00:01<00:21, 22.41it/s][A
generating:   5%|▌         | 27/511 [00:01<00:21, 22.49it/s][A
generating:   6%|▌         | 30/511 [00:01<00:21, 22.49it/s][A
generating:   6%|▋         | 33/511 [00:01<00:21, 22.53it/s][A
generating:   7%|▋         | 36/511 [00:01<00:21, 22.43it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 22.45it/s][A
generating:   8%|▊         | 42/511 [00:01<00:20, 22.40it/s][A
generating:   9%|▉         | 45/511 [00:02<00:20, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:23, 21.77it/s][A
generating:   1%|          | 6/511 [00:00<00:22, 22.15it/s][A
generating:   2%|▏         | 9/511 [00:00<00:22, 22.27it/s][A
generating:   2%|▏         | 12/511 [00:00<00:22, 22.30it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 22.23it/s][A
generating:   4%|▎         | 18/511 [00:00<00:22, 22.00it/s][A
generating:   4%|▍         | 21/511 [00:00<00:22, 22.04it/s][A
generating:   5%|▍         | 24/511 [00:01<00:22, 22.13it/s][A
generating:   5%|▌         | 27/511 [00:01<00:21, 22.29it/s][A
generating:   6%|▌         | 30/511 [00:01<00:21, 22.28it/s][A
generating:   6%|▋         | 33/511 [00:01<00:21, 22.33it/s][A
generating:   7%|▋         | 36/511 [00:01<00:21, 22.24it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 22.21it/s][A
generating:   8%|▊         | 42/511 [00:01<00:21, 22.08it/s][A
generating:   9%|▉         | 45/511 [00:02<00:21, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:22, 22.42it/s][A
generating:   1%|          | 6/511 [00:00<00:22, 22.44it/s][A
generating:   2%|▏         | 9/511 [00:00<00:22, 22.37it/s][A
generating:   2%|▏         | 12/511 [00:00<00:22, 22.26it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 22.34it/s][A
generating:   4%|▎         | 18/511 [00:00<00:22, 22.41it/s][A
generating:   4%|▍         | 21/511 [00:00<00:21, 22.34it/s][A
generating:   5%|▍         | 24/511 [00:01<00:21, 22.43it/s][A
generating:   5%|▌         | 27/511 [00:01<00:21, 22.30it/s][A
generating:   6%|▌         | 30/511 [00:01<00:21, 22.26it/s][A
generating:   6%|▋         | 33/511 [00:01<00:21, 22.18it/s][A
generating:   7%|▋         | 36/511 [00:01<00:21, 22.21it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 22.19it/s][A
generating:   8%|▊         | 42/511 [00:01<00:21, 22.07it/s][A
generating:   9%|▉         | 45/511 [00:02<00:21, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:24, 20.82it/s][A
generating:   1%|          | 6/511 [00:00<00:23, 21.26it/s][A
generating:   2%|▏         | 9/511 [00:00<00:23, 21.33it/s][A
generating:   2%|▏         | 12/511 [00:00<00:23, 21.39it/s][A
generating:   3%|▎         | 15/511 [00:00<00:23, 21.39it/s][A
generating:   4%|▎         | 18/511 [00:00<00:22, 21.48it/s][A
generating:   4%|▍         | 21/511 [00:00<00:23, 21.10it/s][A
generating:   5%|▍         | 24/511 [00:01<00:22, 21.46it/s][A
generating:   5%|▌         | 27/511 [00:01<00:22, 21.65it/s][A
generating:   6%|▌         | 30/511 [00:01<00:22, 21.38it/s][A
generating:   6%|▋         | 33/511 [00:01<00:22, 21.46it/s][A
generating:   7%|▋         | 36/511 [00:01<00:21, 21.63it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 21.69it/s][A
generating:   8%|▊         | 42/511 [00:01<00:21, 21.65it/s][A
generating:   9%|▉         | 45/511 [00:02<00:21, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:22, 22.61it/s][A
generating:   1%|          | 6/511 [00:00<00:22, 22.63it/s][A
generating:   2%|▏         | 9/511 [00:00<00:22, 22.56it/s][A
generating:   2%|▏         | 12/511 [00:00<00:22, 22.48it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 22.45it/s][A
generating:   4%|▎         | 18/511 [00:00<00:21, 22.47it/s][A
generating:   4%|▍         | 21/511 [00:00<00:21, 22.38it/s][A
generating:   5%|▍         | 24/511 [00:01<00:21, 22.22it/s][A
generating:   5%|▌         | 27/511 [00:01<00:21, 22.24it/s][A
generating:   6%|▌         | 30/511 [00:01<00:21, 22.25it/s][A
generating:   6%|▋         | 33/511 [00:01<00:21, 22.39it/s][A
generating:   7%|▋         | 36/511 [00:01<00:21, 22.23it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 22.40it/s][A
generating:   8%|▊         | 42/511 [00:01<00:20, 22.49it/s][A
generating:   9%|▉         | 45/511 [00:02<00:20, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:24, 20.86it/s][A
generating:   1%|          | 6/511 [00:00<00:24, 20.80it/s][A
generating:   2%|▏         | 9/511 [00:00<00:23, 21.41it/s][A
generating:   2%|▏         | 12/511 [00:00<00:23, 21.59it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 21.75it/s][A
generating:   4%|▎         | 18/511 [00:00<00:22, 21.53it/s][A
generating:   4%|▍         | 21/511 [00:00<00:22, 21.52it/s][A
generating:   5%|▍         | 24/511 [00:01<00:22, 21.58it/s][A
generating:   5%|▌         | 27/511 [00:01<00:22, 21.60it/s][A
generating:   6%|▌         | 30/511 [00:01<00:22, 21.76it/s][A
generating:   6%|▋         | 33/511 [00:01<00:22, 21.38it/s][A
generating:   7%|▋         | 36/511 [00:01<00:22, 21.16it/s][A
generating:   8%|▊         | 39/511 [00:01<00:22, 21.36it/s][A
generating:   8%|▊         | 42/511 [00:01<00:21, 21.57it/s][A
generating:   9%|▉         | 45/511 [00:02<00:21, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:23, 21.50it/s][A
generating:   1%|          | 6/511 [00:00<00:23, 21.46it/s][A
generating:   2%|▏         | 9/511 [00:00<00:23, 21.41it/s][A
generating:   2%|▏         | 12/511 [00:00<00:23, 21.47it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 21.73it/s][A
generating:   4%|▎         | 18/511 [00:00<00:22, 21.83it/s][A
generating:   4%|▍         | 21/511 [00:00<00:22, 21.79it/s][A
generating:   5%|▍         | 24/511 [00:01<00:22, 21.89it/s][A
generating:   5%|▌         | 27/511 [00:01<00:22, 21.33it/s][A
generating:   6%|▌         | 30/511 [00:01<00:22, 21.25it/s][A
generating:   6%|▋         | 33/511 [00:01<00:22, 21.60it/s][A
generating:   7%|▋         | 36/511 [00:01<00:21, 21.85it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 22.11it/s][A
generating:   8%|▊         | 42/511 [00:01<00:21, 22.19it/s][A
generating:   9%|▉         | 45/511 [00:02<00:20, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:22, 22.40it/s][A
generating:   1%|          | 6/511 [00:00<00:22, 22.33it/s][A
generating:   2%|▏         | 9/511 [00:00<00:22, 21.98it/s][A
generating:   2%|▏         | 12/511 [00:00<00:22, 22.11it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 22.28it/s][A
generating:   4%|▎         | 18/511 [00:00<00:22, 22.29it/s][A
generating:   4%|▍         | 21/511 [00:00<00:22, 22.06it/s][A
generating:   5%|▍         | 24/511 [00:01<00:22, 22.01it/s][A
generating:   5%|▌         | 27/511 [00:01<00:21, 22.10it/s][A
generating:   6%|▌         | 30/511 [00:01<00:21, 22.12it/s][A
generating:   6%|▋         | 33/511 [00:01<00:21, 22.06it/s][A
generating:   7%|▋         | 36/511 [00:01<00:21, 22.06it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 21.57it/s][A
generating:   8%|▊         | 42/511 [00:01<00:21, 21.46it/s][A
generating:   9%|▉         | 45/511 [00:02<00:22, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:20, 24.71it/s][A
generating:   1%|          | 6/511 [00:00<00:21, 22.98it/s][A
generating:   2%|▏         | 9/511 [00:00<00:21, 22.84it/s][A
generating:   2%|▏         | 12/511 [00:00<00:22, 22.17it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 22.15it/s][A
generating:   4%|▎         | 18/511 [00:00<00:22, 22.32it/s][A
generating:   4%|▍         | 21/511 [00:00<00:21, 22.36it/s][A
generating:   5%|▍         | 24/511 [00:01<00:21, 22.38it/s][A
generating:   5%|▌         | 27/511 [00:01<00:21, 22.22it/s][A
generating:   6%|▌         | 30/511 [00:01<00:21, 22.20it/s][A
generating:   6%|▋         | 33/511 [00:01<00:21, 22.13it/s][A
generating:   7%|▋         | 36/511 [00:01<00:21, 22.16it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 22.20it/s][A
generating:   8%|▊         | 42/511 [00:01<00:21, 22.17it/s][A
generating:   9%|▉         | 45/511 [00:02<00:21, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:22, 22.25it/s][A
generating:   1%|          | 6/511 [00:00<00:22, 22.30it/s][A
generating:   2%|▏         | 9/511 [00:00<00:22, 22.49it/s][A
generating:   2%|▏         | 12/511 [00:00<00:22, 22.26it/s][A
generating:   3%|▎         | 15/511 [00:00<00:23, 21.02it/s][A
generating:   4%|▎         | 18/511 [00:00<00:23, 21.09it/s][A
generating:   4%|▍         | 21/511 [00:00<00:22, 21.43it/s][A
generating:   5%|▍         | 24/511 [00:01<00:22, 21.56it/s][A
generating:   5%|▌         | 27/511 [00:01<00:22, 21.71it/s][A
generating:   6%|▌         | 30/511 [00:01<00:22, 21.77it/s][A
generating:   6%|▋         | 33/511 [00:01<00:22, 21.24it/s][A
generating:   7%|▋         | 36/511 [00:01<00:22, 21.52it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 21.64it/s][A
generating:   8%|▊         | 42/511 [00:01<00:21, 21.73it/s][A
generating:   9%|▉         | 45/511 [00:02<00:21, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:23, 22.05it/s][A
generating:   1%|          | 6/511 [00:00<00:22, 22.19it/s][A
generating:   2%|▏         | 9/511 [00:00<00:22, 22.15it/s][A
generating:   2%|▏         | 12/511 [00:00<00:22, 22.17it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 22.20it/s][A
generating:   4%|▎         | 18/511 [00:00<00:22, 22.19it/s][A
generating:   4%|▍         | 21/511 [00:00<00:22, 22.19it/s][A
generating:   5%|▍         | 24/511 [00:01<00:21, 22.14it/s][A
generating:   5%|▌         | 27/511 [00:01<00:21, 22.18it/s][A
generating:   6%|▌         | 30/511 [00:01<00:21, 22.19it/s][A
generating:   6%|▋         | 33/511 [00:01<00:21, 22.09it/s][A
generating:   7%|▋         | 36/511 [00:01<00:21, 22.09it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 21.94it/s][A
generating:   8%|▊         | 42/511 [00:01<00:21, 21.94it/s][A
generating:   9%|▉         | 45/511 [00:02<00:21, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:22, 22.43it/s][A
generating:   1%|          | 6/511 [00:00<00:22, 22.30it/s][A
generating:   2%|▏         | 9/511 [00:00<00:22, 22.27it/s][A
generating:   2%|▏         | 12/511 [00:00<00:22, 22.06it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 22.21it/s][A
generating:   4%|▎         | 18/511 [00:00<00:22, 21.57it/s][A
generating:   4%|▍         | 21/511 [00:00<00:22, 21.55it/s][A
generating:   5%|▍         | 24/511 [00:01<00:22, 21.56it/s][A
generating:   5%|▌         | 27/511 [00:01<00:22, 21.66it/s][A
generating:   6%|▌         | 30/511 [00:01<00:22, 21.69it/s][A
generating:   6%|▋         | 33/511 [00:01<00:21, 21.96it/s][A
generating:   7%|▋         | 36/511 [00:01<00:21, 22.04it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 21.82it/s][A
generating:   8%|▊         | 42/511 [00:01<00:21, 21.72it/s][A
generating:   9%|▉         | 45/511 [00:02<00:21, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:23, 21.56it/s][A
generating:   1%|          | 6/511 [00:00<00:23, 21.95it/s][A
generating:   2%|▏         | 9/511 [00:00<00:23, 21.83it/s][A
generating:   2%|▏         | 12/511 [00:00<00:22, 21.82it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 21.59it/s][A
generating:   4%|▎         | 18/511 [00:00<00:22, 21.55it/s][A
generating:   4%|▍         | 21/511 [00:00<00:22, 21.70it/s][A
generating:   5%|▍         | 24/511 [00:01<00:22, 21.50it/s][A
generating:   5%|▌         | 27/511 [00:01<00:22, 21.26it/s][A
generating:   6%|▌         | 30/511 [00:01<00:22, 21.07it/s][A
generating:   6%|▋         | 33/511 [00:01<00:22, 21.35it/s][A
generating:   7%|▋         | 36/511 [00:01<00:22, 21.49it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 21.55it/s][A
generating:   8%|▊         | 42/511 [00:01<00:21, 21.82it/s][A
generating:   9%|▉         | 45/511 [00:02<00:21, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:19, 26.24it/s][A
generating:   1%|          | 6/511 [00:00<00:21, 23.22it/s][A
generating:   2%|▏         | 9/511 [00:00<00:22, 22.25it/s][A
generating:   2%|▏         | 12/511 [00:00<00:22, 22.10it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 21.88it/s][A
generating:   4%|▎         | 18/511 [00:00<00:22, 21.72it/s][A
generating:   4%|▍         | 21/511 [00:00<00:22, 21.56it/s][A
generating:   5%|▍         | 24/511 [00:01<00:22, 21.51it/s][A
generating:   5%|▌         | 27/511 [00:01<00:22, 21.52it/s][A
generating:   6%|▌         | 30/511 [00:01<00:22, 21.45it/s][A
generating:   6%|▋         | 33/511 [00:01<00:22, 21.47it/s][A
generating:   7%|▋         | 36/511 [00:01<00:21, 21.70it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 21.87it/s][A
generating:   8%|▊         | 42/511 [00:01<00:21, 21.82it/s][A
generating:   9%|▉         | 45/511 [00:02<00:21, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:20, 24.25it/s][A
generating:   1%|          | 6/511 [00:00<00:21, 23.02it/s][A
generating:   2%|▏         | 9/511 [00:00<00:22, 22.49it/s][A
generating:   2%|▏         | 12/511 [00:00<00:22, 22.33it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 22.01it/s][A
generating:   4%|▎         | 18/511 [00:00<00:22, 22.11it/s][A
generating:   4%|▍         | 21/511 [00:00<00:22, 22.05it/s][A
generating:   5%|▍         | 24/511 [00:01<00:22, 22.02it/s][A
generating:   5%|▌         | 27/511 [00:01<00:22, 21.99it/s][A
generating:   6%|▌         | 30/511 [00:01<00:21, 22.11it/s][A
generating:   6%|▋         | 33/511 [00:01<00:21, 22.21it/s][A
generating:   7%|▋         | 36/511 [00:01<00:21, 22.28it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 21.96it/s][A
generating:   8%|▊         | 42/511 [00:01<00:21, 21.82it/s][A
generating:   9%|▉         | 45/511 [00:02<00:21, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:21, 24.04it/s][A
generating:   1%|          | 6/511 [00:00<00:22, 22.67it/s][A
generating:   2%|▏         | 9/511 [00:00<00:22, 22.09it/s][A
generating:   2%|▏         | 12/511 [00:00<00:22, 22.09it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 22.10it/s][A
generating:   4%|▎         | 18/511 [00:00<00:22, 22.04it/s][A
generating:   4%|▍         | 21/511 [00:00<00:22, 21.32it/s][A
generating:   5%|▍         | 24/511 [00:01<00:22, 21.48it/s][A
generating:   5%|▌         | 27/511 [00:01<00:22, 21.55it/s][A
generating:   6%|▌         | 30/511 [00:01<00:22, 21.63it/s][A
generating:   6%|▋         | 33/511 [00:01<00:22, 21.69it/s][A
generating:   7%|▋         | 36/511 [00:01<00:21, 21.80it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 21.77it/s][A
generating:   8%|▊         | 42/511 [00:01<00:21, 21.82it/s][A
generating:   9%|▉         | 45/511 [00:02<00:21, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:21, 23.82it/s][A
generating:   1%|          | 6/511 [00:00<00:22, 22.65it/s][A
generating:   2%|▏         | 9/511 [00:00<00:22, 22.10it/s][A
generating:   2%|▏         | 12/511 [00:00<00:22, 21.88it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 21.80it/s][A
generating:   4%|▎         | 18/511 [00:00<00:22, 21.78it/s][A
generating:   4%|▍         | 21/511 [00:00<00:22, 21.78it/s][A
generating:   5%|▍         | 24/511 [00:01<00:22, 21.76it/s][A
generating:   5%|▌         | 27/511 [00:01<00:22, 21.92it/s][A
generating:   6%|▌         | 30/511 [00:01<00:21, 21.89it/s][A
generating:   6%|▋         | 33/511 [00:01<00:21, 21.85it/s][A
generating:   7%|▋         | 36/511 [00:01<00:21, 21.82it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 21.78it/s][A
generating:   8%|▊         | 42/511 [00:01<00:21, 21.82it/s][A
generating:   9%|▉         | 45/511 [00:02<00:21, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:19, 25.56it/s][A
generating:   1%|          | 6/511 [00:00<00:21, 23.31it/s][A
generating:   2%|▏         | 9/511 [00:00<00:21, 22.90it/s][A
generating:   2%|▏         | 12/511 [00:00<00:22, 22.66it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 21.97it/s][A
generating:   4%|▎         | 18/511 [00:00<00:22, 21.59it/s][A
generating:   4%|▍         | 21/511 [00:00<00:22, 21.70it/s][A
generating:   5%|▍         | 24/511 [00:01<00:22, 21.88it/s][A
generating:   5%|▌         | 27/511 [00:01<00:22, 21.36it/s][A
generating:   6%|▌         | 30/511 [00:01<00:22, 21.21it/s][A
generating:   6%|▋         | 33/511 [00:01<00:22, 21.25it/s][A
generating:   7%|▋         | 36/511 [00:01<00:22, 21.08it/s][A
generating:   8%|▊         | 39/511 [00:01<00:22, 20.86it/s][A
generating:   8%|▊         | 42/511 [00:01<00:22, 20.67it/s][A
generating:   9%|▉         | 45/511 [00:02<00:22, 2

name 'val_dataset' is not defined


Validation: |          | 0/? [00:00<?, ?it/s]


generating:   0%|          | 0/511 [00:00<?, ?it/s][A
generating:   1%|          | 3/511 [00:00<00:22, 22.45it/s][A
generating:   1%|          | 6/511 [00:00<00:22, 22.02it/s][A
generating:   2%|▏         | 9/511 [00:00<00:23, 20.95it/s][A
generating:   2%|▏         | 12/511 [00:00<00:23, 21.22it/s][A
generating:   3%|▎         | 15/511 [00:00<00:22, 21.59it/s][A
generating:   4%|▎         | 18/511 [00:00<00:23, 20.84it/s][A
generating:   4%|▍         | 21/511 [00:00<00:23, 21.11it/s][A
generating:   5%|▍         | 24/511 [00:01<00:22, 21.21it/s][A
generating:   5%|▌         | 27/511 [00:01<00:22, 21.40it/s][A
generating:   6%|▌         | 30/511 [00:01<00:22, 21.63it/s][A
generating:   6%|▋         | 33/511 [00:01<00:22, 21.67it/s][A
generating:   7%|▋         | 36/511 [00:01<00:21, 21.82it/s][A
generating:   8%|▊         | 39/511 [00:01<00:21, 21.83it/s][A
generating:   8%|▊         | 42/511 [00:01<00:21, 22.07it/s][A
generating:   9%|▉         | 45/511 [00:02<00:20, 2

name 'val_dataset' is not defined
