In [None]:
# allow relative import from parent directory
import sys  
from pathlib import Path
sys.path.insert(0, str(Path().resolve().parents[0]))

# Automatically generate OSU beatmaps
1. Initiate a model from tokenzier settings + checkpoint
2. Upload MP3
3. Generate beatmap based on audio and sampling settings
4. Convert generated beatmap to .OSZ
5. Download file

In [None]:
# load the model
import torch
from beatlearning.configs import QuaverBEaRT
from beatlearning.tokenizers import BEaRTTokenizer
from beatlearning.models import BEaRT

model_config = QuaverBEaRT()
tokenizer = BEaRTTokenizer(model_config)
model = BEaRT(tokenizer)
model.load("../models/checkpoint.pt")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)

# ONLY USE AUDIO CONCENT FOR WHICH YOU HOLD THE RIGHTS!

In [None]:
# Upload MP3
from ipywidgets import FileUpload
from IPython.display import Audio
import librosa
import tempfile
import os

upload = FileUpload(accept=".mp3", multiple=False)
upload

In [None]:
# SET the following:
title = "Untitled"
artist = "Unknown Artist"
source = "Unknown Source"

audio_start = 0.0  # beatmap will start from this point, but all audio will be used
audio_end = None  # if None, all audio will be used
difficulty = 0.5  # between 0. - 1. where 1. is insane difficulty
####################

dirpath = tempfile.mkdtemp()
mp3 = os.path.join(dirpath, 'audio.mp3')
with open(mp3, 'wb') as output_file: 
    try:
        content = upload.value[-1]['content']   
    except IndexError:
        raise IndexError("You forgot to upload a file in the previous cell!")
    output_file.write(content)
y, sr = librosa.load(mp3, offset=audio_start, duration=None if audio_end is None else audio_end - audio_start)
Audio(data=y, rate=sr)

In [None]:
# Generate beatmap

# ignore all holds, use only hits:
hits_only = {hold: -float("inf") for hold in range(len(tokenizer.RESERVED_TOKENS) + 1, 60)}
# hits_only[len(tokenizer.RESERVED_TOKENS)] = -0.01

ibf = model.generate(audio_file=mp3, audio_start = audio_start, audio_end = None, 
                    use_tracks = ["LEFT"],  # only ["LEFT"] for OSU is supported at the moment
                    difficulty = difficulty,
                    beams = [2] * 8,    # use lower when on CPU
                    top_k = 2,          # top_k is randomly sampled after beam search
                    temperature = 1.0,  # < 1 more conservative, > 1 more creative but is off more often
                    logit_bias = hits_only)

result = os.path.join(dirpath, 'tmp.idf')
ibf.save(result)
ibf.data.head()

In [None]:
# Convert generated beatmap to OSZ file format

from beatlearning.converters import OsuBeatmapConverter

osu = os.path.join(dirpath, 'tmp.osz')
converter = OsuBeatmapConverter()
converter.generate(result, osu, 
                   meta={
                        "lead_in": 0, 
                        "title": title,
                        "artist": artist,
                        "source": source,
                        "difficulty_name": ["easy", "normal", "hard", "insane"][min(3, int(difficulty * 3))],
                        "hp_drain_rate": 5,
                        "overall_difficulty": int(7 * difficulty),
                        "approach_rate": 5, 
                        "slider_multiplier": 1.8,
                        "bg": None,
                   })

In [None]:
# Download OSZ file

from ipywidgets import HTML
from IPython.display import display
import base64

with open(osu, "rb") as f:
    osu_file = f.read()

b64 = base64.b64encode(osu_file)
payload = b64.decode()
output_file = "beatmap.osz"

html_button = f'''<html><head><meta name="viewport" content="width=device-width, initial-scale=1"></head>
<body>
<a download="{output_file}" href="data:text/csv;base64,{payload}" download>
<button class="p-Widget jupyter-widgets jupyter-button widget-button mod-warning">Download File</button></a>
</body></html>
'''
display(HTML(html_button))

Hey hey! Let me know if you have another MP3 to convert! 