In [None]:
import torch
import numpy as np
from BEATs import BEATs, BEATsConfig

def load_model(model_path):
    checkpoint = torch.load(model_path)
    cfg = BEATsConfig(checkpoint['cfg'])
    model = BEATs(cfg)
    model.load_state_dict(checkpoint['model'])
    model.eval()
    return model, checkpoint['label_dict']

def extract_features(model, audio_input):
    padding_mask = torch.zeros(audio_input.size()).bool()
    with torch.no_grad():
        representation = model.extract_features(audio_input, padding_mask=padding_mask)[0]
    return representation

def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def find_closest_song(model, song_embeddings, new_song_embedding):
    similarities = [cosine_similarity(song_embedding, new_song_embedding) for song_embedding in song_embeddings]
    closest_index = np.argmax(similarities)
    return closest_index

def main(song_paths, new_song_path, model_path):
    model, label_dict = load_model(model_path)

    # Load and process audio files (you need to implement this part)
    song_embeddings = []
    for song_path in song_paths:
        audio_input = load_audio(song_path)  # Implement this function to load and preprocess audio
        embedding = extract_features(model, audio_input)
        song_embeddings.append(embedding.numpy())

    # Process new song
    new_audio_input = load_audio(new_song_path)  # Implement this function to load and preprocess audio
    new_song_embedding = extract_features(model, new_audio_input).numpy()

    # Find the closest song
    closest_index = find_closest_song(model, song_embeddings, new_song_embedding)

    print(f'The new song most closely resembles: {song_paths[closest_index]}')

if __name__ == "__main__":
    # Example usage
    song_paths = [
        'path/to/song1.wav',
        'path/to/song2.wav',
        'path/to/song3.wav',
        'path/to/song4.wav'
    ]
    new_song_path = 'path/to/new_song.wav'
    model_path = '/path/to/model.pt'
    
    main(song_paths, new_song_path, model_path)
