In [1]:
import torch
import time
from utils.utils import generate_mask, load_model, writeDACFile

In [2]:
### params
cptnum=50
checkpoint_dir='runs/scratchpistons'
fnamebase="pistons.e256.l4.h8" + "_chkpt_" + str(cptnum).zfill(4)
 
checkpoint_path = checkpoint_dir + '/' +  fnamebase  + '.pth' 

# for saving sound 
outdir=checkpoint_dir

DEVICE='cuda'

inference_steps=86*20  #86 frames per second

In [3]:
torch.cuda.device_count()
torch.cuda.get_device_properties(0).total_memory/1e9

device = torch.device(DEVICE) # if the docker was started with --gpus all, then can choose here with cuda:0 (or cpu)
torch.cuda.device_count()
print(f'memeory on cuda 0 is  {torch.cuda.get_device_properties(0).total_memory/1e9}')

device

memeory on cuda 0 is  25.216745472


device(type='cuda')

In [4]:

model, Ti, vocab_size, num_codebooks = load_model(checkpoint_path)
model.to(device);


In [5]:
def inference(model, Ti, vocab_size, num_tokens, inference_steps, fname) :
    model.eval()
    mask = generate_mask(Ti, Ti).to(device)
    input_data = torch.randint(0, vocab_size, (1, Ti, num_tokens)).to(device)  # Smaller context window for inference
    predictions = []

    t0 = time.time()
    for i in range(inference_steps):  # Generate 100 tokens
        output = model(input_data, mask)

        # This takes the last vector of the sequence (the new predicted token stack) so has size(b,1,4,1024)
        # This it takes the max across the last dimension (scores for each element of the vocabulary (for each of the 4 tokens))
        # .max returns a duple of tensors, the first are the max vals (one for each token) and the second are the
        #        indices in the range of the vocabulary size. 
        # THAT IS, the 4 selected "best" tokens are taken independently
        next_token = output[:, -1, :, :].max(-1)[1]  # Greedy decoding for simplicity
        predictions.append(next_token)
        input_data = torch.cat([input_data, next_token.unsqueeze(1)], dim=1)[:, 1:]  # Slide window

    t1 = time.time()
    inf_time = t1-t0
    print(f'inference time for {inference_steps} steps, or {inference_steps/86} seconds of sound is {inf_time}' )

    dacseq = torch.cat(predictions, dim=0).unsqueeze(0).transpose(1, 2)
    if mask == None:
        writeDACFile(fname + '_unmasked', dacseq)
    else :
        writeDACFile(fname, dacseq)       

In [6]:


inference(model, Ti, vocab_size, num_codebooks, inference_steps, outdir+"/"+ fnamebase+"_steps_"+str(inference_steps).zfill(4)) 


inference time for 1720 steps, or 20.0 seconds of sound is 4.36799693107605
