In [42]:
from gym import Env
from gym.spaces import Discrete, Box
import numpy as np
import math
from itertools import combinations
from datetime import datetime
from music21 import *

For the reward function:  
if correct segmentation: output +1 #(change of roughness)  
if incorrect segmentation: output -1#(correct change of roughness in the next segmentation)  
if correct do nothing : output +1  
if incorrect do nothing: output -1 #(correct change of roughness)  
if illegal: output 0  

In [46]:
class SegmentationEnv(Env):  # Fit one particular coin first
    def __init__(self, pieces):
        #Preprocess the pieces
        self.notes = []
        self.offset = []
        self.beat = []
        self.duration = []
        self.octave = []
        self.correct_offset = []
#         self.beatchanges = []
        for piece in pieces:
            xnotes = []
            xoffset = []
            xbeat = []
            xduration = []
            xoctave = []
            xcoroffset = []
            c = converter.parse(piece)
            post = c.flattenParts().flat
            for note in post.notes:
                duration = note.duration.quarterLength
                offset = note.offset
                beat = note.beat
                if note.lyric is not None and note.offset != 0:
                    xcoroffset.append(note.offset)
                allnotes = list(note.pitches)
                for note1 in allnotes:
                    xnotes.append(note1.name)
                    xoffset.append(offset)
                    xbeat.append(beat)
                    xduration.append(duration)
                    xoctave.append(note1.octave)
            self.notes.append(xnotes)
            self.offset.append(xoffset)
            self.beat.append(xbeat)
            self.duration.append(xduration)
            self.octave.append(xoctave)
            self.correct_offset.append(xcoroffset)
            #             xbeatchange = {}
#             for ts in post.recurse().getElementsByClass(meter.TimeSignature):
#                 assert ts.denominator in [2,4,8]
#                 if ts.denominator == 2:
#                     xbeatchange[ts.offset] = 2
#                 elif ts.denominator == 4:
#                     xbeatchange[ts.offset] = 1
#                 else:
#                     xbeatchange[ts.offset] = 0.5
#             self.beatchanges.append(xbeatchange)

        #Actions: Remain segment (0), segment (1)
        self.action_space = Discrete(2)
        
        #Observations: First dim 12 pitch classes, Second dim Octave (1-7), Value is total duration.
        self.observation_space = Box(
            low=np.zeros((12,7),dtype=np.float32),
            high=np.ones((12,7),dtype=np.float32)*20, # Set the maximum duration to 20. If exceed, then just keep at 20.
        )
        
        #internal state: check where the time currently is 
        self.current_piece = 0
        self.current_noteoffset = 0
        self.notelistfirst = 0
        self.notelistlast = 0
        self.latestbeatfirst = 0
        self.latestbeatlast = 0
        self.state = np.zeros((12,7))
        
        #save segmentation for rendering purposes
        self.determined_offset = []
        
    def step(self, action):
        #Calculating reward
        if action == 0: # do nothing
            is_segment = False
            if self.current_noteoffset not in self.correct_offset[self.current_piece]: #correct
                reward = 1
            else:
                reward = max(-self.change_in_roughness()/20,-1)
        else: # segmentation
            is_segment = True
            if self.current_noteoffset == 0: #illegal operations
                reward = 0
            else:
                self.determined_offset.append((self.current_piece,self.current_noteoffset))
                if self.current_noteoffset in self.correct_offset[self.current_piece]:
                    reward = 1
                else:
                    reward = -1    
        #determine new obs state
        if is_segment and self.current_noteoffset != 0:
            self.notelistfirst = self.latestbeatfirst
        done = False
        if self.latestbeatlast == len(self.beat[self.current_piece]): #Finished a piece
            self.current_piece += 1
            if self.current_piece == len(self.notes):
                done = True
            else:
                done = False
                self.current_noteoffset = 0
                self.notelistfirst = 0
                self.notelistlast = 0 
                self.latestbeatfirst = 0
                self.latestbeatlast = 0
        if not done:
            self.current_noteoffset = self.offset[self.current_piece][self.latestbeatlast]
            currentbeat = self.beat[self.current_piece][self.latestbeatlast]
            currentindex = self.latestbeatlast + 1
            self.latestbeatfirst = self.latestbeatlast
            while len(self.beat[self.current_piece]) > currentindex and self.beat[self.current_piece][currentindex]//1 == currentbeat:
                currentindex += 1
            self.notelistlast = currentindex
            self.latestbeatlast = currentindex
        info = {}
        return self.staterender(), reward, done, info

    def render(self):
#         print("Current notelist:",self.notelistfirst,self.notelistlast)
        print("The current segmentation are:")
        for segment in self.determined_offset:
            print(segment)
        return
    
    def change_in_roughness(self):
        def roughness(notes):
            '''
            Calculate the Roughness of notes according to sum of ideal ratio N+M
            Reference: https://www.researchgate.net/publication/276905584_Measuring_Musical_Consonance_and_Dissonance
            '''
            def interval_to_ratio(interval):
                interval_ratio_mapping = {
                    0:1+1,
                    1:18+17,
                    2:9+8,
                    3:6+5,
                    4:5+4,
                    5:4+3,
                    6:17+12,
                    7:3+2,
                    8:8+5,
                    9:5+3,
                    10:16+9,
                    11:17+9,
                    12:2+1
                }
                interval_pitch_mapping = {
                    1:0,
                    2:2,
                    3:4,
                    4:5,
                    5:7,
                    6:9,
                    7:11,
                    8:12
                }
                ans = interval_pitch_mapping[int(interval[-1])]
                if int(interval[-1]) in [4,5,8]:
                    intname = interval[:-1]
                    if intname == "dd":
                        ans -= 2
                    elif intname == "d":
                        ans -= 1
                    elif intname == "A":
                        ans += 1
                    elif intname == "AA":
                        ans += 2
                else:
                    intname = interval[:-1]
                    if intname == "m":
                        ans -= 1
                    elif intname == "d":
                        ans -= 2
                    elif intname == "A":
                        ans += 1
                    elif intname == "AA":
                        ans += 2
                if ans < 0:
                    ans = ans%12
                return interval_ratio_mapping[ans]
            ans = 0
            for combo in combinations(notes,2):
                n1 = note.Note(combo[0])
                n2 = note.Note(combo[1])
                xinterval = interval.Interval(noteStart=n1,noteEnd=n2)
                ans += interval_to_ratio(xinterval.semiSimpleName)
            return ans/len(notes) if len(notes)!= 0 else 0
        notelist1 = []
        for i in range(self.notelistfirst,self.latestbeatfirst):
            notelist1.append(self.notes[self.current_piece][i])
        notelist2 = notelist1.copy()
        for i in range(self.latestbeatfirst,self.latestbeatlast):
            notelist2.append(self.notes[self.current_piece][i])
        return abs(roughness(notelist2)-roughness(notelist1))
    
    def staterender(self):
        pitch_to_index = {"C": 0, "D": 2, "E": 4, "F": 5, "G": 7, "A": 9, "B": 11}
        obsarray = np.zeros((12,7))
        notelist = []
        for idx in range(self.notelistfirst,self.notelistlast):
            current_note = self.notes[self.current_piece][idx]
            notelist.append(current_note)
            current_duration = self.duration[self.current_piece][idx]
            current_octave = self.octave[self.current_piece][idx]
            pitchindex = pitch_to_index[current_note[0]]
            current_note = current_note[1:]
            if current_note == "#":
                pitchindex += 1
            elif current_note == "##":
                pitchindex += 2
            elif current_note == "-":
                pitchindex -= 1
            elif current_note == "--":
                pitchindex -= 2
            pitchindex = pitchindex % 12
            if current_octave < 1 or current_octave > 7:
                continue
            obsarray[pitchindex][current_octave] += current_duration
            obsarray[pitchindex][current_octave] = min(20,obsarray[pitchindex][current_octave])
#         print(notelist)
        return obsarray

    def reset(self):
        self.current_piece = 0
        self.current_noteoffset = 0
        self.notelistfirst = 0
        self.notelistlast = 0 #exclusive
        self.latestbeatfirst = 0
        self.latestbeatlast = 0 #exclusive
        currentbeat = self.beat[self.current_piece][self.latestbeatlast]
        currentindex = self.latestbeatlast + 1
        while len(self.beat[self.current_piece]) > currentindex and self.beat[self.current_piece][currentindex]//1 == currentbeat:
            currentindex += 1
        self.notelistlast = currentindex
        self.latestbeatlast = currentindex
        return self.staterender()

In [47]:
import time
env = SegmentationEnv(["../review/not1_Prelude_I.musicxml"])
cur_state = env.reset()
done = False
while not done:
    env.render()
    action = env.action_space.sample()
    print("Action taken:",action)
    new_state, reward, done, _ = env.step(action)
    print("Reward:",reward)


The current segmentation are:
Action taken: 0
Reward: 1
The current segmentation are:
Action taken: 1
Reward: -1
The current segmentation are:
(0, 1.0)
Action taken: 1
Reward: -1
The current segmentation are:
(0, 1.0)
(0, 2.0)
Action taken: 1
Reward: -1
The current segmentation are:
(0, 1.0)
(0, 2.0)
(0, 3.0)
Action taken: 1
Reward: 1
The current segmentation are:
(0, 1.0)
(0, 2.0)
(0, 3.0)
(0, 4.0)
Action taken: 0
Reward: 1
The current segmentation are:
(0, 1.0)
(0, 2.0)
(0, 3.0)
(0, 4.0)
Action taken: 1
Reward: -1
The current segmentation are:
(0, 1.0)
(0, 2.0)
(0, 3.0)
(0, 4.0)
(0, 6.0)
Action taken: 0
Reward: 1
The current segmentation are:
(0, 1.0)
(0, 2.0)
(0, 3.0)
(0, 4.0)
(0, 6.0)
Action taken: 1
Reward: 1
The current segmentation are:
(0, 1.0)
(0, 2.0)
(0, 3.0)
(0, 4.0)
(0, 6.0)
(0, 8.0)
Action taken: 1
Reward: -1
The current segmentation are:
(0, 1.0)
(0, 2.0)
(0, 3.0)
(0, 4.0)
(0, 6.0)
(0, 8.0)
(0, 9.0)
Action taken: 0
Reward: 1
The current segmentation are:
(0, 1.0)
(0, 2.0

(0, 2.0)
(0, 3.0)
(0, 4.0)
(0, 6.0)
(0, 8.0)
(0, 9.0)
(0, 14.0)
(0, 16.0)
(0, 18.0)
(0, 20.0)
(0, 22.0)
(0, 24.0)
(0, 26.0)
(0, 28.0)
(0, 29.0)
(0, 30.0)
(0, 33.0)
(0, 34.0)
(0, 35.0)
(0, 36.0)
(0, 37.0)
(0, 40.0)
(0, 43.0)
(0, 45.0)
(0, 46.0)
(0, 47.0)
(0, 49.0)
(0, 51.0)
(0, 52.0)
(0, 54.0)
(0, 56.0)
(0, 58.0)
(0, 59.0)
(0, 63.0)
(0, 66.0)
(0, 70.0)
(0, 71.0)
(0, 72.0)
(0, 75.0)
(0, 76.0)
(0, 79.0)
(0, 83.0)
(0, 84.0)
(0, 85.0)
(0, 88.0)
Action taken: 0
Reward: 1
The current segmentation are:
(0, 1.0)
(0, 2.0)
(0, 3.0)
(0, 4.0)
(0, 6.0)
(0, 8.0)
(0, 9.0)
(0, 14.0)
(0, 16.0)
(0, 18.0)
(0, 20.0)
(0, 22.0)
(0, 24.0)
(0, 26.0)
(0, 28.0)
(0, 29.0)
(0, 30.0)
(0, 33.0)
(0, 34.0)
(0, 35.0)
(0, 36.0)
(0, 37.0)
(0, 40.0)
(0, 43.0)
(0, 45.0)
(0, 46.0)
(0, 47.0)
(0, 49.0)
(0, 51.0)
(0, 52.0)
(0, 54.0)
(0, 56.0)
(0, 58.0)
(0, 59.0)
(0, 63.0)
(0, 66.0)
(0, 70.0)
(0, 71.0)
(0, 72.0)
(0, 75.0)
(0, 76.0)
(0, 79.0)
(0, 83.0)
(0, 84.0)
(0, 85.0)
(0, 88.0)
Action taken: 0
Reward: 1
The current segmentat

(0, 29.0)
(0, 30.0)
(0, 33.0)
(0, 34.0)
(0, 35.0)
(0, 36.0)
(0, 37.0)
(0, 40.0)
(0, 43.0)
(0, 45.0)
(0, 46.0)
(0, 47.0)
(0, 49.0)
(0, 51.0)
(0, 52.0)
(0, 54.0)
(0, 56.0)
(0, 58.0)
(0, 59.0)
(0, 63.0)
(0, 66.0)
(0, 70.0)
(0, 71.0)
(0, 72.0)
(0, 75.0)
(0, 76.0)
(0, 79.0)
(0, 83.0)
(0, 84.0)
(0, 85.0)
(0, 88.0)
(0, 93.0)
(0, 95.0)
(0, 100.0)
(0, 102.0)
(0, 104.0)
(0, 105.0)
(0, 106.0)
(0, 107.0)
(0, 110.0)
(0, 111.0)
(0, 112.0)
(0, 113.0)
(0, 117.0)
(0, 122.0)
Action taken: 0
Reward: 1
The current segmentation are:
(0, 1.0)
(0, 2.0)
(0, 3.0)
(0, 4.0)
(0, 6.0)
(0, 8.0)
(0, 9.0)
(0, 14.0)
(0, 16.0)
(0, 18.0)
(0, 20.0)
(0, 22.0)
(0, 24.0)
(0, 26.0)
(0, 28.0)
(0, 29.0)
(0, 30.0)
(0, 33.0)
(0, 34.0)
(0, 35.0)
(0, 36.0)
(0, 37.0)
(0, 40.0)
(0, 43.0)
(0, 45.0)
(0, 46.0)
(0, 47.0)
(0, 49.0)
(0, 51.0)
(0, 52.0)
(0, 54.0)
(0, 56.0)
(0, 58.0)
(0, 59.0)
(0, 63.0)
(0, 66.0)
(0, 70.0)
(0, 71.0)
(0, 72.0)
(0, 75.0)
(0, 76.0)
(0, 79.0)
(0, 83.0)
(0, 84.0)
(0, 85.0)
(0, 88.0)
(0, 93.0)
(0, 95.0)
(0, 100.0

IndexError: list index out of range

In [2]:
x = [1,2,3]
y = x.copy()
y.append(5)
x,y

([1, 2, 3], [1, 2, 3, 5])

In [10]:
env = SegmentationEnv("sdf")
for i in range(10):
    print(env.action_space.sample())

1
1
0
0
1
0
1
1
0
1
