In [5]:
import os 
import matplotlib.pyplot as plt
import re
import shutil
import numpy as np
import tqdm as tqdm
import json
import soundfile as sf
from pprint import pprint

gen_file = "generated/mel_no_pretrain_epoch10.log"
test_json = "OSUFOLDER/test.json"
# since shuffling was not done, order is same.

with open(test_json) as file:
    test_json = json.load(file)

In [None]:
def group_lines(lines):
    ret = []
    for line in lines:
        if '[INFO]' in line:
            ret.append(line.replace("\n",''))
        else:
            ret[-1] += f" {line.strip()}"
    return ret

def group_charts(lines, test_json):
    ret = []
    osujson_fn, osujson_idx = [], 0
    past = 0
    for n,line in enumerate(lines):
        if "precision" in line and "recall" in line and "f1" in line:
            grouped_lines = group_lines(lines[past:n+1])
            if len(grouped_lines) == 0: continue 
            ret.append(grouped_lines)
            osujson_fn.append(test_json[osujson_idx].replace(".osu.json",".osu"))
            past = n+1
            osujson_idx += 1
    return ret, osujson_fn

def get_lines_from_file(fn):
    with open(fn) as file:
        return file.readlines()

lines = get_lines_from_file(gen_file)
charts, osujson_fn = group_charts(lines[2:-1], list(test_json.keys()))
charts[-1]
osujson_fn[-1]

In [13]:
# copy folders from 2026_ddc_rere directory.
dest = "demo"
os.makedirs(dest,exist_ok=True)
dirs = list(set([os.path.dirname(fn) for fn in osujson_fn]))
dests = [os.path.join(dest,d.split('/')[-1]) for d in dirs]
for copyfrom, copyto in zip(dirs, dests):
    shutil.copytree(copyfrom, copyto, dirs_exist_ok=True)

In [None]:
osujson_fn = [e.replace("OSUFOLDER",dest) for e in osujson_fn]
osujson_fn
# Anyways, 
# these are the files we wish to generates demos on.

now, for each .osu file, we will try to output a demo file

In [15]:
def action_token_to_actions(num):
    ret=0
    num-=96
    for i in range(4):
        ret *= 10
        ret += num%3
        num //=3
    return str(ret).zfill(4)

action_token_to_actions(99)


'0100'

In [16]:
def refine_chart_lines(chart):
    chart = [e.split('\t\ttensor(')[-1].replace(", device='cuda:0')",'').strip() for e in chart]
    chart = [("[[]]" if e.startswith('[]') else e)for e in chart]
    chart = [e.replace('[[','').replace(']]','') for e in chart]
    chart = [([int(ee) for ee in e.split(',')] if len(e)>0 else []) for e in chart[1:-1]]
    return chart
    
def chart_to_notes(chart):
    notes = []
    index_of_holdnotes = [-1,-1,-1,-1]
    chart = refine_chart_lines(chart)
    base_beat = 2
    beat_phase, channel = 0, "0000"
    for two_beat_segment in chart:
        for token in two_beat_segment:
            if token<96: beat_phase = token
            else: 
                channel = action_token_to_actions(token)
                for n, (flag, holdindex) in enumerate(zip(channel, index_of_holdnotes)):
                    if flag != '0' and holdindex >=0:
                        notes[holdindex]['end'] = base_beat + beat_phase/48
                        notes[holdindex]['isHold'] = True
                        index_of_holdnotes[n] = -1
                        continue # no more notes can be written
                    if flag == '2':
                        index_of_holdnotes[n] = len(notes)
                    if flag != '0':
                        notes.append({
                            'time':base_beat+beat_phase/48,
                            'line':n,
                            'isHold':False, #edited later if hold
                            'end':base_beat+beat_phase/48, #edited later if hold
                        })
                    
        base_beat += 2
    return notes

In [17]:
def osu_file_by_metadata(filename):
    if not filename.endswith(".osu"):
        raise ValueError("file provided is not .osu chart file")
    with open(filename,'r') as file:
        lines = file.readlines()
    ret, lines = [lines[0]], lines[1:]
    while len(lines[0]) == 0:
        lines = lines[1:]
    
    lines = "".join(lines)
    metadatas = lines.split('[')[1:]
    metadatas = ['[' + lines for lines in metadatas]
    return ret + metadatas

policy_for_metadata_fields = {
    "osu file format v14": "copy",
    "[General]": "copy",
    "[Editor]": "copy",
    "[Metadata]": "copy",
    "[Difficulty]": "copy",
    "[Events]": "copy",
    "[TimingPoints]": "copy",
    "[HitObjects]": "make",
}

def notes_to_hitobjects(notes, osu):
    osujson = f"{osu}.json"
    beatjson = f"{osu}.json.beat.json"
    with open(osujson) as file: offset = json.load(file)['offset']
    with open(beatjson) as file: beatjson = json.load(file)["beat_to_sample_map"]
    audiofilename = getAttribute(osu, "AudioFilename")
    _,sr = sf.read(os.path.join(os.path.dirname(osu),audiofilename))

    linemap = {0:64, 1:192, 2:320, 3:448}
    y = 192
    ret = []
    for note in notes:
        time = round( (beatjson[round(note['time']*48)][1] * 1000/ sr) + offset )
        end = round( (beatjson[round(note['end']*48)][1] * 1000/ sr) + offset )
        if note['isHold']:
            ret.append(f"{linemap[note['line']]},192,{time},128,0,{end}:0:0:0:0:")
        else: # not hold
            ret.append(f"{linemap[note['line']]},192,{time},1,0,0:0:0:0:")

    return ret


In [18]:
def generate_osu(notes, osu, output_file):
    osufile_contents = osu_file_by_metadata(osu) 
    for n,content in enumerate(osufile_contents):
        for k in policy_for_metadata_fields:
            if k in content and policy_for_metadata_fields[k] == 'copy':
                pass #leave it be
            elif k in content:
                hitobjects = '\n'.join(notes_to_hitobjects(notes, osu))
                osufile_contents[n] = f"{k}\n{hitobjects}"

    with open(output_file,'w') as file:
        for content in osufile_contents:
            file.write(content)
            file.write('\n')

In [19]:
def writeAttribute(filename, attribute, value):
    if not filename.endswith(".osu"):
        raise ZeroDivisionError
    with open(filename,'r') as file:
        lines = [e.strip() for e in file.readlines()]
    for n,line in enumerate(lines):
        if line.startswith(f"{attribute}:"):
            lines[n] = f"{attribute}: {value}"
    with open(filename,'w') as file:
        for line in lines:
            file.write(line)
            file.write('\n')

def getAttribute(filename, attribute):
    if not filename.endswith(".osu"):
        raise ValueError("file provided is not .osu chart file")
    with open(filename,'r') as file:
        lines = file.readlines()
    ret = ""
    for line in lines: 
        if line.startswith(f"{attribute}:"):
            ret = line.replace(f"{attribute}:",'').strip()
    return ret            

In [None]:
idx = 0
for chart, osu in tqdm.tqdm(zip(charts,osujson_fn)):
    notes = chart_to_notes(chart)
    output_file = os.path.join(os.path.dirname(osu),f"{idx}.osu")
    generate_osu(notes, osu, output_file)
    writeAttribute(output_file, "Creator", "goctai")
    writeAttribute(output_file, "Version", f"AI-GEN-{idx}")
    idx += 1


In [None]:
len(osujson_fn)
osujson_fn[-1]

In [33]:
# make a credits text file

credits_path = "demo/__credits.txt"

content = ["""
Each .osz file only carries the bare minimum of files (not including artwork or any non-default hitsounds).

The non-note portions of this beatmap were taken from existing beatmaps.
These include all metadata portions of the original beatmap except for the [HitObject] section.

This file lists the source for each, along with the intended target difficulties fed to the AI.

"""]

def get_targetdiff(osujson):
    with open(osujson) as file:
        d = json.load(file)
    return float(d['charts'][0]['difficulty_fine'])

for i, source in enumerate(osujson_fn): 
    content.append(f"{i}.osz \t <- {source.replace('demo/','')}; target diff was {get_targetdiff(source+'.json')}")

content = "\n".join(content)
with open(credits_path, 'w') as file:
    file.write(content)


In [None]:
print(osujson_fn)

In [27]:
# wipe all original beatmaps
for source in osujson_fn:
    os.remove(source)

In [34]:
# wipe .osu.json and .osu.json.beat.json files
for root, dirs, files in os.walk('demo'):
    for file in files:
        if file.endswith(".json"):
            os.remove(os.path.join(root, file))

In [35]:
# move .zip to .osz
for root, dirs, files in os.walk('demo'):
    for file in files:
        if file.endswith(".zip"):
            source = os.path.join(root,file)
            dest = os.path.join(root,file[:-4]+".osz")
            os.rename(source,dest)

In [45]:
# let's test with single file first.



getAttribute("demo/984247/AAAA + Silentroom - UNDERVEIL IS REAL!!! (angki) [ANGKI IS REAL!!!].osu","Title")

'UNDERVEIL IS REAL!!!'

In [None]:






osu_file_by_metadata("demo/984247/AAAA + Silentroom - UNDERVEIL IS REAL!!! (angki) [ANGKI IS REAL!!!].osu")

In [47]:
# notes should be organized to have hold-times for each note


refine_chart_lines(charts[-1])[:10]

[[], [], [], [], [], [], [], [], [48, 98], [0, 133, 48, 99, 72, 105]]

In [48]:
# taps: x,y,time,type,hitSound,objectParams,hitSample
# holds: x,y,time,type,hitSound,endTime:hitSample
# type is 0 for tap, 128 for hold

# example hold  448,192,19047,128,2,19214:0:0:0:0:
# example tap   64,192,17880,1,0,0:0:0:0:
notes = [] # sorted by time
# each note : { time in beats, line, isHold, hold end (to be filled in later) }


In [49]:
# translates notes to osu hitobjects



In [50]:
output_file = "demo/984247/test.osu"

policy_for_metadata_fields = {
    "osu file format v14": "copy",
    "[General]": "copy",
    "[Editor]": "copy",
    "[Metadata]": "copy",
    "[Difficulty]": "copy",
    "[Events]": "copy",
    "[TimingPoints]": "copy",
    "[HitObjects]": notes_to_hitobjects,
}

osu = "demo/984247/AAAA + Silentroom - UNDERVEIL IS REAL!!! (angki) [ANGKI IS REAL!!!].osu"
osufile_contents = osu_file_by_metadata(osu)

