In [None]:
import pickle
import os
import csv
import string
import librosa
import matplotlib.pyplot as plt
import numpy as np
import shutil
import soundfile as sf
import torch
import h5py
from PIL import Image
from scipy.ndimage import zoom
from librosa.util import normalize
from librosa.util import fix_length
from sklearn.model_selection import train_test_split
from google.colab import drive
from google.colab import runtime
from collections import defaultdict

drive.mount('/content/drive')

Mounted at /content/drive


Paths

In [None]:
mp3_data_path = '/content/drive/My Drive/Projects/NeuraBeat/Data/fma_small/'
csv_path = '/content/drive/My Drive/Projects/NeuraBeat/tracks.csv'

Clean Filesystem

In [None]:
# !apt install ffmpeg
# !wget -O fma.zip https://os.unil.cloud.switch.ch/fma/fma_small.zip
# !unzip fma.zip
# !rm fma.zip

In [None]:
# os.remove(mp3_data_path + "README.txt")
# os.remove(mp3_data_path + "checksums")
# for root, dirs, files in os.walk(mp3_data_path):
#     for file in files:
#         file_path = os.path.join(root, file)
#         if file_path != mp3_data_path + file:
#             shutil.move(file_path, mp3_data_path)

# for root, dirs, _ in os.walk(mp3_data_path, topdown=False):
#     for folder in dirs:
#         folder_path = os.path.join(root, folder)
#         if os.path.isdir(folder_path):
#             os.rmdir(folder_path)

Create File:Genre Map

In [None]:
file_genre_map = {}  # Dictionary to store file-genre mapping
track_ids = [file_name.split('.')[0].lstrip('0') for file_name in os.listdir(mp3_data_path) if file_name.endswith('.mp3')]

# Read CSV file and create file-genre mapping
with open(csv_path, 'r') as csvfile:
    csvreader = csv.reader(csvfile)
    next(csvreader) # Skip headers
    next(csvreader)
    next(csvreader)
    for row in csvreader:
        if row[0] in track_ids:
            genre = row[40]
            file_genre_map[row[0]] = genre

In [None]:
# Initialize genre dist dictionary
genre_dist = {}
total_songs = 0

# Count the number of each genre
for genre in file_genre_map.values():
    if genre not in genre_dist:
        genre_dist[genre] = 0
    genre_dist[genre] += 1

# Calculate the total number of songs
total_songs = len(file_genre_map)

# Output the genre distribution and total number of songs
print("Genre distribution:")
for genre, count in genre_dist.items():
    print(f"{genre}: {count}")
print(f"Total number of songs: {total_songs}")

Genre distribution:
Hip-Hop: 1000
Pop: 1000
Folk: 1000
Experimental: 1000
Rock: 1000
International: 1000
Electronic: 1000
Instrumental: 1000
Total number of songs: 8000


Preprocess Training Data

In [None]:
melspec_data = []
labels = []

genre_counts = defaultdict(int)
max_songs_per_genre = 990
target_sr = 22050
chunk_duration = 3
num_chunks = 10
full_song_length = 27

genre_to_number = {'Electronic': 0, 'Experimental': 1, 'Folk': 2, 'Hip-Hop': 3, 'Instrumental': 4, 'International': 5, 'Pop': 6, 'Rock': 7}

with open(csv_path, 'r') as csvfile:
    for mp3_file in os.listdir(mp3_data_path):
        track_id = mp3_file.split('.')[0].lstrip('0')
        genre = file_genre_map[track_id]

        if genre_counts[genre] >= max_songs_per_genre:
            continue

        try:
            audio, sr = librosa.load(os.path.join(mp3_data_path, mp3_file))
            if (len(audio) / sr) < full_song_length:
                print(f"Skipped short file: {mp3_file}")
                continue
            resampled_audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
            padded_audio = fix_length(resampled_audio, size=target_sr * full_song_length)
            normalized_audio = normalize(padded_audio, norm=1)

            chunk_length = target_sr * chunk_duration
            for i in range(num_chunks):
                start_sample = i * chunk_length
                end_sample = start_sample + chunk_length
                if end_sample > len(normalized_audio):
                    break
                audio_chunk = normalized_audio[start_sample:end_sample]

                melspec = librosa.feature.melspectrogram(y=audio_chunk, sr=target_sr, n_mels=256)
                melspec = librosa.power_to_db(melspec, ref=np.max)
                melspec_tensor = np.expand_dims(melspec, axis=0)
                melspec_tensor = torch.tensor(melspec_tensor)
                melspec_data.append(melspec_tensor)

                numeric_label = genre_to_number[genre]
                labels.append(numeric_label)
            genre_counts[genre] += 1

        except Exception as e:
            print(f"Skipped corrupt file: {mp3_file}")

melspec_data = np.array(melspec_data)
labels = np.array(labels)
print(genre_counts)

  audio, sr = librosa.load(os.path.join(mp3_data_path, mp3_file))
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


Skipped short file: 098565.mp3


  audio, sr = librosa.load(os.path.join(mp3_data_path, mp3_file))
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


Skipped corrupt file: 133297.mp3


  audio, sr = librosa.load(os.path.join(mp3_data_path, mp3_file))
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


Skipped short file: 098569.mp3


  audio, sr = librosa.load(os.path.join(mp3_data_path, mp3_file))
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


Skipped corrupt file: 108925.mp3


  audio, sr = librosa.load(os.path.join(mp3_data_path, mp3_file))
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


Skipped corrupt file: 099134.mp3


  audio, sr = librosa.load(os.path.join(mp3_data_path, mp3_file))
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


Skipped short file: 098567.mp3
defaultdict(<class 'int'>, {'Folk': 990, 'Instrumental': 990, 'Electronic': 990, 'Rock': 990, 'Hip-Hop': 990, 'Pop': 990, 'Experimental': 990, 'International': 990})


In [None]:
melspec_training_data, melspec_val_data, melspec_training_labels, melspec_val_labels = train_test_split(melspec_data,
                                                                                                        labels,
                                                                                                        test_size=0.2,
                                                                                                        stratify=labels,
                                                                                                        random_state=42)

In [None]:
with h5py.File('/content/drive/My Drive/Projects/NeuraBeat/Data/train_data_melspec_expanded.h5', 'w') as f:
    f.create_dataset('data', data=np.array(melspec_training_data))
    f.create_dataset('labels', data=np.array(melspec_training_labels))

with h5py.File('/content/drive/My Drive/Projects/NeuraBeat/Data/val_data_melspec_expanded.h5', 'w') as f:
    f.create_dataset('data', data=np.array(melspec_val_data))
    f.create_dataset('labels', data=np.array(melspec_val_labels))

In [None]:
print(melspec_training_data.shape)
print(melspec_training_labels.shape)
print(melspec_val_data.shape)
print(melspec_val_labels.shape)

(57024, 1, 256, 130)
(57024,)
(14256, 1, 256, 130)
(14256,)


In [None]:
runtime.unassign()