In [215]:
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import os

genres = ['disco', 'reggae', 'pop', 'rock',
          'metal', 'jazz', 'blues', 'hiphop', 'country']

chroma_type = ['stft', 'cq', 'cens']
# Naming mistake, cq should be cqt

binary_templates = [[1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1], [1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0]]

ks_templates = [[6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88], 
                [6.33, 2.68, 3.52, 5.38, 2.6, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17]]

# Tested with an irrelevent vector, perform a differnt result than binary_templates
#ks_templates = [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]

alpha = 0.9

uc = [1 + alpha + alpha**3 + alpha**7,
        0,
        0,
        0,
        alpha**4,
        0,
        0,
        alpha**2+alpha**5,
        0,
        0,
        alpha**6,
        0]

#[1 0 1 1 0 1 0 1 1 0 1 0] C minor
#[C C# D D# E F F# G G# A A# B]
#[1 0  1 0  1 1 0  1 0  1 0  1] C major

#C Major key is then 𝒖𝑪 + 𝒖𝑫 + 𝒖𝑬 + 𝒖𝑭 + 𝒖𝑮 + 𝒖𝑨 + 𝒖B

u = [None]*12
for i in range(12):
    # 0 means ua, 1 means ua#, ...
    u[i] = np.roll(uc, i-3)

harmonic_major_template = [0] * 12
harmonic_minor_template = [0] * 12

for i in range(12):
    if binary_templates[0][i] == 1:
        harmonic_major_template += u[(i+3)%12]
    if binary_templates[1][i] == 1:
        harmonic_minor_template += u[(i+3)%12]    

harmonic_templates = [harmonic_major_template, harmonic_minor_template]
    

In [189]:
def get_GTZAN_keys(genres):
    """
    Get keys for each song in each genres in GTZAN dataset. The path of the set is fixed.
    
    :param: None
    :return: A dict that contains the true key of each genres.
    
    Example:
        key0 = d['disco'][0]
        # key0 == 12
    """
    
    d = {}
    for g in genres:
        
        genres_key_path = "data/GTZAN/key/" + g + "/" + g + "."
        
        # Initial the dict
        with open(genres_key_path + "00000.lerch.txt") as f:
            d[g] = [f.readline().rstrip()]
        
        for i in range(1, 100): 
            # 100 is the numbers songs of each genres'data 
            
            with open("data/GTZAN/key/" + g + "/" + g + "." + str(i).zfill(5) + ".lerch.txt") as f:
                # 5 is the length of numbers in each file name
                d[g] += (f.readline().rstrip())
    return d

In [190]:
GTZAN_keys = get_GTZAN_keys(genres)

In [191]:
def create_GTZAN_Chroma(genres):
    """
    Create chroma file from wav and save them to result/GTZAN/
    
    :param: genres(list): The list of genres. 
    :return: None
    """
    
    if(not os.path.isdir('result/GTZAN/chroma')):
        for g in genres:
            genres_wav_path = "data/GTZAN/wav/" + g + "/" + g + "."

            for i in range(100):
                # 100 is the numbers songs of each genres'data 

                filename = genres_wav_path + str(i).zfill(5) + ".wav"
                y, sr = librosa.load(filename)

                # STFT chromagram
                chroma_stft = librosa.feature.chroma_stft(
                    y=y, sr=sr, n_chroma=12, n_fft=4096)
                # Why is n_fft=4096?
                
                # CQT chromagram
                chroma_cq = librosa.feature.chroma_cqt(y=y, sr=sr)
                
                # Chroma Energy Normalized chromagram
                chroma_cens = librosa.feature.chroma_cens(y=y, sr=sr)

                
                # Create new files for three types of chroma
                
                filename = "result/GTZAN/chroma" + g + "/stft/" + g + "." + str(i).zfill(5) + ".chroma_stft.txt"
                os.makedirs(os.path.dirname(filename), exist_ok=True)
                with open(filename, mode='w') as f:
                    np.savetxt(filename, chroma_stft)
                    
                filename = "result/GTZAN/chroma" + g + "/cq/" + g + "." + str(i).zfill(5) + ".chroma_cq.txt"
                os.makedirs(os.path.dirname(filename), exist_ok=True)
                with open(filename, mode='w') as f:
                    np.savetxt(filename, chroma_cq)
                    
                filename = "result/GTZAN/chroma" + g + "/cens/" + g + "." + str(i).zfill(5) + ".chroma_cens.txt"
                os.makedirs(os.path.dirname(filename), exist_ok=True)
                with open(filename, mode='w') as f:
                    np.savetxt(filename, chroma_cens)
                    
    return

In [192]:
create_GTZAN_Chroma(genres)

In [19]:
def get_GTZAN_Chroma(genres:str, chroma_type:str, index:int):
    """
    Get chroma from txt
    
    :return: (ndarray)
    """
    filename = "result/GTZAN/chroma/" + genres + "/" + chroma_type + "/" + genres + "." + str(index).zfill(5) + ".chroma_" + chroma_type + ".txt"
    with open(filename, mode='r') as f:
        return np.loadtxt(filename)

In [20]:
c = get_GTZAN_Chroma("blues", "cens", 3)

In [193]:
def create_GTZAN_tonic_pitch(genres:list, chroma_type:list):
    """
    Create tonic pitch list from chroma files if not already exsists.
    Start from A == 0
    
    """
    if(not os.path.isdir('result/GTZAN/tonic_pitch')):
        tonic_pitch_path = "result/GTZAN/tonic_pitch/"
        
        for g in genres:
            tonic_pitch_genres_path = tonic_pitch_path + g + "/"
            
            for c in chroma_type:
                
                filename = tonic_pitch_genres_path + c + ".txt"
                os.makedirs(os.path.dirname(filename), exist_ok=True)
                with open(filename, mode='w') as f:
                    
                    for i in range(100):
                    # 100 is the numbers of songs of each genres'data
                        chroma = get_GTZAN_Chroma(g, c, i)
                        mean_pooling_chroma = chroma.sum(axis=1)
                        tonic_pitch = (np.argmax(mean_pooling_chroma)+3)%12 # Shift to A=0
                        f.write(str(tonic_pitch) + '\n')
    return

In [194]:
create_GTZAN_tonic_pitch(genres, chroma_type)

In [41]:
def get_GTZAN_tonic_pitch(genres:list, chroma_type:list):
    """
    :Return: pitches(dict): 
    """
    
    d = {}
    tonic_pitch_path = "result/GTZAN/tonic_pitch/"
    
    for g in genres:
        tonic_pitch_genres_path = tonic_pitch_path + g + "/"
        d[g] = {}
        
        for c in chroma_type:
            filename = tonic_pitch_genres_path + c + ".txt"
    
            with open(filename, mode='r+') as f:
                d[g][c] = [int(f.readline().rstrip())]
                
                for i in range(1, 100):
                    d[g][c].append(int(f.readline().rstrip()))
    
    return d

In [67]:
pitches = get_GTZAN_tonic_pitch(genres, chroma_type)
pitches['blues']['cq'][0]

10

In [220]:
def predict_GTZAN_key(chroma, major_template, minor_template, tonic_pitch: int):
    """
    Tonic_pitch start with A
    
    """
    
    shifted_major_template = np.roll(major_template, -3 + tonic_pitch)
    shifted_minor_template = np.roll(minor_template, -3 + tonic_pitch)
    # -3 becuase input tempalte tonic is C
    
    mean_pooling_chroma = chroma.sum(axis=1)
    
    major_pcc = stats.pearsonr(mean_pooling_chroma, shifted_major_template)[0]
    minor_pcc = stats.pearsonr(mean_pooling_chroma, shifted_minor_template)[0]
    
    if major_pcc > minor_pcc:
        return tonic_pitch # Major
    else:
        return tonic_pitch+12 # Minor
    

In [221]:
predict_GTZAN_key(get_GTZAN_Chroma('blues', 'cq', 0), 
                  binary_template_C_major, 
                  binary_template_C_minor, 
                  pitches['blues']['cq'][0])

10

In [222]:
def create_GTZAN_predict(genres:list, chroma_type:list, pitches:dict, binary_templates:list, ks_templates:list, harmonic_templates:list):
    
    if(not os.path.isdir('result/GTZAN/predict/')):
        predict_path = "result/GTZAN/predict/"
        
        for g in genres:
            predict_genres_path = predict_path + "/" + g + "/"
        
            for c in chroma_type:
                predict_genres_chroma_path = predict_genres_path + "/" + c + "/"
                
                # Get binary_templates results
                filename = predict_genres_chroma_path + "binary_templates.txt"
                os.makedirs(os.path.dirname(filename), exist_ok=True)
                with open(filename, mode='w') as f:
                    for i in range(100):
                        predict_key = predict_GTZAN_key(
                        get_GTZAN_Chroma(g, c, i),
                            binary_templates[0],
                            binary_templates[1],
                            pitches[g][c][i]
                        )
                        f.write(str(predict_key) + '\n')
                        
                # Get ks_templates results
                filename = predict_genres_chroma_path + "ks_templates.txt"
                os.makedirs(os.path.dirname(filename), exist_ok=True)
                with open(filename, mode='w') as f:
                    for i in range(100):
                        predict_key = predict_GTZAN_key(
                        get_GTZAN_Chroma(g, c, i),
                            ks_templates[0],
                            ks_templates[1],
                            pitches[g][c][i]
                        )
                        f.write(str(predict_key) + '\n')
                        
                # Get harmonic_templates results
                filename = predict_genres_chroma_path + "harmonic_templates.txt"
                os.makedirs(os.path.dirname(filename), exist_ok=True)
                with open(filename, mode='w') as f:
                    for i in range(100):
                        predict_key = predict_GTZAN_key(
                        get_GTZAN_Chroma(g, c, i),
                            harmonic_templates[0],
                            harmonic_templates[1],
                            pitches[g][c][i]
                        )
                        f.write(str(predict_key) + '\n')

In [224]:
create_GTZAN_predict(genres, chroma_type, pitches, binary_templates, ks_templates, harmonic_templates)

In [225]:
def get_raw_acc(p, y):
    
    count = 0
    for i in range(len(y)):
        if p[i] == y[i]:
            count += 1
    return round(count / len(y), 5)
    
def get_weighted_acc(p, y):
    """
    Refer to the weighted accuracy of Task 1 to see how weighting works.
    """
    
    score = 0
    for i in range(len(y)):
        if p[i] == y[i]:
            score += 1
        elif (p[i]+7)%12 == y[i]:
            score += 0.5
        elif p[i] < 12:
            if p[i] + 9 == y[i]:
                score += 0.3
            elif p[i] + 12 == y[i]:
                score += 0.2
        elif p[i] >= 12:
            if (p[i] + 3)%12 == y[i]:
                score += 0.3
            elif p[i] - 12 == y[i]:
                score += 0.2
    return round(score / len(y), 5)

def create_GTZAN_acc(genres:list, chroma_type:list):
    
    if(not os.path.isdir('result/GTZAN/acc/')):
        acc_path = "result/GTZAN/acc/"
        predict_path = "result/GTZAN/predict/"
        y_path = "data/GTZAN/key/"
        
        acc = {}
        for g in genres:
            acc_genres_path = acc_path + g + ".txt"
            predict_genres_path = predict_path + "/" + g + "/"
            y_genres_path = y_path + "/" + g + "/"
            
            acc[g] = {}
            
            # Get ground truth
            y = []
            for i in range(100):
                y_filename = y_genres_path + g + '.' + str(i).zfill(5) + ".lerch.txt"
                with open(y_filename, mode='r+') as f:
                    y.append(int(f.readline().rstrip()))
            
            for c in chroma_type:
                acc_genres_chroma_path = acc_genres_path + "/" + c + "/"
                predict_genres_chroma_path = predict_genres_path + "/" + c + "/"
                acc[g][c] = {}
                
                # Get binary_templates acc
                acc[g][c]["binary_templates"] = {}
                p_bt_filename = predict_genres_chroma_path + "binary_templates.txt"
                p = []
                with open(p_bt_filename, mode='r+') as f:
                    for i in range(100):
                        p.append(int(f.readline().rstrip()))
                
                # Raw and weighted acc
                acc[g][c]["binary_templates"]["raw"] = get_raw_acc(p, y)
                acc[g][c]["binary_templates"]["weighted"] = get_weighted_acc(p, y)
                
                
                # Get ks_templates acc
                acc[g][c]["ks_templates"] = {}
                p_ks_filename = predict_genres_chroma_path + "ks_templates.txt"
                p = []
                with open(p_ks_filename, mode='r+') as f:
                    for i in range(100):
                        p.append(int(f.readline().rstrip()))
                
                # Raw and weighted acc
                acc[g][c]["ks_templates"]["raw"] = get_raw_acc(p, y)
                acc[g][c]["ks_templates"]["weighted"] = get_weighted_acc(p, y)
                
                
                
                # Get harmonic_templates acc
                acc[g][c]["harmonic_templates"] = {}
                p_hm_filename = predict_genres_chroma_path + "harmonic_templates.txt"
                p = []
                with open(p_hm_filename, mode='r+') as f:
                    for i in range(100):
                        p.append(int(f.readline().rstrip()))
                
                # Raw and weighted acc
                acc[g][c]["harmonic_templates"]["raw"] = get_raw_acc(p, y)
                acc[g][c]["harmonic_templates"]["weighted"] = get_weighted_acc(p, y)
                
        return acc

In [228]:
ans = create_GTZAN_acc(genres, chroma_type)

In [229]:
ans

{'disco': {'stft': {'binary_templates': {'raw': 0.34, 'weighted': 0.364},
   'ks_templates': {'raw': 0.32, 'weighted': 0.348},
   'harmonic_templates': {'raw': 0.37, 'weighted': 0.388}},
  'cq': {'binary_templates': {'raw': 0.4, 'weighted': 0.425},
   'ks_templates': {'raw': 0.36, 'weighted': 0.393},
   'harmonic_templates': {'raw': 0.39, 'weighted': 0.417}},
  'cens': {'binary_templates': {'raw': 0.35, 'weighted': 0.379},
   'ks_templates': {'raw': 0.32, 'weighted': 0.355},
   'harmonic_templates': {'raw': 0.36, 'weighted': 0.387}}},
 'reggae': {'stft': {'binary_templates': {'raw': 0.42, 'weighted': 0.43},
   'ks_templates': {'raw': 0.41, 'weighted': 0.422},
   'harmonic_templates': {'raw': 0.41, 'weighted': 0.422}},
  'cq': {'binary_templates': {'raw': 0.32, 'weighted': 0.346},
   'ks_templates': {'raw': 0.38, 'weighted': 0.394},
   'harmonic_templates': {'raw': 0.41, 'weighted': 0.418}},
  'cens': {'binary_templates': {'raw': 0.31, 'weighted': 0.337},
   'ks_templates': {'raw': 0.37