In [None]:
import os
import json
from dotenv import load_dotenv
from openai import OpenAI

In [None]:
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=api_key)

def classify_song_with_gpt(song_name):
    try:
        prompt = f"Classify the following song '{song_name}' into three music genres or styles. Provide three classifications separated by commas, for example: ['disco', 'funky', '80s']"
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "user", "content": prompt}
            ]
        )
        message_content = response.choices[0].message.content.strip()
        return message_content
    except Exception as e:
        return f"Error: Unable to classify song {song_name} due to {str(e)}"

def process_lakh_midi_directory():
    recognized_songs_classifications = []
    base_dir = os.getcwd()
    data_dir = os.path.join(base_dir, '..', 'data')
    clean_midi_dir = os.path.join(data_dir, 'clean_midi')

    for artist_folder in os.listdir(clean_midi_dir):
        artist_folder_path = os.path.join(clean_midi_dir, artist_folder)
        if os.path.isdir(artist_folder_path):
            song_files = [f for f in os.listdir(artist_folder_path) if f.endswith('.mid')]
            if song_files:
                song_name = song_files[0].replace('.mid', '')
                print(f"Found song: {song_name} in {artist_folder_path}")
                
                classification = classify_song_with_gpt(song_name)
                
                if "Error" not in classification:
                    print(f"Classifications for {song_name}: {classification}")
                    recognized_songs_classifications.append({
                        "song": song_name,
                        "classification": classification
                    })
                else:
                    print(f"Skipping {song_name}: {classification}")

    return recognized_songs_classifications

recognized_songs_classifications = process_lakh_midi_directory()

with open('recognized_songs_classifications.json', 'w') as f:
    json.dump(recognized_songs_classifications, f, indent=4)

print("Finished processing songs.") 