# Music Genre Classifier

In [None]:
import os
import wget
import shutil
import tarfile
from pylab import imshow
from essentia import Pool 
import matplotlib.pyplot as plt
from essentia.standard import FrameGenerator, MonoLoader, \
    Windowing, Spectrum, MFCC, UnaryOperator

## Load the GTZAN Dataset

In [None]:
PREPROCESSING = False

try:
    if 'dataset' not in os.listdir('.'):
        os.mkdir('./dataset/')
    os.chdir('./dataset/')

    if 'genres.tar.gz' not in os.listdir('.'):    
        if os.getenv('COLAB_RELEASE_TAG'):
            # download the GTZAN dataset
            wget.download("https://huggingface.co/datasets/marsyas/gtzan/resolve/main/data/genres.tar.gz")
        else:
            raise Exception("Download the GTZAN dataset.")
        
    # extract all dataset
    if 'genres' not in os.listdir('.'):
        tar = tarfile.open('genres.tar.gz', 'r:gz')
        tar.extractall()
        tar.close()

    if 'preprocessing' not in os.listdir('.'):
        PREPROCESSING = True

    if PREPROCESSING:
        os.mkdir('./preprocessing')
        os.mkdir('./preprocessing/mfcc')
        os.mkdir('./preprocessing/mfcc_bands')
        os.mkdir('./preprocessing/mfcc_bands_log')

        # remove all unnecesary files
        for genre in os.listdir('./genres'):
            if genre.startswith('.'): 
                os.remove(f'./genres/{genre}')
                continue
            
            if PREPROCESSING:
                os.mkdir(f'./preprocessing/mfcc/{genre}')
                os.mkdir(f'./preprocessing/mfcc_bands/{genre}')
                os.mkdir(f'./preprocessing/mfcc_bands_log/{genre}')
            
            for wav in os.listdir(f'./genres/{genre}'):
                if wav.startswith('._'):
                    os.remove(f'./genres/{genre}/{wav}')

        for file in os.listdir('.'):
            if file.startswith('._'):
                os.remove(file)

finally:
    os.chdir('..')
    GENRES = os.listdir('./dataset/genres/')

## Feature Extraction

In [None]:
plt.rcParams['figure.figsize'] = (16, 9)

def extract_mfcc(src: str, dst: str, name: str, genre: str):
    # we start by instantiating the audio loader:
    loader = MonoLoader(filename=src)

    # and then we actually perform the loading:
    audio = loader()

    w = Windowing(type = 'hann')
    spectrum = Spectrum()  # FFT() would return the complex FFT, here we just want the magnitude spectrum
    mfcc = MFCC()
    
    logNorm = UnaryOperator(type='log')

    pool = Pool()

    for frame in FrameGenerator(audio, frameSize = 1024, hopSize = 512, startFromZero=True):
        mfcc_bands, mfcc_coeffs = mfcc(spectrum(w(frame)))
        pool.add('lowlevel.mfcc', mfcc_coeffs)
        pool.add('lowlevel.mfcc_bands', mfcc_bands)
        pool.add('lowlevel.mfcc_bands_log', logNorm(mfcc_bands))

    imshow(pool['lowlevel.mfcc'].T[1:,:], aspect='auto', origin='lower', interpolation='none')
    plt.axis('off')
    plt.savefig(f'{dst}/mfcc/{genre}/{name}.png', bbox_inches='tight', pad_inches=0)

    imshow(pool['lowlevel.mfcc_bands'].T, aspect = 'auto', origin='lower', interpolation='none')
    plt.axis('off')
    plt.savefig(f'{dst}/mfcc_bands/{genre}/{name}.png', bbox_inches='tight', pad_inches=0)

    imshow(pool['lowlevel.mfcc_bands_log'].T, aspect = 'auto', origin='lower', interpolation='none')
    plt.axis('off')
    plt.savefig(f'{dst}/mfcc_bands_log/{genre}/{name}.png', bbox_inches='tight', pad_inches=0)

## Dataset Preprocessing

In [None]:
count = 1

# extract features from the data set
for genre in ['jazz', 'rock']:
    for wav in os.listdir(f'./dataset/genres/{genre}'):
        if PREPROCESSING:
            try: 
                extract_mfcc(
                    src=f'./dataset/genres/{genre}/{wav}', 
                    dst=f'./dataset/preprocessing',
                    name=wav.removesuffix('.wav'),
                    genre=genre
                )
                print(f'{count} - {wav} EXTRACTED')
                count += 1
            except:...