## Determine Reward Function Weighting

#### Procedure
We will be determining the frequency that each of the voice leading rules is violated in the training data (major chorales). The idea being that the rules that are broken more frequently are "less important". We will penalize these violations to a lesser extent when they are made by our model. 

1. Load chorale data.
2. Iterate through each training chorale, counting rule violations. 

In [2]:
### Load data ###
import yaml
with open('./data/chorales_in_c_smallshift.yaml', 'r') as file:
    chorales = yaml.safe_load(file)
maj_chorales_training_data = chorales['train']

In [22]:
### RULE DEFINITIONS ###
from state_space_def import * 

def illegal_leaps(state, next_state): 
    # returns number of instances of an illegal leap 
    bass_interval = abs(next_state[0] - state[0]) # BASS FIRST!!!
    tenor_interval = abs(next_state[1] - state[1])
    alto_interval = abs(next_state[2] - state[2])
    soprano_interval = abs(next_state[3] - state[3])
    num_leaps = 0

    all_intervals = [bass_interval, tenor_interval, alto_interval, soprano_interval]

    for interval in all_intervals:
        if interval > 12 or interval==6 or interval == 11:
            num_leaps += 1
    return num_leaps

def voice_crossing(state,next_state):
    # returns number of instances of voice crossing
    num_crosses = 0
    # assumes state, next_state are LISTS!
    # voice cross between bass and tenor
    if state[0] > next_state[1] or state[1] < next_state[0]:
        num_crosses += 1
    # between tenor and alto
    if state[1] > next_state[2] or state[2] < next_state[1]:
        num_crosses += 1
    # between alto and soprano:
    if state[2] > next_state[3] or state[3] < next_state[2]:
        num_crosses += 1
    return num_crosses

def parallel_fifths_octaves(state, next_state):
    # two parts separated by a p5 or p8 move to 
    # new pitch classes separated by the same interval 
    num_parallels = 0
    p5 = 7
    p8 = 12
    bass_tenor_intervals = [state[1]-state[0], next_state[1]-next_state[0]]
    bass_alto_intervals = [state[2]-state[0], next_state[2]-next_state[0]]
    bass_soprano_intervals = [state[3]-state[0], next_state[3]-next_state[0]]
    tenor_alto_intervals = [state[2]-state[1], next_state[2]-next_state[1]]
    tenor_soprano_intervals = [state[3]-state[1], next_state[3]-next_state[1]]
    alto_soprano_intervals = [state[3]-state[2], next_state[3]-next_state[2]]
    # CHECK EVERY PAIR 
    # Get intervals for each pair
    all_intervals = [bass_tenor_intervals, bass_soprano_intervals, tenor_alto_intervals , alto_soprano_intervals, bass_alto_intervals, tenor_soprano_intervals]
    for interval in all_intervals:
        if interval[0] == interval[1] and interval[0] !=0: # don't care if we don't see movement or if no parallel motion
            if interval[0]%12 == 7: # already know they're equal and positive... and abs(interval[1])%12 == 7:
                num_parallels+=1
            elif interval[0]%12 ==0: # and interval[1]%12 == 0:
                num_parallels += 1
    return num_parallels

def direct_fifths_octaves(state, next_state):
    num_d58 = 0
    bass_interval = state[0] - next_state[0]
    tenor_interval = state[1] - next_state[1]
    alto_interval = state[2] - next_state[2]
    soprano_interval = state[3] - next_state[3]
    bass_soprano_interval_2 = next_state[3]-next_state[0]
    bass_tenor_interval_2 = next_state[1]-next_state[0]
    tenor_alto_interval_2 = next_state[2]-next_state[1]
    alto_soprano_interval_2 = next_state[3]-next_state[2]
    # BASS SOP
    if bass_interval != 0 and soprano_interval != 0 and (bass_interval * soprano_interval) > 0 and abs(soprano_interval) > 2:
        # they move in the same direction, leap in the soprano part
        if bass_soprano_interval_2%12 == 0: # move into an octave
            num_d58 += 1
        elif bass_soprano_interval_2%12 == 7: # move into a fifth
            num_d58 += 1
    # BASS TENOR
    if bass_interval != 0 and tenor_interval != 0 and (bass_interval * tenor_interval) > 0 and abs(tenor_interval) > 2:
        # they move in the same direction, leap in the soprano part
        if bass_tenor_interval_2%12 == 0: # move into an octave
            num_d58 += 1
        elif bass_tenor_interval_2%12 == 7: # move into a fifth
            num_d58 += 1
    # TENOR ALTO
    if tenor_interval != 0 and alto_interval != 0 and (tenor_interval * alto_interval) > 0 and abs(alto_interval) > 2:
        # they move in the same direction, leap in the soprano part
        if tenor_alto_interval_2%12 == 0: # move into an octave
            num_d58 += 1
        elif tenor_alto_interval_2%12 == 7: # move into a fifth
            num_d58 += 1
    # ALTO SOP
    if alto_interval != 0 and soprano_interval != 0 and (alto_interval * soprano_interval) > 0 and abs(soprano_interval) > 2:
        # they move in the same direction, leap in the soprano part
        if alto_soprano_interval_2%12 == 0: # move into an octave
            num_d58 += 1
        elif alto_soprano_interval_2%12 == 7: # move into a fifth
            num_d58 += 1    
    return num_d58

def illegal_common_tones(state, next_state):
    # return true if there are three illegal common tones or 4 common tones 
    num_common_tones = 0
    for i in range(4):
        if state[i] == next_state[i]:
            num_common_tones += 1
    if num_common_tones == 3: 
        if state[0] == next_state[0]:
            return 1
        else:
            return 0 # bass arpeggiation!
    elif num_common_tones == 4: # this is fine
        return 0
    return 0

def leading_tone_resolution(state, next_state): 
    # IMPORTANT IN THE OUTER VOICE!
    # find leading tone:
    for i, note in enumerate(state): 
        if note%12 == 11: # LEADING TONE! Should resolve up by step
            resolution_note = next_state[i]
            res_step = resolution_note - note 
            if not (res_step == 1 or res_step == 2): 
                if i == 0 or i == 3:
                    return 1
                return 0
    return 0

def seventh_approach(state, next_state):
    # a 7th must not be approached by descending leap
    # Figure out if the chord is a 7th:
    # chord_1 = determine_chord_from_voicing(state)
    chord_2 = determine_chord_from_voicing(next_state)

    if chord_2 > 7: # second chord is a 7th
        # find the 7th
        for i, note in enumerate(next_state): 
            if note%12 == notes_in_chords[chord_2][-1]: # find the 7th 
                approach_note = state[i]
                seventh_note = next_state[i]
                if approach_note - seventh_note > 2:
                    # approached by descending leap
                    return 1 
    return 0

def seventh_resolve(state,next_state): 
    # A seventh MUST resolve DOWN by step!
    chord_1 = determine_chord_from_voicing(state)

    if chord_1 > 7: # first chord is a 7th
        for i, note in enumerate(next_state): 
            if note%12 == notes_in_chords[chord_1][-1]:
                seventh_note = state[i]
                resolution_note = next_state[i]
                if seventh_note-resolution_note > 2 or seventh_note-resolution_note < 1: 
                    return 1 # did not resolve down by step
    return 0

# RULES THAT ARE NOT DEPENDENT ON PREVIOUS STATE!
### JUST REMOVE THESE FROM THE STATE SPACE??? ### Probably makes sense to.... so ignore these for now
def doubled_leading_tone(next_state):
    leading_tone_count = 0 
    for note in next_state:
        if note%12 == 11:
            leading_tone_count += 1
    if leading_tone_count > 1:
        return 1
    return 0

def dim_triad_first_inversion(next_state):  # WILL ONLY WORK FOR MAJOR
    chord = determine_chord_from_voicing(next_state)
    inversion = determine_inversion(next_state, chord)
    if chord == 7: # diminished 7th 
        if inversion != 1:
            return 1
    return 0

def check_inv_triad_complete(next_state): 
    # INVERTED TRIADS SHOULD BE COMPLETE!
    # Return true if NOT complete
    chord = determine_chord_from_voicing(next_state)
    inversion = determine_inversion(next_state, chord)
    if inversion == 1: 
        if not is_complete(next_state, chord):
            return 1
    return 0

def second_inversion_triad_doubling(next_state): 
    # return true if INCORRECT doubling
    chord = determine_chord_from_voicing(next_state)
    inversion = determine_inversion(next_state, chord)
    if inversion == 2 and chord < 8:
        # 5th should be doubled
        fifth = notes_in_chords[chord][2]
        num_fifths = 0
        for pitch in next_state:
            if pitch%12 == fifth:
                num_fifths += 1
        if num_fifths >=2:
            return 0
        else:
            return 1
    return 0




In [23]:
from tqdm import tqdm
ill, vc, p58, d58, ill_ct, lt_res, sev_app, sev_res = 0,0,0,0,0,0,0,0

for chorale in tqdm(maj_chorales_training_data):
    for i in range(len(chorale)-1): 
        state, next_state = chorale[i], chorale[i+1]
        if len(next_state) < 4: 
            break
        chord_1 = determine_chord_from_voicing(state)
        chord_2 = determine_chord_from_voicing(next_state)
        ill += illegal_leaps(state, next_state)
        vc += voice_crossing(state, next_state)
        p58 += parallel_fifths_octaves(state, next_state)
        d58 += direct_fifths_octaves(state, next_state)
        ill_ct += illegal_common_tones(state,next_state)
        lt_res += leading_tone_resolution(state, next_state)
        if chord_1 and chord_2:
            sev_app += seventh_approach(state, next_state)
            sev_res += seventh_resolve(state,next_state)


100%|██████████| 229/229 [00:00<00:00, 575.47it/s]


In [24]:
print("# of illegal leaps: {}\n# of voice crossings: {}\n# of parallel 5ths/8ves: {}\n# of direct 5ths/8ves: {}\n# of illegal common tones: {}\n# of illegal lt resolutions: {}\n# of illegal 7th approaches: {}\n# of illegal 7th resolutions: {}".format(ill, vc, p58, d58, ill_ct, lt_res, sev_app, sev_res))



# of illegal leaps: 240
# of voice crossings: 662
# of parallel 5ths/8ves: 21478
# of direct 5ths/8ves: 953
# of illegal common tones: 2843
# of illegal lt resolutions: 3106
# of illegal 7th approaches: 14
# of illegal 7th resolutions: 591


In [38]:
rules_sum = ill + vc + p58 + d58 + ill_ct + lt_res + sev_app + sev_res
ill_unscaled = (rules_sum - ill)/rules_sum
vc_unscaled = (rules_sum - vc)/rules_sum
p58_unscaled = (rules_sum - p58)/rules_sum
d58_unscaled = (rules_sum - d58)/rules_sum
illct_unscaled = (rules_sum - ill_ct)/rules_sum
ltres_unscaled = (rules_sum - lt_res)/rules_sum
sevapp_unscaled = (rules_sum - sev_app)/rules_sum
sevres_unscaled = (rules_sum - sev_res)/rules_sum
unscaled_sum = ill_unscaled + vc_unscaled + p58_unscaled + d58_unscaled + illct_unscaled + ltres_unscaled + sevapp_unscaled + sevres_unscaled

ill_weight = round(ill_unscaled/unscaled_sum,4)
vc_weight = round(vc_unscaled/unscaled_sum,4)
p58_weight = round(p58_unscaled/unscaled_sum,4)
d58_weight = round(d58_unscaled/unscaled_sum,4)
illct_weight = round(illct_unscaled/unscaled_sum,4)
ltres_weight = round(ltres_unscaled/unscaled_sum,4)
sevapp_weight = round(sevapp_unscaled/unscaled_sum,4)
sevres_weight = round(sevres_unscaled/unscaled_sum,4)

print("illegal leaps weight: {}\nvoice crossing weight: {}\nparallel 5ths/8ves weight: {}\ndirect 5ths/8ves weight: {}\nillegal common tones weight: {}\nillegal lt resolutions weight:{}\nillegal 7th approaches weight: {}\nillegal 7th resolutions weight: {}".format(ill_weight, vc_weight, p58_weight, d58_weight
                                                                                    , illct_weight, ltres_weight, sevapp_weight, sevres_weight))


illegal leaps weight: 0.1417
voice crossing weight: 0.1397
parallel 5ths/8ves weight: 0.0402
direct 5ths/8ves weight: 0.1383
illegal common tones weight: 0.1293
illegal lt resolutions weight:0.128
illegal 7th approaches weight: 0.1428
illegal 7th resolutions weight: 0.14


In [30]:
print(f'Illegal leaps weight: {ill_weight:.5f}')

Illegal leaps weight: 0.00803


In [6]:
from datetime import datetime
datem = datetime.today().strftime("%m_%d")
print(datem)

02_19


In [7]:
import pickle
import numpy as np
Qvalues = np.zeros((10,10))
epochs = 12
data = {
    "Q": Qvalues,
    "epochs": epochs
}

with open('test.p', 'wb') as f:
    pickle.dump(data, f)
        #np.save(modelpath, self.Qvalues)

In [9]:
with open('test.p', 'rb') as f: 
        data = pickle.load(f)
Qvalues = data['Q']
eps = data['epochs']
print(type(Qvalues))
print(eps)

<class 'numpy.ndarray'>
12


In [14]:
import glob, os
list_of_files = glob.glob('./models/harmmodel_*.p]') # * means all if need specific format then *.csv
print(list_of_files)
latest_file = max(list_of_files, key=os.path.getctime)
print(latest_file)

['./models/freemodel.npy', './models/harmmodel_fulldata.npy', './models/harmmodel2.npy', './models/voicemodel.npy', './models/freemodel2.npy', './models/harmmodel_fulldata_2_16.npy', './models/harmmodel_fulldata_2_15.npy', './models/voicemodel2.npy', './models/harmmodel.npy']
./models/harmmodel_fulldata_2_16.npy
