# Music Genre Classifier

In [None]:
import os
import re
import wget
import shutil
import tarfile
from pylab import imshow
from essentia import Pool 
import matplotlib.pyplot as plt
from essentia.standard import FrameGenerator, MonoLoader, \
    Windowing, Spectrum, MFCC, UnaryOperator

## Load the GTZAN Dataset

In [None]:
PREPROCESSING = True

SETS = ['training', 'validation', 'tests']
FEATURES = ['mfcc', 'mfcc_bands', 'mfcc_bands_log']
GENRES = ['blues', 'classical', 'country', 'disco', \
    'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']

if 'dataset' not in os.listdir('.'):
    os.mkdir('./dataset')

if 'genres.tar.gz' not in os.listdir('./dataset'):    
    if os.getenv('COLAB_RELEASE_TAG'):
        # download the GTZAN dataset
        wget.download("https://huggingface.co/datasets/marsyas/gtzan/resolve/main/data/genres.tar.gz")
        shutil.copy2(src='./genres.tar.gz', dst='./dataset/')
    else:
        raise Exception("Download the GTZAN dataset.")

# extract all dataset
if 'genres' not in os.listdir('./dataset'):
    tar = tarfile.open('./dataset/genres.tar.gz', 'r:gz')
    tar.extractall('./dataset')
    tar.close()

if 'split' not in os.listdir('./dataset'):
    PREPROCESSING = True

if PREPROCESSING:
    os.mkdir('./dataset/split')
    
    for set in SETS:
        os.mkdir(f'./dataset/split/{set}')
        os.mkdir(f'./dataset/split/{set}/data')
        os.mkdir(f'./dataset/split/{set}/features')
        
        for feat in FEATURES:
            os.mkdir(f'./dataset/split/{set}/features/{feat}')
    
            for genre in GENRES:
                if genre not in os.listdir(f'./dataset/split/{set}/data'):
                    os.mkdir(f'./dataset/split/{set}/data/{genre}')
                os.mkdir(f'./dataset/split/{set}/features/{feat}/{genre}')    

    # remove all unnecesary files
    for genre in GENRES:
        for wav in os.listdir(f'./dataset/genres/{genre}'):
            matched = re.match(r"(\w+).(\d+).wav", wav)
            if matched:
                _, i = matched.groups()
                i = int(i)
                dest = ''

                if i >=0 and i <= 79:
                    dest = f'./dataset/split/training/data/{genre}'
                elif i >= 80 and i <= 89:
                    dest = f'./dataset/split/validation/data/{genre}'
                else:
                    dest = f'./dataset/split/tests/data/{genre}'

                shutil.copy2(
                    src=f'./dataset/genres/{genre}/{wav}',
                    dst=dest
                )

    shutil.rmtree('./dataset/genres')

    PREPROCESSING = True

## Feature Extraction

In [None]:
plt.rcParams['figure.figsize'] = (16, 9)

def extract_mfcc(src: str, dst: str, name: str, genre: str):
    # we start by instantiating the audio loader:
    loader = MonoLoader(filename=src)

    # and then we actually perform the loading:
    audio = loader()

    w = Windowing(type = 'hann')
    spectrum = Spectrum()  # FFT() would return the complex FFT, here we just want the magnitude spectrum
    mfcc = MFCC()
    
    logNorm = UnaryOperator(type='log')

    pool = Pool()

    for frame in FrameGenerator(audio, frameSize = 1024, hopSize = 512, startFromZero=True):
        mfcc_bands, mfcc_coeffs = mfcc(spectrum(w(frame)))
        pool.add('lowlevel.mfcc', mfcc_coeffs)
        pool.add('lowlevel.mfcc_bands', mfcc_bands)
        pool.add('lowlevel.mfcc_bands_log', logNorm(mfcc_bands))

    for feat in FEATURES:
        imshow(pool[f'lowlevel.{feat}'].T[1:,:], aspect='auto', origin='lower', interpolation='none')
        plt.axis('off')
        plt.savefig(f'{dst}/{feat}/{genre}/{name}.png', bbox_inches='tight', pad_inches=0)


## Dataset Preprocessing

In [None]:
count = 1
PREPROCESSING = True

# extract features from the data set
for genre in GENRES:
    for set in SETS:
        for wav in os.listdir(f'./dataset/split/{set}/data/{genre}'):
            if PREPROCESSING:
                try: 
                    extract_mfcc(
                        src=f'./dataset/split/{set}/data/{genre}/{wav}', 
                        dst=f'./dataset/split/{set}/features',
                        name=wav.removesuffix('.wav'),
                        genre=genre
                    )
                    print(f'{count} - {wav} EXTRACTED')
                    count += 1
                except:...