In [None]:
import datetime
import sys
import os
import torch
import numpy as np
import random
from pathlib import Path
from collections import namedtuple
from miditok import REMI, MMM
from torch.nn import functional as F
from importlib import reload

os.chdir('/home/nico/dev/projects/ai/musai')

sys.path.append('./src/tools')
sys.path.append('./src/model')

import tokenizer

reload(tokenizer)

from tokenizer import get_tokenizer, parse_bpe_tokens, TOKEN_PARAMS_NAME

PROJ_NAME = 'all'
IS_BPE = False
TOKENS_PATH = f"/media/nico/nvme/data/tokens/tmp{'/bpe' if IS_BPE else ''}"
TOKENS_FILE_PATHS = list(Path(TOKENS_PATH).glob('*.json'))
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
CTX_LEN = 1024
PRECISION = 'bf16'

os.environ['RWKV_T_MAX'] = str(CTX_LEN)
os.environ['RWKV_FLOAT_MODE'] = PRECISION

os.getcwd()

In [None]:
DEVICE

In [None]:
TOKENIZER = get_tokenizer(params=f'{TOKENS_PATH}/{TOKEN_PARAMS_NAME}')

Path(f'./out/{PROJ_NAME}').mkdir(parents=True, exist_ok=True)

ORIG_VOCAB_SIZE = len(TOKENIZER.vocab)
BPE_VOCAB_SIZE = int(ORIG_VOCAB_SIZE * 1.25)

(ORIG_VOCAB_SIZE, BPE_VOCAB_SIZE, len(TOKENIZER))

In [None]:
N_EMBED = 768
N_LAYER = 12

params = {
    'ctx_len': CTX_LEN,
    'n_embd': N_EMBED,
    'n_layer': N_LAYER,
}

params_obj = namedtuple('RWKVParams', params.keys())(*params.values())

In [None]:
import matplotlib.pyplot as plt

MAX_ITER = CTX_LEN

# this is where we introduce some randomness
NOISE_LEVEL = 0.55
NOISE_FREQ = 10
PHASE = 0


def gen_sin_wave(total_iterations, min_value, max_value, noise_scale, noise_frequency, main_phase):
    """
    Generate a sinusoidal wave with optional noise.

    Args:
        total_iterations (int): The total number of iterations.
        min_value (float): The minimum value of the wave.
        max_value (float): The maximum value of the wave.
        noise_scale (float): The scale factor for the noise.
        noise_frequency (float): The frequency of the noise wave.
        main_phase (float): The phase of the main wave.

    Returns:
        list: A list of generated wave values.

    """
    progress = np.linspace(0, 1, total_iterations)
    main_wave = np.sin(2 * np.pi * progress + main_phase)
    noise_wave = np.sin(2 * np.pi * noise_frequency * progress - main_phase / 2)
    noise = noise_scale * noise_wave
    values = min_value + (max_value - min_value) * \
        (1 + main_wave) / 2 + noise

    np.clip(values, min_value, max_value, out=values)

    return values.tolist()


temp_values = gen_sin_wave(MAX_ITER, 0.05, 0.25, NOISE_LEVEL, NOISE_FREQ, PHASE)
top_p_values = gen_sin_wave(MAX_ITER, 0.65, 0.999, NOISE_LEVEL, NOISE_FREQ*2, PHASE+6)

plt.plot(temp_values)
plt.plot(top_p_values)

In [None]:
import runner

reload(runner)

from runner import RWKV
import types

SEED = random.randint(1000, 10000)
np.random.seed(SEED)

args = types.SimpleNamespace()
args.RUN_DEVICE = "cuda"
args.FLOAT_MODE = PRECISION
args.map_location = 'cpu'
args.base_model = f'/home/nico/dev/projects/ai/musai/dist/main_1.pth'
args.n_layer = params['n_layer']
args.n_embd = params['n_embd']
args.ctx_len = int(params['ctx_len'])
args.vocab_size = len(TOKENIZER)
args.head_size_a = 64

model_rnn = RWKV(args)
model_rnn.to(torch.bfloat16).cuda()

model_rnn.load_state_dict(torch.load(args.base_model, map_location='cpu'), strict=False)

In [None]:
import json
from miditoolkit import MidiFile

init_state = None
out_tokens = []
tokens_file_paths = list(Path(TOKENS_PATH).glob('*.json'))

random.shuffle(tokens_file_paths)

# SAMPLE_TOKENS_FILE = f'/home/nico/data/ai/models/midi/{PROJ_NAME}/8f33606fa1a6040e5ba230ea7bff8546_mid.json'

token_ids = json.load(open(tokens_file_paths[0]))['ids']
# token_ids = json.load(open(SAMPLE_TOKENS_FILE))['ids']
# sample = TOKENIZER.midi_to_tokens(MidiFile('examples/2023-06-11T12-59-05-739764.mid'))
# token_ids = sample.ids
max_seq = CTX_LEN if len(token_ids) >= CTX_LEN else len(token_ids)
init_tokens = token_ids[:max_seq]
print(init_tokens)

In [None]:
import re
import datetime
from collections import deque
import torch
import torch.nn.functional as F
from runner import sample_logits, repetition_penalty  # Ensure this is properly defined in your runner module
from tqdm import tqdm  # Import tqdm for the progress bar

# Initialize the model in evaluation mode
model_rnn.eval()

# Initialize the buffer with the initial tokens, ensuring it does not exceed context_length
context_length = 1024  # Define your context length
buffer = deque(maxlen=context_length)
for token in init_tokens[-context_length:]:
    buffer.append(token)

max_iterations = MAX_ITER  # Define your maximum number of iterations

# Container for generated tokens
generated_tokens = list(buffer)

# Convert buffer to tensor
input_ids = torch.tensor([list(buffer)], dtype=torch.long).cuda()
print(f"Input shape: {input_ids.shape}")  # Should be [1, 1024]

# Define pad token IDs to ignore (replace with actual pad token IDs)
pad_token_ids = [0]  # Example: [0, 50256]
out_tokens = []

with torch.no_grad():
    # Initial forward pass
    init_out = model_rnn.forward(input_ids)  # Shape: [1, 1024, vocab_size]
    print(f"Initial output shape: {init_out.shape}")

    # Ignore padding by setting their logits to -inf
    init_out[0, :, pad_token_ids] = -float('inf')

    # Sample the first token from the last position
    logits = init_out[0, -1, :].cpu()  # Move to CPU for sampling if necessary
    out_token = sample_logits(logits, temperature=0.25, top_k=50)  # Example values

    # Append to generated tokens
    generated_tokens.append(out_token)

    # Add the first token to out_tokens for tracking
    out_tokens.append(out_token)

    # Initialize tqdm progress bar
    with tqdm(range(max_iterations), desc="Generating tokens", unit="token") as pbar:
        for i in pbar:
            # Update buffer with the new token
            buffer.append(out_token)
            generated_tokens.append(out_token)

            # Convert buffer to tensor
            input_ids = torch.tensor([list(buffer)], dtype=torch.long).cuda()

            # Forward pass
            output = model_rnn.forward(input_ids)  # Shape: [1, 1024, vocab_size]

            # Ignore padding
            output[0, :, pad_token_ids] = -float('inf')

            # Apply repetition penalty if desired
            # Uncomment and adjust the following lines if repetition_penalty is needed
            # output = repetition_penalty(
            #     output, generated_tokens, [4, 5, 6, 7], 
            #     repetition_penalty=1.25, 
            #     seq_len=256, 
            #     decay_factor=0.8
            # )

            # Get logits for the last token
            logits = output[0, -1, :].cpu()

            # Define temperature and top_k/top_p for the current step
            current_temp = temp_values[i] if 'temp_values' in locals() else 1.0  # Replace with your logic
            current_top_k = top_k_values[i] if 'top_k_values' in locals() else 50   # Replace with your logic
            current_top_p = top_p_values[i] if 'top_p_values' in locals() else None  # Replace with your logic

            # Sample the next token
            out_token = sample_logits(
                logits, 
                temperature=current_temp, 
                top_k=current_top_k, 
                top_p=current_top_p
            )

            # Append to generated tokens
            generated_tokens.append(out_token)
            out_tokens.append(out_token)

            # Update the tqdm progress bar with the latest token
            pbar.set_postfix(token=out_token)

In [None]:
print(out_tokens)

In [None]:
fname = f'out/{PROJ_NAME}/{d}.mid'
fname_orig = f'out/{PROJ_NAME}/{d}_orig.mid'

TOKENIZER(out_tokens).dump_midi(fname)
TOKENIZER(token_ids).dump_midi(fname_orig)

[fname, fname_orig]