In [1]:
import os
import math
import time
import inspect
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
from rotary_embedding_torch import RotaryEmbedding
from transformers import AutoTokenizer
master_process = True 
class ChineseTokenizer:
    def __init__(self):
        self.special_tokens = {}
        self.special_token_ids = {}
        self.next_special_token_id = 70000  # Start of Unicode Private Use Area

    def add_special_token(self, token):
        if token not in self.special_tokens:
            self.special_tokens[token] = self.next_special_token_id
            self.special_token_ids[self.next_special_token_id] = token
            self.next_special_token_id += 1

    def encode(self, text):
        encoded = []
        i = 0
        while i < len(text):
            # Check for special tokens
            matched = False
            for token, token_id in self.special_tokens.items():
                if text[i:].startswith(token):
                    encoded.append(token_id)
                    i += len(token)
                    matched = True
                    break
            if not matched:
                encoded.append(ord(text[i]))
                i += 1
        return encoded

    def decode(self, ids):
        return ''.join([self.special_token_ids.get(id, chr(id)) for id in ids])
    
    
tokz = ChineseTokenizer()
tokz.add_special_token('<|answer|>')
tokz.add_special_token('None')


def log_pytorch(tensor):
    return torch.log(tensor + 1e-8)  # Adding a small constant to avoid NaN for very small values

class expertAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_tone_attn = nn.Linear(config.n_embd_pro , config.n_embd//config.n_head)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.EXPERT_SCALE_INIT = 1
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))
        self.rotary_emb = RotaryEmbedding( dim = 32,)

    def forward(self, x, tone_embeds):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
        tone_k = self.c_tone_attn(tone_embeds)
        em_x = self.rotary_emb.rotate_queries_or_keys(x)
        qkv = self.c_attn(em_x)
        q, k, v = qkv.split(self.n_embd, dim=2) 
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        k = k + tone_k.unsqueeze(1)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y

class expertMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        # Gate projection layers
        self.gate_proj = nn.Linear(config.n_embd, 4*config.n_embd)
        self.act_fn = nn.SiLU()
        self.up_proj = nn.Linear(config.n_embd, 4*config.n_embd)
        
        # Down projection layer
        self.down_proj = nn.Linear(4*config.n_embd + config.n_embd_pro, config.n_embd)
        self.down_proj.NANOGPT_SCALE_INIT = 1


    def forward(self, x, tone_embeds):
        # Gating mechanism
        gate = self.gate_proj(x)
        up = self.up_proj(x)
        
        combined = torch.cat([gate + up, tone_embeds], dim= -1)
        
        # Apply activation function
        activated = self.act_fn(combined)
        
        # Down projection
        output = self.down_proj(activated)
        
        return output
    
class expertLM_head(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.lm_head = nn.Linear(config.n_embd+ config.n_embd_pro, config.vocab_size, bias=False)
        self.act_fn = nn.GELU()
        self.lm_head2 = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.gate = nn.Parameter(torch.tensor([0.5]))

    def forward(self, x, tone_embeds):
       
        combined = torch.cat([x, tone_embeds], dim= -1)
        # Apply activation function
        activated = self.act_fn(combined)
        activated_x = self.act_fn(x)
        
        # Down projection
        output = self.lm_head(activated)
        output2 = self.lm_head2(activated_x)

        
        return self.gate * output + (1 - self.gate) * output2
    
class expertBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = expertAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = expertMLP(config)

    def forward(self, x, tone_embeds):
        x = x + self.attn(self.ln_1(x), tone_embeds)
        x = x + self.mlp(self.ln_2(x), tone_embeds)
        return x

@dataclass
class expertConfig:
    block_size: int = 32 # max sequence length
    vocab_size: int = tokz.next_special_token_id # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
    n_layer: int = 20 # number of layers
    n_head: int = 8 # number of heads
    n_embd: int = 2048 # embedding dimension
    n_embd_pro : int = 512 # tone embedding for each token

# tone_list: int = [9]

class tone_expert(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd- config.n_embd_pro),
            wpe = nn.Embedding(config.block_size,  config.n_embd- config.n_embd_pro),
            poe = nn.Embedding(config.vocab_size, config.n_embd_pro),
            ppoe = nn.Embedding(config.block_size, config.n_embd_pro),
            h = nn.ModuleList([expertBlock(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = expertLM_head(config)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, 'EXPERT_SCALE_INIT'):
                std *= (2 * self.config.n_layer) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)


    def forward(self, idx, tone_list , targets=None):
        # idx is of shape (B, T)
        # tone_list is of shape (B, T)
        B, T = idx.size()
        assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
        # forward the token and posisition embeddings
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd- n_embd_pro)
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd - n_embd_pro)
        tone_emb = self.transformer.poe(tone_list) #token embeddings of shape (B, T, n_embd_pro)
        # tone_poe_emb = self.transformer.ppoe(pos) #token embeddings of shape (T, n_embd_pro)
        x = tok_emb + pos_emb 
        tone_x = tone_emb # (B, T , n_embd_pro)
        x = torch.cat([x, tone_x ], dim=-1)
        # forward the blocks of the transformer
        for block in self.transformer.h:
            x = block(x, tone_x)
        # forward the final layernorm and the classifier
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x, tone_x) # (B, T, vocab_size)
        
        # with torch.no_grad():
        #     logits_last = logits[:,-1,:].squeeze(dim = 1)
        #     logits_last = [torch.argmax(logits_last, dim=-1), torch.max(logits_last,dim = -1)]
            
        #     t_d = {0:[4], 2:[6], 4:[3,5], 3:[1,2]}
        #     loss_tone = 0
        #     for b in range(B):
        #         req_tone = tone_list[b,0] - 15;print(f'{req_tone=}')
        #         jp = pinyin(tokz.decode(logits_last[0][b]), variant='cantonese')
        #         print(f'{jp=}')
        #         jp = re.findall(r'\d+', jp)
        #         jp = [int(num) for num in jp]
        #         v = []
        #         for r,p in t_d.items():
        #             for q in p:
        #                 if q in jp:
        #                     jp.remove(q)
        #                     v.append(r)
        #         v = set(v)
        #         v = list(v)
        #         print(f'{logits_last[1][b]=}')
        #         if req_tone in v:
        #             loss_tone += log_pytorch(logits_last[1][b])
        #         else:
        #             loss_tone += 2*log_pytorch(logits_last[1][b])
        
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss
    def configure_optimizers(self, weight_decay, learning_rate, device_type):
        # start with all of the candidate parameters (that require grad)
        param_dict = {pn: p for pn, p in self.named_parameters()}
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        if master_process:
            print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
            print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == "cuda"
        if master_process:
            print(f"using fused AdamW: {use_fused}")
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
        return optimizer

In [2]:
import numpy as np
import json
import random
class DataLoaderLite:
    def __init__(self, B, T, tokz):
        self.B = B
        self.T = T
        self.tokz = tokz


        with open('chatml_lyrics1_2 copy 4.json', 'r', encoding='utf-8') as file:
            data = json.load(file)
        self.data = data
        self.data_length = len(data)
        self.current_position = random.randint(0, len(data)-2*B-1)

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.data[self.current_position : self.current_position+B]
        # gathered_sample = 0
        # while gathered_sample < B:
        #     if 'theme' in self.data[self.current_position].keys():
        #         if self.data[self.current_position]['theme'] != None and len(self.data[self.current_position]['lyric'])>=4:
        #             buf.append(self.data[self.current_position])
        #             gathered_sample += 1
        #             self.current_position += 1
        #             if self.current_position >= len(self.data):
        #                 self.current_position = 0
        #     else:
        #         print('weird samlpe found!')
        #         self.current_position += 1
        # print('self.current_position = ')  
        # print(self.current_position)
                
        x = []
        t = []
        y = []
        for sample in buf:
            # tx = self.tokz.encode(sample['messages'][0]['content']+'<|answer|>'+sample['messages'][1]['content'])
            if sample['theme'] == None:
                tx = self.tokz.encode(sample['lyric']+'<|answer|>'+sample['lyric'])
            else:   
                tx = self.tokz.encode(sample['theme']+'<|answer|>'+sample['lyric'])
            
            for token in tx:
                if token>=tokz.next_special_token_id:
                    print('there is weird token' , token)
            tt = self.tokz.encode(sample['tone'])[:T]
            ty = tx[1:][:T]
            tx = tx[:-1][:T]
            x.append( [948] * (T - len(tx)) + tx)
            y.append([948] * (T - len(ty)) + ty)
            t.append([948] * (T - len(tt)) + tt)
            
        self.current_position += B
        # if loading the next batch would be out of bounds, advance to next shard
        if self.current_position + (B) > len(self.data) :
            self.current_position = 0    
        return torch.tensor(x),torch.tensor(t,) , torch.tensor(y,)

In [3]:
device = 'cuda'
device_type = 'cuda'
num_return_sequences =1
model = tone_expert(expertConfig())
model.to(device)
model = torch.compile(model).to(device)
train_loader = DataLoaderLite(B=16 , T =32, tokz=tokz)
ddp = False
use_compile = True
epoch = 5
max_steps = int(train_loader.data_length/16/4*epoch)
print(max_steps)


# Load the weights from the .pth file
weights = torch.load("model_weights_step_480_trial13.pth")

# Create a new dictionary with keys adjusted to remove the '_orig_mod.' prefix
# adjusted_weights = {}
# for key in weights.keys():
#     # Remove the '_orig_mod.' prefix if present
#     new_key = key.replace("_orig_mod.", "")
#     adjusted_weights[new_key] = weights[key]

# Assuming 'model' is your model instance
# Load the adjusted weights into the model
model.load_state_dict(weights)
model.to('cuda')
# Assuming 'weights' is the variable holding the loaded weights
del weights  # Delete the variable

# Invoke garbage collection
import gc
gc.collect()

# Clear the CUDA cache
torch.cuda.empty_cache()




max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 4
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_steps:
        return max_lr * (it+1) / warmup_steps
    # 2) if it > lr_decay_iters, return min learning rate
    if it > max_steps:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
    return min_lr + coeff * (max_lr - min_lr)

# optimize!
optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device_type=device_type)


log_dir = "log"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"log13.txt")
with open(log_file, "w") as f: # open for writing to clear the file
    pass





4550
num decayed parameter tensors: 126, with 1,831,769,088 parameters
num non-decayed parameter tensors: 203, with 705,537 parameters
using fused AdamW: True


In [4]:

for step in range(max_steps):
    t0 = time.time()
    last_step = (step == max_steps - 1)
    # once in a while generate from the model (except step 0, which is noise)
    if ((step > 0 and step % 250 == 0) or last_step) and (not use_compile):
        model.eval()
        tokens = tokz.encode('身體健康')
        tones = tokz.encode('333332')
        tokens = torch.tensor(tokens, dtype = torch.long)
        tones = torch.tensor(tones, dtype = torch.long)
        tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
        tones = tones.unsqueeze(0).repeat(num_return_sequences, 1)

        xgen = tokens.to(device)
        xtone = tones.to(device)
        sample_rng = torch.Generator(device=device)
        while xgen.size(1) < max_length:
            # forward the model to get the logits
            with torch.no_grad():
                with torch.autocast(device_type=device, dtype=torch.bfloat16):
                    logits, loss = model(xgen,xtone ) # (B, T, vocab_size)
                # take the logits at the last position
                logits = logits[:, -1, :] # (B, vocab_size)
                # get the probabilities
                probs = F.softmax(logits, dim=-1)
                # do top-k sampling of 50 (huggingface pipeline default)
                # topk_probs here becomes (5, 50), topk_indices is (5, 50)
                topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
                # select a token from the top-k probabilities
                # note: multinomial does not demand the input to sum to 1
                ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
                # gather the corresponding indices
                xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
                # append to the sequence
                xgen = torch.cat((xgen, xcol), dim=1)
                pad_tone = torch.tensor([tokz.next_special_token_id-1], dtype = torch.long)
                pad_tone = pad_tone.unsqueeze(0).repeat(num_return_sequences, 1)
                pad_tone = torch.cat((pad_tone,pad_tone), dim = 1)
                pad_tone = pad_tone.to(device)
                xtone = torch.cat((xtone[:, 1:], pad_tone), dim=1) 
        for i in range(num_return_sequences):
            tokens = xgen[i, :max_length].tolist()
            decoded = tokz.decode(tokens)
            print(decoded)                                              
            print('***')
    # do one step of the optimization
    model.train()
    optimizer.zero_grad()
    loss_accum = 0.0
    for micro_step in range(4):
        x, t, y = train_loader.next_batch()       
        x,t,  y = x.to(device),t.to(device), y.to(device)
        # added after video, this field is also used by the forward pass.
        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            logits, loss = model(x, t ,y)
        # we have to scale the loss to account for gradient accumulation,
        # because the gradients just add on each successive backward().
        # addition of gradients corresponds to a SUM in the objective, but
        # instead of a SUM we want MEAN. Scale the loss here so it comes out right
        loss = loss / 4
        loss_accum += loss.detach()
        loss.backward()
    if ddp:
        dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    # determine and set the learning rate for this iteration
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    optimizer.step()
    if device_type == "cuda":
        torch.cuda.synchronize() # wait for the GPU to finish work
    t1 = time.time()
    dt = t1 - t0 # time difference in seconds
    tokens_processed = train_loader.B * train_loader.T * 5
    tokens_per_sec = tokens_processed / dt
    print(f"step {step:5d} | loss: {loss_accum.item():.6f} | lr {lr:.4e} | norm: {norm:.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}")
    with open(log_file, "a") as f:
        f.write(f"{step} train {loss_accum.item():.6f}\n")
    if step == max_steps-1:
        torch.save(model.state_dict(), f"model_weights_step_{step}_trial14.pth")

step     0 | loss: 1.553728 | lr 1.5000e-04 | norm: 3.5668 | dt: 58652.16ms | tok/sec: 43.65
step     1 | loss: 2.595336 | lr 3.0000e-04 | norm: 3.5324 | dt: 477.33ms | tok/sec: 5363.19
step     2 | loss: 1.776985 | lr 4.5000e-04 | norm: 3.3171 | dt: 440.75ms | tok/sec: 5808.26
step     3 | loss: 1.625846 | lr 6.0000e-04 | norm: 2.0644 | dt: 451.41ms | tok/sec: 5671.07
step     4 | loss: 1.804182 | lr 6.0000e-04 | norm: 7.8530 | dt: 439.78ms | tok/sec: 5821.05
step     5 | loss: 1.770599 | lr 6.0000e-04 | norm: 5.4370 | dt: 450.27ms | tok/sec: 5685.50
step     6 | loss: 1.964005 | lr 6.0000e-04 | norm: 1.1620 | dt: 443.06ms | tok/sec: 5777.94
step     7 | loss: 1.574155 | lr 6.0000e-04 | norm: 1.7519 | dt: 449.02ms | tok/sec: 5701.33
step     8 | loss: 1.301316 | lr 6.0000e-04 | norm: 0.6474 | dt: 443.35ms | tok/sec: 5774.27
step     9 | loss: 1.640643 | lr 6.0000e-04 | norm: 6.3813 | dt: 444.58ms | tok/sec: 5758.24
step    10 | loss: 1.599122 | lr 6.0000e-04 | norm: 0.9553 | dt: 441.2

In [None]:
import torch
model = tone_expert(expertConfig())  # Replace with your model class
# Load the weights from the .pth file
weights = torch.load("model_weights_1.pth")

# Create a new dictionary with keys adjusted to remove the '_orig_mod.' prefix
adjusted_weights = {}
for key in weights.keys():
    # Remove the '_orig_mod.' prefix if present
    new_key = key.replace("_orig_mod.", "")
    adjusted_weights[new_key] = weights[key]

# Assuming 'model' is your model instance
# Load the adjusted weights into the model
model.load_state_dict(adjusted_weights)

# Verify that the weights are loaded correctly
for name, param in model.named_parameters():
    print(name, "\t", param.size())

# Print the total number of parameters in the model
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
model.eval()
model.to('cuda')

In [5]:
model.eval()

OptimizedModule(
  (_orig_mod): tone_expert(
    (transformer): ModuleDict(
      (wte): Embedding(70002, 1536)
      (wpe): Embedding(32, 1536)
      (poe): Embedding(70002, 512)
      (ppoe): Embedding(32, 512)
      (h): ModuleList(
        (0-19): 20 x expertBlock(
          (ln_1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (attn): expertAttention(
            (c_attn): Linear(in_features=2048, out_features=6144, bias=True)
            (c_tone_attn): Linear(in_features=512, out_features=256, bias=True)
            (c_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (rotary_emb): RotaryEmbedding()
          )
          (ln_2): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): expertMLP(
            (gate_proj): Linear(in_features=2048, out_features=8192, bias=True)
            (act_fn): SiLU()
            (up_proj): Linear(in_features=2048, out_features=8192, bias=True)
            (down_proj): Linear(in_features=

In [6]:
num_return_sequences =1
device = 'cuda'


In [7]:
tokens = tokz.encode('相聚之少<|answer|>')
tones = tokz.encode('33023033342433')
lt = max(len(tokens) , len(tones))
raw_tones = tones
tones = [948] * (lt  - 1) + [raw_tones[0]]
tokens = [948] * (lt  - len(tokens)) + tokens
tokens = torch.tensor(tokens, dtype = torch.long)
tones = torch.tensor(tones, dtype = torch.long)
max_length = lt + len(raw_tones)
print(max_length)
print(tokens)
print(len(raw_tones))

28
tensor([  948,   948,   948,   948,   948,   948,   948,   948,   948, 30456,
        32858, 20043, 23569, 70000])
14


In [8]:
tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
tones = tones.unsqueeze(0).repeat(num_return_sequences, 1)

xgen = tokens.to(device)
xtone = tones.to(device)
current_tone = 1
sample_rng = torch.Generator(device=device)
while xgen.size(1) < max_length:
    # forward the model to get the logits
    with torch.no_grad():
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            logits, loss = model(xgen,xtone ) # (B, T, vocab_size)
        # take the logits at the last position
        logits = logits[:, -1, :] # (B, vocab_size)
        # get the probabilities
        probs = F.softmax(logits, dim=-1)
        # do top-k sampling of 50 (huggingface pipeline default)
        # topk_probs here becomes (5, 50), topk_indices is (5, 50)
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
        # select a token from the top-k probabilities
        # note: multinomial does not demand the input to sum to 1
        ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
        # gather the corresponding indices
        xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
        # append to the sequence
        xgen = torch.cat((xgen, xcol), dim=1)
        
        if xgen.size(1) == max_length:
            break
        pad_tone = torch.tensor([raw_tones[current_tone]], dtype = torch.long)
        pad_tone = pad_tone.unsqueeze(0).repeat(num_return_sequences, 1)
        pad_tone = pad_tone.to(device)
        xtone = torch.cat((xtone[:, :], pad_tone), dim=1)
        current_tone += 1
        print(f'{xtone=}')
for i in range(num_return_sequences):
    tokens = xgen[i, :max_length].tolist()
    decoded = tokz.decode(tokens)
    print(decoded)
    print('***')

xtone=tensor([[948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948,  51,
          51]], device='cuda:0')
xtone=tensor([[948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948,  51,
          51,  48]], device='cuda:0')
xtone=tensor([[948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948,  51,
          51,  48,  50]], device='cuda:0')
xtone=tensor([[948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948,  51,
          51,  48,  50,  51]], device='cuda:0')
xtone=tensor([[948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948,  51,
          51,  48,  50,  51,  48]], device='cuda:0')
xtone=tensor([[948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948,  51,
          51,  48,  50,  51,  48,  51]], device='cuda:0')
xtone=tensor([[948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948,  51,
          51,  48,  50,  51,  48,  51,  51]], device='cuda:0')
xtone=tensor([[948, 948, 948, 948, 948, 948, 948, 948, 948, 948, 948,