# Thanks
The team at OpenAI for making Jukebox: Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford, Ilya Sutskever,

SMarioMan and Zag, for authoring the original Colab notebooks; and [Johannezz Music](https://deeplearn.art) for making the notebook this one is based on.

changes from Johannez:

- removed co-composing (was getting more interesting results by letting the batches complete)
- fixed lyrics generation (lyrics were being written into info.txt, but could not be read back from file)
- can interpolate two artists/genres

# Instructions:
- You need ColabPro (free tier will not work).
- Factory Reset before you start.
- Fresh Batch: Please remember to delete info.txt, then "Run All"
- Upsampling Only: Click through cell by cell... omit the "Initial Generation" cell and skip to "Upsampling."  

# Setup

In [None]:
#@title Run this cell to see your GPU
!nvidia-smi -L

In [None]:
#@title Run this cell to connect Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

# Project options

The folder on Google drive where the primer wav is located. 
This notebook will put everything else in the same folder.

In [None]:
PROJECT_DIRECTORY = 'jukebox34' #@param {type:"string"}


Enter lyrics into the code-block

In [None]:
LYRICS = """
enter badass lyrics here
"""

If 'info.txt' exists in the project folder, then the following section will be overriden with the settings from that file. 

Please remember to delete 'info.txt' if starting a fresh batch.


In [None]:
#@title Run this cell to setup the project.  

from pathlib import Path

WAV = '' #@param {type:"string"}
ARTIST= '' #@param {type:"string"}
GENRE = '' #@param {type:"string"}

SEC_ARTIST= '' #@param {type:"string"}
SEC_GENRE = '' #@param {type:"string"}
SEC_WEIGHT = #@param {type:"number"}

PROMPT_LENGTH =  5#@param {type:"number"}
INITIAL_SONG_LENGTH =  150#@param {type:"number"}
TOTAL_SONG_LENGTH =  150#@param {type:"number"}

META = ''#@param {type:"string"}
NOTE = "" #@param {type:"string"}
SONG_LENGTH = 0

sampling_temperature = .985#@param {type:"number"}
level_2_hops = .125#@param {type:"number"}
hops = [1, .5, level_2_hops] 

# Zag's variables
leexportlyrics = True
leprogress = True
leautorename = True 

# saving and loading project options
info = {}
song_length = 0
project_dir = f'/content/gdrive/MyDrive/{PROJECT_DIRECTORY}/'
fname = "info.txt"

# define some functions to save and read info.txt
def read_info():
  with open(project_dir + fname) as f:
    for line in f:
      k, v = line.partition(":")[::2]  
      info[k.strip()] = v.strip() 
      
def init_info(dir_path):
  f = open(dir_path  + fname, "w")
  f.write(f'''
       PROJECT_DIRECTORY : {PROJECT_DIRECTORY}\n 
       WAV : {WAV}
       ARTIST: {ARTIST} 
       GENRE : {GENRE}
       SEC_ARTIST: {SEC_ARTIST} 
       SEC_GENRE: {SEC_GENRE}
       SEC_WEIGHT: {SEC_WEIGHT}
       PROMPT_LENGTH : {PROMPT_LENGTH}
       INITIAL_SONG_LENGTH : {INITIAL_SONG_LENGTH}
       TOTAL_SONG_LENGTH :  {TOTAL_SONG_LENGTH}
       LYRICS : {LYRICS}
       META : {META} 
       NOTE : {NOTE}
       SONG_LENGTH : {SONG_LENGTH}
  ''')
  f.close()
  f = open(dir_path  + "lyrics.txt", "w")
  f.write(f'''{LYRICS}''')
  f.close()

def save_info(dir_path):
  !rm "{dir_path + fname}"
  f = open(dir_path  + fname, "w")
  f.write(f'''
       PROJECT_DIRECTORY : {info['PROJECT_DIRECTORY']}\n 
       WAV : {info['WAV']}
       ARTIST: {info['ARTIST']} 
       GENRE : {info['GENRE']}
       SEC_ARTIST: {info['SEC_ARTIST']} 
       SEC_GENRE : {info['SEC_GENRE']}
       SEC_WEIGHT: {info['SEC_WEIGHT']} 
       PROMPT_LENGTH : {info['PROMPT_LENGTH']}
       INITIAL_SONG_LENGTH : {info['INITIAL_SONG_LENGTH']}
       TOTAL_SONG_LENGTH :  {info['TOTAL_SONG_LENGTH']}
       LYRICS : {info['LYRICS']}
       META : {info['META']} 
       NOTE : {info['NOTE']}
       SONG_LENGTH : {info['SONG_LENGTH']}
  ''')
  f.close()  

def print_fancy(txt):
  print('\x1b[6;30;42m' + txt + '\x1b[0m')

def backup(note):
  d = datetime.datetime.now()
  date_str = d.strftime("%m-%d-%Y")
  bak_dir = f'{hps.name}backup-{date_str}-{note}/'
  info['PROJECT_DIRECTORY'] = bak_dir
  info['NOTE'] = note
  !mkdir "{bak_dir}"
  #wav = f"{hps.name}zs-checkpoint.t"
  !cp "{hps.name}"*.wav "{bak_dir}"
  cpoint = f"{hps.name}zs-checkpoint.t"
  !cp "{cpoint}" "{bak_dir}"
  save_info(bak_dir)  

# check if new project or continuation  
if Path(project_dir + fname).is_file() == False:
  init_info(project_dir)

# the variables used in generation are reloaded from info.txt
read_info()  
INITIAL_SONG_LENGTH = int(info['INITIAL_SONG_LENGTH'])
PROMPT_LENGTH = int(info['PROMPT_LENGTH'])
TOTAL_SONG_LENGTH = int(info['TOTAL_SONG_LENGTH'])
SONG_LENGTH = int(info['SONG_LENGTH'])
song_length = SONG_LENGTH
sec_art, sec_gen, sec_wgt = info['SEC_ARTIST'], info['SEC_GENRE'], info['SEC_WEIGHT']


In [None]:
#@title Run this cell to install Jukebox. Do it only once in a session!

!pip install git+https://github.com/songeater/jukebox
import jukebox
import torch as t
import librosa
import os
import glob
import torch
import datetime
from IPython.display import Audio
from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model
from jukebox.hparams import Hyperparams, setup_hparams
from jukebox.sample import sample_single_window, _sample, \
                           sample_partial_window, upsample, \
                           load_prompts, \
                           sample_level
from jukebox.utils.dist_utils import setup_dist_from_mpi
from jukebox.utils.torch_utils import empty_cache
rank, local_rank, device = setup_dist_from_mpi()


In [None]:
#@title Run this cell to complete the setup. (Wait patiently)
model = "5b_lyrics" 
hps = Hyperparams()
hps.sr = 44100
hps.n_samples = 3 
hps.name = project_dir
hps.hop_fraction = hops
hps.levels = 3
lower_batch_size = 16
max_batch_size = 3 
lower_level_chunk_size = 32
chunk_size = 16 

vqvae, *priors = MODELS[model]
vqvae = make_vqvae(setup_hparams(vqvae,dict(sample_length = 1048576)), device)
top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)

# find out if we have saved anything
checkpoint = ""
if os.path.exists(hps.name):
  data = f"{hps.name}/level_1/data.pth.tar"
  compose = f"{hps.name}/zs-checkpoint.t"
  
  if os.path.isfile(data):
    mode = 'upsample'
    codes_file = data
    checkpoint = "data"
    audio_file=None
  elif os.path.isfile(compose):
    checkpoint = "compose"  
    
if checkpoint != "":   
  print_fancy(f'Found checkpoint of type {checkpoint}')

if checkpoint != 'data': 
  mode = 'primed'
  codes_file=None
  audio_file = project_dir + info['WAV']

sample_hps = Hyperparams(dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=PROMPT_LENGTH))

sample_length_in_seconds = TOTAL_SONG_LENGTH 
hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens
assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate'

metas = [dict(artist = info['ARTIST'],
          genre = info['GENRE'],
          total_length = hps.sample_length,
          offset = 0,
          lyrics = LYRICS,
          ),
        ] * hps.n_samples

#if checkpoint != 'data': 
#  labels = top_prior.labeller.get_batch_labels(metas, 'cuda')
#  sampling_kwargs = dict(temp=sampling_temperature, fp16=True, max_batch_size=lower_batch_size,
#                    chunk_size=lower_level_chunk_size)    
#else:
labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')]
sampling_kwargs = [dict(temp=.99, fp16=True, max_batch_size=lower_batch_size,
                        chunk_size=lower_level_chunk_size),
                    dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size,
                         chunk_size=lower_level_chunk_size),
                    dict(temp=sampling_temperature, fp16=True, 
                         max_batch_size=max_batch_size, chunk_size=chunk_size)]

def seconds_to_tokens(sec, sr, prior, chunk_size):
  tokens = sec * hps.sr // prior.raw_to_tokens
  tokens = ((tokens // chunk_size) + 1) * chunk_size
  print(f"tokens {tokens} prior.n_ctx {prior.n_ctx}")
  #if tokens > prior.n_ctx:
   # tokens = prior.n_ctx
  return tokens  
  

def get_length():
  the_length = 0
  level_0_wav = glob.glob(os.path.join(hps.name, 'level_0/*0.wav'))
  level_1_wav = glob.glob(os.path.join(hps.name, 'level_1/*0.wav'))
  if level_0_wav != []:
    print_fancy("resuming level_0")
    the_length = librosa.get_duration(filename=level_0_wav[0]) 
  else:
    print_fancy("resuming level_1")
    the_length = librosa.get_duration(filename=level_1_wav[0]) 
  return the_length

the_length = get_length() if checkpoint == 'data' else INITIAL_SONG_LENGTH
print(f"rendered length {the_length}")
tokens_to_sample = seconds_to_tokens(the_length, hps.sr, top_prior, chunk_size)  

def write_audio(start):
  for i in range(hps.n_samples):
    librosa.output.write_wav(f'clip_{i}.wav', x[i][start*hps.sr:], sr=hps.sr)

# change prior.py to add the second artist/genre/weight
filex = "/usr/local/lib/python3.7/dist-packages/jukebox/prior/prior.py"
fin = open(filex, "rt")
data = fin.read()
fin.close()
new_text = f"sec_art, sec_gen, sec_wgt = \"{sec_art}\", \"{sec_gen}\", {sec_wgt}"
data = data.replace('#NEWTEXT gets inserted here',new_text)
fin = open(filex, "wt")
fin.write(data)
fin.close()

In [None]:
#@title Run this cell to free memory
!nvidia-smi 
empty_cache()
!nvidia-smi 

# Initial Generation

This will generate the starting point of your song.

**Skip this step if you are resuming from an earlier session.**

In [None]:
#@title Run this cell to generate the beginning of your song
  
if checkpoint == '':
  # initial generation
  assert sample_hps.audio_file is not None
  audio_files = sample_hps.audio_file.split(',')
  duration = (int(sample_hps.prompt_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens
  print("audio files",audio_files)
  x = load_prompts(audio_files, duration, hps)
  zs = top_prior.encode(x, start_level=0, end_level=len(priors), bs_chunks=x.shape[0])
  hop_length = int(hps.hop_fraction[2]*top_prior.n_ctx)
  zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps)
  x = vqvae.decode(zs[2:], start_level=2).cpu().numpy()
  song_length = INITIAL_SONG_LENGTH
  print_fancy(f"Song length: {song_length} seconds. Remaining time: {TOTAL_SONG_LENGTH - song_length}") 
  write_audio(0)      
  t.save(zs, 'zs-checkpoint.t')
  !cp /content/zs-checkpoint.t "{hps.name}"
  !cp /content/clip_*.wav "{hps.name}"
  info["SONG_LENGTH"] = song_length
  save_info(hps.name)

# Upsample

This last step renders your songs with better quality audio. It will take some time. The first pass is about 95 minutes for 45 seconds clip with Tesla P100 as GPU.

The three upsampled clips can be found in a new folder called **level_1** inside your project directory. They are being saved while the upsampling is happening, and you can listen to them during the process.

You can then render your clips to **level_0** (the best quality), but that will take even longer, and on Colab free tier, if it goes beyond 8 hours or so, your session might terminate before the clips are finished. If that happens, just open this notebook and pick up from where you left. 

In [None]:
#@title Run this cell to make your clips better sounding


if (checkpoint != "data"):
  print("!=data")
  labels = top_prior.labeller.get_batch_labels(metas, 'cuda')
  sampling_kwargs = dict(temp=sampling_temperature, fp16=True, max_batch_size=lower_batch_size,
                    chunk_size=lower_level_chunk_size)
  zs = t.load(f'{hps.name}/zs-checkpoint.t')
  assert zs[2].shape[1]>=2048, f'Please first generate at least 2048 tokens at the top level, currently you have {zs[2].shape[1]}'
  hps.sample_length = zs[2].shape[1]*top_prior.raw_to_tokens 
  del top_prior
  empty_cache()
  top_prior=None
  upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]
  sampling_kwargs = [dict(temp=.99, fp16=True, max_batch_size=16, chunk_size=32),
                    dict(temp=0.99, fp16=True, max_batch_size=16, chunk_size=32),
                    None]
  if type(labels)==dict:
    labels = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers] + [labels] 
  

else: #upsampling from level 1
  print("upsampling from level 1")
  assert sample_hps.codes_file is not None
  data = t.load(sample_hps.codes_file, map_location='cpu')
  zs = [z.cuda() for z in data['zs']]
  assert zs[-1].shape[0] == hps.n_samples, f"Expected bs = {hps.n_samples}, got {zs[-1].shape[0]}"
  del data
  del top_prior
  empty_cache()
  top_prior=None
  upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]
  labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers]
  
 
zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)
del upsamplers
empty_cache()
print("Done.")