In [None]:
# allow relative import from parent directory
import sys  
from pathlib import Path
sys.path.insert(0, str(Path().resolve().parents[0]))

# Automatically generate OSU beatmaps
1. Initiate a model from tokenzier settings + checkpoint
2. Upload MP3
3. Generate beatmap based on audio and sampling settings
4. Convert generated beatmap to .OSZ
5. Download file

In [None]:
# downlaod model checkpoint from HF

from huggingface_hub import hf_hub_download
checkpoint = hf_hub_download(repo_id="sedthh/BeatLearning", filename="quaver_beart_v1.pt")

In [None]:
# load the model
import torch
from beatlearning.configs import QuaverBEaRT
from beatlearning.tokenizers import BEaRTTokenizer
from beatlearning.models import BEaRT

model_config = QuaverBEaRT()
tokenizer = BEaRTTokenizer(model_config)
model = BEaRT(tokenizer)

if torch.cuda.is_available():
    device = "cuda:0"
    model.load(checkpoint)
else:
    device = "cpu"
    model.load(checkpoint, map_location=torch.device("cpu"))
model.to(device)

print("device:", device)

# ONLY USE AUDIO CONCENT FOR WHICH YOU HOLD THE RIGHTS!

In [None]:
# Upload MP3
from ipywidgets import FileUpload
from IPython.display import Audio
import librosa
import tempfile
import os
import uuid

beatmap_uid = uuid.uuid4().hex
upload = FileUpload(accept=".mp3", multiple=False)
upload

In [None]:
# SET the following:
title = "Untitled" 
artist = "Unknown Artist"
source = "MP3"

# note: you can get slightly better results if you get the start of the very first beat right
audio_start = 0.0  # beatmap will start from this point, but all audio will be used
audio_end = None  # if None, the entire audio will be used
difficulty = 0.5  # between 0. - 1. where 1. is harder difficulty
random_seed = 69420
####################

dirpath = tempfile.mkdtemp()
mp3 = os.path.join(dirpath, 'audio.mp3')
with open(mp3, 'wb') as output_file: 
    try:
        content = upload.value[-1]['content']   
    except IndexError:
        raise IndexError("You forgot to upload a file in the previous cell!")
    output_file.write(content)
y, sr = librosa.load(mp3, offset=audio_start, duration=None if audio_end is None else audio_end - audio_start)
Audio(data=y, rate=sr)

In [None]:
# Generate beatmap

"""
The model currently has the tendency to predict a lot of holds,
as a quick workaround, you can disable all holds by adding this logit_bias:
"""
logit_bias = {hold: -float('Inf') for hold in range(9, 60)}

ibf = model.generate(audio_file=mp3, audio_start = audio_start, audio_end = audio_end, 
                     use_tracks = ["LEFT"],  # only ["LEFT"] for OSU is supported at the moment
                     difficulty = difficulty,
                     logit_bias = logit_bias, # comment this line to enable holds
                     beams = [2] * (4 if difficulty > 0.5 else 8),
                     max_beam_width = 256,  # lower values are faster but less accurate
                     temperature = 0.1, # you usually want a low temperature for better accuracy
                     random_seed = random_seed)

result = os.path.join(dirpath, 'tmp.ibf')
ibf.save(result)

In [None]:
# Upload background image (OPTIONAL)

image = FileUpload(multiple=False)
image

In [None]:
# Convert generated beatmap to OSZ file format

from beatlearning.converters import OsuBeatmapConverter

# try to add background image
bg = os.path.join(dirpath, 'BG.png')
with open(bg, 'wb') as output_file: 
    try:
        content = image.value[-1]['content']   
    except (IndexError, NameError):
        bg = None
    output_file.write(content)

# create osz file
osu = os.path.join(dirpath, 'tmp.osz')
converter = OsuBeatmapConverter()
converter.generate(result, osu, 
                   meta={
                    "title": title,
                    "artist": artist,
                    "source": source,
                    "difficulty_name": ["easy", "normal", "hard"][min(2, int(difficulty * 2))],
                    "overall_difficulty": int(7 * difficulty),
                    "hp_drain_rate": 1 + int(6 * difficulty), 
                    "approach_rate": 1 + int(6 * difficulty), 
                    "bg": bg,
                    "osu_file": f"beatmap_{beatmap_uid}.osu"
                   })

In [None]:
# Download OSZ file

from ipywidgets import HTML
from IPython.display import display
import base64

with open(osu, "rb") as f:
    osu_file = f.read()

b64 = base64.b64encode(osu_file)
payload = b64.decode()
output_file = f"beatmap_{beatmap_uid}.osz"

html_button = f'''<html><head><meta name="viewport" content="width=device-width, initial-scale=1"></head>
<body>
<a download="{output_file}" href="data:text/csv;base64,{payload}" download>
<button class="p-Widget jupyter-widgets jupyter-button widget-button mod-warning">Download File</button></a>
</body></html>
'''
display(HTML(html_button))

Hey hey! Let me know if you have another MP3 to convert! 