
## Setup

In [None]:
!git clone https://github.com/patchbanks/Pop-K.git
%cd /content/Pop-K

In [None]:
!pip install -q lightning-fabric pretty_midi wandb ninja

## Process MIDI

In [None]:
!python data/midi_to_text.py  \
    --midi_dir='data/midi_data' \
    --dataset_name='pop_k_test'

## Training Options

In [None]:
from shutil import copy
import os

base_model_name = "model" #@param {type: "string"}
base_model_path = "models/{base_model_name}*.pth"
tuned_model_name = base_model_name
output_path = 'checkpoints'


def training_options():
    datafile = "data/train_data/pop_k_test.txt" #@param {type: "string"}
    lr_init = 0.00001 #@param {type: "number"}
    lr_final = lr_init
    n_epoch = 1 #@param {type: "number"}
    epoch_save_frequency = 20 #@param {type: "number"}
    batch_size = 4 #@param {type: "number"}
    n_layer = 12 #@param {type: "number"}
    n_embd = 768 #@param {type: "number"}
    ctx_len = 2048 #@param {type: "number"}
    LOAD_MODEL = False # @param {type:"boolean"}
    EPOCH_BEGIN = 0 #@param {type: "number"}
    epoch_save_path = f"{output_path}/{tuned_model_name}"

    return locals()

def model_options():
    T_MAX = 2048 #@param {type: "number"}
    return locals()

def env_vars():
    RWKV_FLOAT_MODE = 'fp16' # ['fp16', 'bf16', 'bf32'] {type:"string"}
    RWKV_DEEPSPEED = '0' # ['0', '1'] {type:"string"}
    return {f"os.environ['{key}']": value for key, value in locals().items()}

def replace_lines(file_name, to_replace):
    with open(file_name, 'r') as f:
        lines = f.readlines()
    with open(f'{file_name}.tmp', 'w') as f:
        for line in lines:
            key = line.split(" =")[0]
            if key.strip() in to_replace:
                value = to_replace[key.strip()]
                if isinstance(value, str):
                    f.write(f'{key} = "{value}"\n')
                else:
                    f.write(f'{key} = {value}\n')
            else:
                f.write(line)
    copy(f'{file_name}.tmp', file_name)
    os.remove(f'{file_name}.tmp')

values = training_options()
values.update(env_vars())
replace_lines('train.py', values)
replace_lines('src/model.py', model_options())

## Train

In [None]:
!python train.py

## Generate

In [None]:
!python generate.py \
    --model_name checkpoints/model \
    --num_samples 10 \
    --temperature 1.0\
    --top_k 20

In [None]:
#@title Zip MIDI
import zipfile
import os

base_name = input("Enter zip file name: ")
zip_file_path = f'/content/{base_name}.zip'

if os.path.exists(zip_file_path):
    count = 1
    while os.path.exists(zip_file_path):
        zip_file_path = f'/content/{base_name}-{count}.zip'
        count += 1

midi_output_dir = '/content/Pop-K/midi_output'

with zipfile.ZipFile(zip_file_path, 'w') as zipf:
    for root, _, files in os.walk(midi_output_dir):
        for file in files:
            file_path = os.path.join(root, file)
            arcname = os.path.join(base_name, file)
            zipf.write(file_path, arcname)

print('Saved to zip:', zip_file_path)


## Import MIDI

In [None]:
import zipfile
import os
import shutil

import_zip_path = '/content/Archive.zip' #@param {type: "string"}
temp_folder = '/content/temp' #@param {type: "string"}
dst_folder = 'data/midi_data' #@param {type: "string"}

with zipfile.ZipFile(import_zip_path, 'r') as zip_ref:
    zip_ref.extractall(temp_folder)

temp_path = os.path.join(temp_folder, '__MACOSX')
if os.path.exists(temp_path) and os.path.isdir(temp_path):
    shutil.rmtree(temp_path)

os.makedirs(dst_folder, exist_ok=True)

for root, dirs, files in os.walk(temp_folder):
    for file in files:
        if file.endswith('.mid') or file.endswith('.MID'):
            src_file = os.path.join(root, file)
            dst_file = os.path.join(dst_folder, file)

            # Avoid overwriting files
            if not os.path.exists(dst_file):
                shutil.move(src_file, dst_file)
                print(f"Moved {src_file} to {dst_file}")
            else:
                print(f"File {dst_file} already exists. Skipping move.")
