In [1]:
import pandas as pd 
from Chord import Chord
from Rule import Rule
from collections import defaultdict
import pitchtypes as pt
from probabilistic_model import ProbabilisticModel

In [2]:
p_model = ProbabilisticModel()

In [3]:
data = pd.read_json('../eda/JazzHarmonyTreebank/treebank.json')
trees_data = data[data['trees'].notna()]['trees']
ccts = [x[0]['complete_constituent_tree'] for x in list(trees_data)]

In [4]:
p_model.fit(ccts)

In [7]:
def parse_subtree(subtree):
    if isinstance(subtree, dict):
        label = Chord(subtree['label'])
        child_intervals = [str(label.distance_to(Chord(x['label']))) for x in subtree['children']]
        child_qualities = [Chord(x['label']).quality for x in subtree['children']]
        rule = Rule(label.quality, child_intervals, child_qualities)
        
        return rule

In [8]:
def make_hashable(d):
    def convert(value):
        if isinstance(value, dict):
            return tuple(sorted((k, convert(v)) for k, v in value.items()))
        elif isinstance(value, list):
            return tuple(convert(v) for v in value)
        else:
            return value

    return tuple(sorted((k, convert(v)) for k, v in d.items()))

In [None]:
count_dict = defaultdict(int)
lhs_count_dict = defaultdict(int)

def parse_leftmost(tree):
    output = []
    steps = []
    stack = [tree]
    parsed_so_far = []

    while stack:
        current = stack.pop(0)
        rule = parse_subtree(current)
        rule_key = rule.make_hashable()
        count_dict[rule_key] += 1
        lhs_count_dict[rule.lhs()] += 1

        steps.append({
            'current': current['label'],
            'parsed_before': parsed_so_far.copy()
        })

        if current.get('children'):
            stack = current['children'] + stack
        else:
            output.append(current['label'])

        parsed_so_far.append(current['label'])

    return steps, count_dict, lhs_count_dict

In [None]:
for cct in ccts:
    steps, count_dict, lhs_count_dict = parse_leftmost(cct)


In [None]:
sorted_dict = dict(sorted(count_dict.items(), key=lambda x: x[1], reverse=False))

In [None]:
lhs_count_dict = dict(sorted(lhs_count_dict.items(), key=lambda x: x[1], reverse=False))
lhs_count_dict 

{'sus': 47, 'unknown': 321, 'minor': 2125, 'major': 5455}

In [None]:
def compute_conditional_probs(count_dict, lhs_count_dict):
    prob_dict = {}
    for rule, count in count_dict.items():
        lhs = rule[2][1]
        if lhs_count_dict[lhs] == 0:
            print(lhs)
        prob_dict[rule] = count / lhs_count_dict[lhs]
    return prob_dict


In [None]:
def get_rule_probability(node, prob_dict):
    subtree = parse_subtree(node)
    rule_key = subtree.make_hashable()
    return prob_dict.get(rule_key, 0)

In [None]:
def compute_tree_probability(tree, prob_dict):
    if not tree.get('children'):
        return 1.0  

    prob = get_rule_probability(tree, prob_dict)

    for child in tree['children']:
        prob *= compute_tree_probability(child, prob_dict)

    return prob
    

In [None]:
for cct in ccts:
    steps, count_dict, lhs_count_dict = parse_leftmost(cct)
    prob_dict = compute_conditional_probs(count_dict, lhs_count_dict)

In [None]:
compute_tree_probability(ccts[3], expanded_prob_dict)

1.8069074877778802e-11

In [5]:
prob_dict = p_model.prob_dict

In [6]:
prob_dict

{(('child_intervals', ('0', '0')),
  ('child_qualities', ('minor', 'minor')),
  ('parent_quality', 'minor')): 0.4410143329658214,
 (('child_intervals', ('7', '0')),
  ('child_qualities', ('sus', 'minor')),
  ('parent_quality', 'minor')): 0.0033076074972436605,
 (('child_intervals', ('10', '0')),
  ('child_qualities', ('sus', 'sus')),
  ('parent_quality', 'sus')): 0.14285714285714285,
 (('child_intervals', ('7', '0')),
  ('child_qualities', ('minor', 'sus')),
  ('parent_quality', 'sus')): 0.09523809523809523,
 (('child_intervals', ('7', '0')),
  ('child_qualities', ('major', 'minor')),
  ('parent_quality', 'minor')): 0.3263506063947078,
 (('child_intervals', ('1', '0')),
  ('child_qualities', ('major', 'major')),
  ('parent_quality', 'major')): 0.07003089598352215,
 (('child_intervals', ('7', '0')),
  ('child_qualities', ('major', 'major')),
  ('parent_quality', 'major')): 0.25060075523515274,
 (('child_intervals', ('7', '0')),
  ('child_qualities', ('minor', 'major')),
  ('parent_quali

In [587]:
from itertools import product

def expand_prob_dict(prob_dict, default_prob=1e-6):
    INTERVAL_VOCAB = list(str(i) for i in range(12))
    QUALITY_VOCAB = ['major', 'minor', 'sus', 'unknown']

    all_child_intervals = list(product(INTERVAL_VOCAB, repeat=2))
    all_child_qualities = list(product(QUALITY_VOCAB, repeat=2))
    all_parent_qualities = QUALITY_VOCAB

    count_added = 0

    for interval_combo, quality_combo, parent_quality in product(all_child_intervals, all_child_qualities, all_parent_qualities):
        if parent_quality not in quality_combo:
            continue  # skip if parent quality not present in children

        key = Rule(parent_quality, interval_combo, quality_combo).make_hashable()

        if key not in prob_dict:
            prob_dict[key] = default_prob
            count_added += 1

    print(f"Added {count_added} new rules to prob_dict.")
    return prob_dict

In [108]:
len(prob_dict)

9216

In [100]:
Rule(
    'major',
    ['0', '4'],
    ['major', 'minor']
).make_hashable()

(('child_intervals', ('0', '4')),
 ('child_qualities', ('major', 'minor')),
 ('parent_quality', 'major'))

In [None]:
INTERVAL_VOCAB = [str(i) for i in range(12)]
QUALITY_VOCAB = ['major', 'minor', 'sus', 'unknown']

missing = []

for parent_quality in QUALITY_VOCAB:
    for int1 in INTERVAL_VOCAB:
        for q1 in QUALITY_VOCAB:
            for int2 in INTERVAL_VOCAB:
                for q2 in QUALITY_VOCAB:
                    key = Rule(parent_quality, [int1, int2], [q1, q2]).make_hashable()
                    if key not in prob_dict:
                        missing.append(key)

print(f"Total missing combinations: {len(missing)}")
for m in missing[:10]:  # just show a few
    print(m)

In [8]:
prob_dict = expand_prob_dict(prob_dict,default_prob=1e-6)

NameError: name 'expand_prob_dict' is not defined

# RL Model

In [10]:
rules = [Rule.unhash(i) for i in prob_dict.keys()]

In [11]:
class TreeNode:
    def __init__(self, chord):
        self.chord = chord
        self.left = None           # child node
        self.right = None          # child node
        self.parent = None         # to be filled in later

    def is_leaf(self):
        return self.left is None and self.right is None

    def __repr__(self):
        return f"TreeNode(label={self.chord.label})"


def create_parent_node(left, right, parent, parent_quality):
    parent = TreeNode(Chord(f"{parent.chord.label}"))
    parent.left = left
    parent.right = right
    left.parent = parent
    right.parent = parent

    return parent

def get_top_level_nodes(nodes):
    """
    Given a list of TreeNodes (leaves and partial trees), 
    return the current top-level (root) nodes that are not children of any other node.
    """
    return [node for node in nodes if node.parent is None]


In [12]:
def get_possible_intervals(chord1,chord2):

    intervals = []
    intervals.append([str(chord1.distance_to(chord1)), str(chord1.distance_to(chord2))])
    intervals.append([str(chord1.distance_from(chord2)), str(chord2.distance_to(chord2))])

    return intervals

In [13]:
def get_qualities(chord1,chord2):
    return [chord1.quality, chord2.quality]

In [14]:
def get_possible_rhs_from_str(chord1,chord2):
    return [(get_possible_intervals(chord1, chord2)[0], get_qualities(chord1, chord2)), (get_possible_intervals(chord1, chord2)[1], get_qualities(chord1, chord2)),]

In [15]:
def get_applicable_rules(rhs_list, rules):
    applicable_rules = []
    for rule in rules:
        for rhs in rhs_list:
            if rule.rhs() == rhs:
                applicable_rules.append(rule)
    return applicable_rules

In [13]:
import torch
import torch.nn as nn

class DQN(nn.Module):
    def __init__(self, history_dim, action_dim, hidden_dim=128):
        super(DQN, self).__init__()

        self.history_encoder = nn.Sequential(
            nn.Linear(history_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.scorer = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # Output one Q-value
        )

    def forward(self, history_vec, actions_vecs):
        """
        Args:
            history_vec: Tensor (batch_size, history_dim)
            actions_vecs: Tensor (batch_size, num_actions, action_dim)

        Returns:
            q_values: Tensor (batch_size, num_actions)
        """
        batch_size, num_actions, action_dim = actions_vecs.shape

        # Encode the history once
        history_encoded = self.history_encoder(history_vec)  # (batch_size, hidden_dim)

        # Flatten actions and encode them
        actions_flat = actions_vecs.view(-1, action_dim)  # (batch_size * num_actions, action_dim)
        actions_encoded = self.action_encoder(actions_flat)  # (batch_size * num_actions, hidden_dim)
        actions_encoded = actions_encoded.view(batch_size, num_actions, -1)  # (batch_size, num_actions, hidden_dim)

        # Repeat history for each action
        history_repeated = history_encoded.unsqueeze(1).repeat(1, num_actions, 1)  # (batch_size, num_actions, hidden_dim)

        # Concatenate encoded history and actions
        concat = torch.cat([history_repeated, actions_encoded], dim=-1)  # (batch_size, num_actions, hidden_dim*2)

        # Score each (history, action) pair
        q_values = self.scorer(concat).squeeze(-1)  # (batch_size, num_actions)

        return q_values


In [16]:
import random
import numpy as np

class Environment:

    MAX_RULES = 10
    MAX_ACTIONS = 10
    MAX_CHORDS = 10

    def __init__(self, chord_sequence):
        self.initial_sequence = chord_sequence
        self.current_state = chord_sequence.copy()
        self.current_nodes = [TreeNode(Chord(chord)) for chord in chord_sequence]
        self.actions = self.get_actions()
        self.applied_rules = []

    # Each index of rhs are a list of applicable rules to currentstate and currentstate + 1
    def get_actions(self):
        rhs_list = []

        for i in range(len(self.current_state) - 1):
            chord1 = Chord(self.current_state[i])
            chord2 = Chord(self.current_state[i + 1])
            possible_rhs = get_possible_rhs_from_str(chord1, chord2)
            rhs_list.append(get_applicable_rules(possible_rhs, rules))


        return rhs_list

    def build_action_index_map(self):

        index_map = []
        for i, rule_list in enumerate(self.get_actions()):
            for j, _ in enumerate(rule_list):
                index_map.append((i, j))
        return index_map

    def apply_rule_index(self,rule_index_i,rule_index_j):
        rule = self.actions[rule_index_i][rule_index_j]
        node1 = self.current_nodes[rule_index_i]
        node2 = self.current_nodes[rule_index_i + 1]
        nodes = [node1,node2]
        parent_root_index = rule.child_intervals.index('0') 
        parent = create_parent_node(node1, node2, nodes[parent_root_index], rule.lhs())
        self.current_nodes[rule_index_i] = parent
        self.current_nodes.pop(rule_index_i + 1)
        self.applied_rules.append(rule)
        self.current_nodes = get_top_level_nodes(self.current_nodes)
        self.current_state = [node.chord.label for node in self.current_nodes]
        self.actions = self.get_actions()

    def apply_rule(self, rule, node1, node2):
        nodes = [node1,node2]
        parent_root_index = rule.child_intervals.index('0') 
        parent = create_parent_node(node1, node2, nodes[parent_root_index], rule.lhs())

        self.current_nodes.append(parent)
        self.applied_rules.append(rule)

        self.current_nodes = get_top_level_nodes(self.current_nodes)
        self.current_state = [node.chord.label for node in self.current_nodes]
        self.actions = self.get_actions()

    def get_state_tensor2(self):
        QUALITY_VOCAB = ['major', 'minor', 'sus', 'unknown']
        INTERVAL_VOCAB = [str(i) for i in range(0,12)]

        def one_hot_encode(value, vocab):
            vec = [0] * len(vocab)
            idx = vocab.index(value if value in vocab else vocab.index('Other'))
            vec[idx] = 1
            return vec

        def encode_rule(rule):
            parent_vec = one_hot_encode(rule.parent_quality, QUALITY_VOCAB)

            interval1_vec = one_hot_encode(rule.child_intervals[0], INTERVAL_VOCAB)
            quality1_vec = one_hot_encode(rule.child_qualities[0], QUALITY_VOCAB)

            interval2_vec = one_hot_encode(rule.child_intervals[1], INTERVAL_VOCAB)
            quality2_vec = one_hot_encode(rule.child_qualities[1], QUALITY_VOCAB)

            return parent_vec + interval1_vec + quality1_vec + interval2_vec + quality2_vec

        chord_vector = []
        for chord in self.current_state[:self.MAX_CHORDS]:
            chord_vector.append(Chord.encode_chord(chord))

        # Pad if needed
        while len(chord_vector) < self.MAX_CHORDS:
            chord_vector.append([0] * len(chord_vector[0]))

        flat_chord_vector = [val for chord in chord_vector for val in chord]
        # chord_tensor = torch.tensor(flat_chord_vector, dtype=torch.float32).unsqueeze(0)

        actions = self.get_actions()
        action_vector = []

        count = 0
        for i in range(len(actions)):
            for j in range(len(actions[i])):
                if count >= self.MAX_ACTIONS:
                    break
                action_vector.append(encode_rule(actions[i][j]))
                count += 1
            if count >= self.MAX_ACTIONS:
                break
        
        flat_action_vector = [val for action in action_vector for val in action]

        return torch.tensor(flat_chord_vector + flat_action_vector, dtype=torch.float32)
        # rule_dim = len(action_vector[0]) if action_vector else 0
        # while len(action_vector) < self.MAX_ACTIONS:
        #     action_vector.append([0] * rule_dim)

        # actions_tensor = torch.tensor(action_vector, dtype=torch.float32).unsqueeze(0)
        # return chord_tensor, actions_tensor


    def get_state_tensor(self):
        QUALITY_VOCAB = ['major', 'minor', 'sus', 'unknown']
        INTERVAL_VOCAB = [str(i) for i in range(0,12)]

        def one_hot_encode(value, vocab):
            vec = [0] * len(vocab)
            idx = vocab.index(value if value in vocab else vocab.index('Other'))
            vec[idx] = 1
            return vec

        def encode_rule(rule):
            parent_vec = one_hot_encode(rule.parent_quality, QUALITY_VOCAB)

            interval1_vec = one_hot_encode(rule.child_intervals[0], INTERVAL_VOCAB)
            quality1_vec = one_hot_encode(rule.child_qualities[0], QUALITY_VOCAB)

            interval2_vec = one_hot_encode(rule.child_intervals[1], INTERVAL_VOCAB)
            quality2_vec = one_hot_encode(rule.child_qualities[1], QUALITY_VOCAB)

            return parent_vec + interval1_vec + quality1_vec + interval2_vec + quality2_vec

        chord_vector = []
        for chord in self.current_state[:self.MAX_CHORDS]:
            chord_vector.append(Chord.encode_chord(chord))

        # Pad if needed
        while len(chord_vector) < self.MAX_CHORDS:
            chord_vector.append([0] * len(chord_vector[0]))

        flat_chord_vector = [val for chord in chord_vector for val in chord]
        chord_tensor = torch.tensor(flat_chord_vector, dtype=torch.float32).unsqueeze(0)

        actions = self.get_actions()
        action_vector = []

        count = 0
        for i in range(len(actions)):
            for j in range(len(actions[i])):
                if count >= self.MAX_ACTIONS:
                    break
                action_vector.append(encode_rule(actions[i][j]))
                count += 1
            if count >= self.MAX_ACTIONS:
                break

        rule_dim = len(action_vector[0]) if action_vector else 0
        while len(action_vector) < self.MAX_ACTIONS:
            action_vector.append([0] * rule_dim)

        actions_tensor = torch.tensor(action_vector, dtype=torch.float32).unsqueeze(0)
        return chord_tensor, actions_tensor

    def step(self, epsilon=0.3):
        chord_tensor, actions_tensor = self.get_state_tensor()
        results = self.model.forward(chord_tensor, actions_tensor)
        index_map = self.build_action_index_map()[0:10]
        # print(len(index_map))
        if len(index_map) == 0:
            return None  # no actions possible
        
        if random.random() < epsilon:
            flat_index = random.randint(0, len(index_map) - 1)
        else:
            q_values = self.model(chord_tensor, actions_tensor).detach().squeeze(0)[0:len(index_map)]  # shape: (num_actions,)
            flat_index = q_values.squeeze(0).argmax().item()

        # reward_before = self.evaluate_tree()
        i, j = index_map[flat_index]
        self.apply_rule_index(i, j)
        # reward_after = self.evaluate_tree()
        last_rule = self.applied_rules[-1]
        # reward = prob_dict.get(last_rule.make_hashable(), 1e-10)  # Use same small value as in evaluate_tree
        # delta_reward = reward_after - reward_before
        next_chord_tensor, next_actions_tensor = self.get_state_tensor()
        # reward = self.evaluate_tree() if self.is_terminal() else np.log(prob_dict.get(last_rule.make_hashable(), 1e-10))
        reward = np.log(prob_dict.get(last_rule.make_hashable(), 1e-10))
        done = self.is_terminal()
        # if done:
        #     delta_reward = self.evaluate_tree()
        return {
        "state": chord_tensor,
        "actions": actions_tensor,
        "action_index": flat_index,
        "reward": reward,
        "next_state": next_chord_tensor,
        "next_actions": next_actions_tensor,
        "done": done
    }
    

    def nn_simulate(self, epsilon=0.1):
        while not self.is_terminal():
            result = self.step()
            if result is None:
                break
 

    def add_model(self, model):
        self.model = model

    def is_terminal(self):
        return len(self.current_state) == 1

    def random_baseline_step(self):
        chord_tensor, actions_tensor = self.get_state_tensor()
        index_map = self.build_action_index_map()

        if len(index_map) == 0:
            return None
        flat_index = random.randint(0, len(index_map) - 1)
        i, j = index_map[flat_index]
        self.apply_rule_index(i, j)

    def simulate_random_baseline(self):
        while not self.is_terminal():
            self.random_baseline_step()

    def simulate_greedy_baseline(self):
        while not self.is_terminal():
            actions = self.get_actions()
            index_map = self.build_action_index_map()
            max_prob = float('-inf')
            max_i = 0
            max_j = 0
            for i, rule_list in enumerate(actions):
                for j, rule in enumerate(rule_list):
                    prob = prob_dict.get(rule.make_hashable(), 1e-10)  # Use same small value as in evaluate_tree
                    if prob > max_prob:
                        max_prob = prob
                        max_i = i
                        max_j = j
            if len(index_map) == 0:
                break
            self.apply_rule_index(max_i, max_j)

    def evaluate_tree(self):
        log_probs = 0
        for rule in self.applied_rules:
            log_probs += np.log(prob_dict.get(rule.make_hashable(), 1e-10))  # Add a small value to avoid log(0)

        return log_probs 



In [15]:
def decode_state_tensor(chord_tensor, actions_tensor):
    # Constants from encoding
    QUALITY_VOCAB = ['major', 'minor', 'sus', 'unknown']
    INTERVAL_VOCAB = [str(i) for i in range(12)]

    def one_hot_decode(vec, vocab):
        idx = int(np.argmax(vec))
        return vocab[idx]

    def decode_chord(chord_vec):
        # Assuming Chord.encode_chord() returns a known-length flat vector
        # You'll need to know how many features make up one chord
        # Example assumes each chord is encoded as 12 pitch classes + 1 quality index (simplified)
        root = np.argmax(chord_vec[:12])
        quality_idx = np.argmax(chord_vec[12:12 + len(QUALITY_VOCAB)])
        quality = QUALITY_VOCAB[quality_idx]
        return f"{root}_{quality}"

    def decode_action(action_vec):
        split = np.array_split(action_vec, [len(QUALITY_VOCAB),
                                            len(QUALITY_VOCAB) + len(INTERVAL_VOCAB),
                                            len(QUALITY_VOCAB) + len(INTERVAL_VOCAB)*2,
                                            len(QUALITY_VOCAB) + len(INTERVAL_VOCAB)*2 + len(QUALITY_VOCAB)])
        parent_vec, interval1_vec, quality1_vec, interval2_vec, quality2_vec = split
        return {
            "parent_quality": one_hot_decode(parent_vec, QUALITY_VOCAB),
            "child_interval_1": one_hot_decode(interval1_vec, INTERVAL_VOCAB),
            "child_quality_1": one_hot_decode(quality1_vec, QUALITY_VOCAB),
            "child_interval_2": one_hot_decode(interval2_vec, INTERVAL_VOCAB),
            "child_quality_2": one_hot_decode(quality2_vec, QUALITY_VOCAB),
        }

    chord_tensor_np = chord_tensor.squeeze(0).detach().numpy()
    actions_tensor_np = actions_tensor.squeeze(0).detach().numpy()

    # Decode chords
    chord_size = len(Chord.encode_chord("0_major"))  # adjust if you have a different encoding
    chords = []
    for i in range(0, len(chord_tensor_np), chord_size):
        chunk = chord_tensor_np[i:i + chord_size]
        if not np.allclose(chunk, 0):  # skip padding
            chords.append(decode_chord(chunk))

    # Decode actions
    rule_size = len(actions_tensor_np[0])
    actions = []
    for action_vec in actions_tensor_np:
        if not np.allclose(action_vec, 0):  # skip padding
            actions.append(decode_action(action_vec))

    return chords, actions

In [17]:
chords = ['Cmaj7', 'Dmin7', 'Emin7', 'Fmaj7', 'G7', 'Amin7b5', 'Bbsus']
env = Environment(chords)
env.simulate_greedy_baseline()
env.evaluate_tree()

-26.116676455677904

In [19]:
chords = ['C^', 'Dmin7', 'Emin7', 'F^7', 'G7', 'Amin7b5', 'Bbsus']
env = Environment(chords)

In [20]:
chords = ['Cmaj7', 'Dmin7', 'Emin7', 'Fmaj7', 'G7', 'Amin7b5', 'Bbsus']
env = Environment(chords)
env.simulate_random_baseline()
env.evaluate_tree()

NameError: name 'torch' is not defined

In [34]:
chords_dataset[0]

['D7', 'G^7', 'F#m7', 'Bm', 'Em7', 'Bm', 'Am7']

In [38]:
len(chords_dataset)

1170

In [71]:
greedy = []
random_vals = []
nn_vals = []
counter = 0
for chord in chords_dataset:
    counter += 1
    print('num ',counter)
    print('a')
    env = Environment(chord)
    env.add_model(model)
    env.simulate_random_baseline()
    random_vals.append(env.evaluate_tree())
    print('b')
    env = Environment(chord)
    env.add_model(model)
    env.simulate_greedy_baseline()
    greedy.append(env.evaluate_tree())
    print('c')
    env = Environment(chord)
    env.add_model(model)
    env.nn_simulate()
    nn_vals.append(env.evaluate_tree())
    

num  1
a
b
c
num  2
a
b
c
num  3
a
b
c
num  4
a
b
c
num  5
a
b
c
num  6
a
b
c
num  7
a
b
c
num  8
a
b
c
num  9
a
b
c
num  10
a
b
c
num  11
a
b
c
num  12
a
b
c
num  13
a
b
c
num  14
a
b
c
num  15
a
b
c
num  16
a
b
c
num  17
a
b
c
num  18
a
b
c
num  19
a
b
c
num  20
a
b
c
num  21
a
b
c
num  22
a
b
c
num  23
a
b
c
num  24
a
b
c
num  25
a
b
c
num  26
a
b
c
num  27
a
b
c
num  28
a
b
c
num  29
a
b
c
num  30
a
b
c
num  31
a
b
c
num  32
a
b
c
num  33
a
b
c
num  34
a
b
c
num  35
a
b
c
num  36
a
b
c
num  37
a
b
c
num  38
a
b
c
num  39
a
b
c
num  40
a
b
c
num  41
a
b
c
num  42
a
b
c
num  43
a
b
c
num  44
a
b
c
num  45
a
b
c
num  46
a
b
c
num  47
a
b
c
num  48
a
b
c
num  49
a
b
c
num  50
a
b
c
num  51
a
b
c
num  52
a
b
c
num  53
a
b
c
num  54
a
b
c
num  55
a
b
c
num  56
a
b
c
num  57
a
b
c
num  58
a
b
c
num  59
a
b
c
num  60
a
b
c
num  61
a
b
c
num  62
a
b
c
num  63
a
b
c
num  64
a
b
c
num  65
a
b
c
num  66
a
b
c
num  67
a
b
c
num  68
a
b
c
num  69
a
b
c
num  70
a
b
c
num  71
a
b
c
num  72
a
b
c
n

In [84]:
np.mean(greedy)

-124.60714835425685

In [85]:
np.mean(random_vals)

-270.6164459674597

In [86]:
np.mean(nn_vals)     

-167.27215844990178

In [83]:
chords = ['Cmaj7', 'Dmin7', 'Emin7', 'Fmaj7', 'G7', 'Amin7b5', 'Bbsus']
env = Environment(chords)
env.add_model(model)
env.simulate_greedy_baseline()
env.evaluate_tree()

-26.116676455677904

In [62]:
chords = ['C^', 'D^7', 'Emin7', 'Fmaj7', 'G7', 'Amin7b5', 'Bbsus']
env = Environment(chords)

In [66]:
env.get_actions()

[[Rule(parent_quality=major, child_intervals=['10', '0'], child_qualities=['major', 'major']),
  Rule(parent_quality=major, child_intervals=['0', '2'], child_qualities=['major', 'major'])],
 [Rule(parent_quality=minor, child_intervals=['10', '0'], child_qualities=['major', 'minor']),
  Rule(parent_quality=major, child_intervals=['0', '2'], child_qualities=['major', 'minor']),
  Rule(parent_quality=minor, child_intervals=['0', '2'], child_qualities=['major', 'minor']),
  Rule(parent_quality=major, child_intervals=['10', '0'], child_qualities=['major', 'minor'])],
 [Rule(parent_quality=minor, child_intervals=['0', '1'], child_qualities=['minor', 'minor']),
  Rule(parent_quality=minor, child_intervals=['11', '0'], child_qualities=['minor', 'minor'])],
 [Rule(parent_quality=major, child_intervals=['10', '0'], child_qualities=['minor', 'major']),
  Rule(parent_quality=major, child_intervals=['0', '2'], child_qualities=['minor', 'major']),
  Rule(parent_quality=minor, child_intervals=['0', '

In [181]:
from sklearn.model_selection import train_test_split

# Split the data into training and test sets (80-20 split)
train_ccts, test_ccts = train_test_split(ccts, test_size=0.2, random_state=42)

print(f"Training set size: {len(train_ccts)}")
print(f"Test set size: {len(test_ccts)}")

Training set size: 120
Test set size: 30


In [21]:
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = []
        self.capacity = capacity

    def add(self, transition):
        self.buffer.append(transition)
        if len(self.buffer) > self.capacity:
            self.buffer.pop(0)

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)

In [67]:
import torch.nn as nn

def train_on_batch(model, optimizer, batch, gamma=0.99):
    hist_batch = torch.cat([b["state"] for b in batch])
    actions_batch = torch.cat([b["actions"] for b in batch])
    next_hist_batch = torch.cat([b["next_state"] for b in batch])
    next_actions_batch = torch.cat([b["next_actions"] for b in batch])
    
    action_indices = torch.tensor([b["action_index"] for b in batch]).unsqueeze(1)
    rewards = torch.tensor([b["reward"] for b in batch], dtype=torch.float32)
    dones = torch.tensor([b["done"] for b in batch], dtype=torch.bool)

    q_values = model(hist_batch, actions_batch)
    q_chosen = q_values.gather(1, action_indices).squeeze(1)

    with torch.no_grad():
        next_q_values = model(next_hist_batch, next_actions_batch)  
        next_q_max = next_q_values.max(dim=1)[0]
        target_q = rewards + gamma * next_q_max * (~dones)

    loss = torch.nn.SmoothL1Loss()(q_chosen, target_q)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # print("Sample rewards:", rewards[:5])
    # print("Sample dones:", dones[:5])
    # print("Sample target Q:", target_q[:5])
    # print("Sample chosen Q:", q_chosen[:5])
    # print("Loss:", loss.item())

    return loss.item()

In [23]:
def train_model(model, dataset, num_episodes=1000, batch_size=32, gamma=0.99):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    buffer = ReplayBuffer(capacity=5000)

    for episode in range(num_episodes):
        # Random starting sequence
        chord_seq = random.choice(dataset)
        env = Environment(chord_seq)
        env.add_model(model)
        loss = None
        while not env.is_terminal():
            transition = env.step(epsilon=0.3)
            if transition and transition["next_actions"].numel() > 0:
                buffer.add(transition)

            if len(buffer) >= batch_size:
                batch = buffer.sample(batch_size)
                loss = train_on_batch(model, optimizer, batch, gamma)

        if episode % 10 == 0:
            if loss:
                print(f"Episode {episode} complete | Last loss: {loss:.4f}")

In [25]:
model = DQN(history_dim=160, action_dim=36)

In [28]:
chords_dataset = list(data['chords'])

In [70]:
train_model(model, chords_dataset, num_episodes=2000, batch_size=32, gamma=0.99)

Episode 0 complete | Last loss: 4.3394
Episode 10 complete | Last loss: 1.4783
Episode 20 complete | Last loss: 4.3424
Episode 30 complete | Last loss: 3.9391
Episode 40 complete | Last loss: 8.3106
Episode 50 complete | Last loss: 7.3341
Episode 60 complete | Last loss: 0.8100
Episode 70 complete | Last loss: 4.4357
Episode 80 complete | Last loss: 1.5888
Episode 90 complete | Last loss: 1.0954
Episode 100 complete | Last loss: 0.6362
Episode 110 complete | Last loss: 2.5546
Episode 120 complete | Last loss: 2.1352
Episode 130 complete | Last loss: 2.5209
Episode 140 complete | Last loss: 1.8690
Episode 150 complete | Last loss: 1.6016
Episode 160 complete | Last loss: 6.4681
Episode 170 complete | Last loss: 2.2542
Episode 180 complete | Last loss: 3.3504
Episode 190 complete | Last loss: 2.0882
Episode 200 complete | Last loss: 2.1531
Episode 210 complete | Last loss: 4.4122
Episode 220 complete | Last loss: 2.6154
Episode 230 complete | Last loss: 1.5049
Episode 240 complete | Last

In [None]:
env.add_model(model)

In [69]:
model = DQN(160,36)

In [24]:
import torch
import torch.nn as nn

class DQN(nn.Module):
    def __init__(self, history_dim, action_dim, hidden_dim=128):
        super(DQN, self).__init__()

        # New input dim is raw history + raw action
        combined_dim = history_dim + action_dim

        self.q_net = nn.Sequential(
            nn.Linear(combined_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # One Q-value per (history, action) pair
        )

    def forward(self, history_vec, actions_vecs):
        """
        Args:
            history_vec: Tensor (batch_size, history_dim)
            actions_vecs: Tensor (batch_size, num_actions, action_dim)

        Returns:
            q_values: Tensor (batch_size, num_actions)
        """
        batch_size, num_actions, action_dim = actions_vecs.shape
        history_dim = history_vec.shape[1]

        # Repeat history for each action
        history_repeated = history_vec.unsqueeze(1).repeat(1, num_actions, 1)  # (batch_size, num_actions, history_dim)

        # Concatenate raw history and actions
        concat = torch.cat([history_repeated, actions_vecs], dim=-1)  # (batch_size, num_actions, history_dim + action_dim)

        # Flatten for input to network
        concat_flat = concat.view(-1, history_dim + action_dim)  # (batch_size * num_actions, combined_dim)

        # print(concat_flat)
        # Pass through Q-network
        q_flat = self.q_net(concat_flat)  # (batch_size * num_actions, 1)
        q_values = q_flat.view(batch_size, num_actions)  # (batch_size, num_actions)

        return q_values

In [43]:
import torch
import torch.nn as nn

class DQN(nn.Module):
    def __init__(self, history_dim, action_dim, hidden_dim=128):
        super(DQN, self).__init__()

        self.history_encoder = nn.Sequential(
            nn.Linear(history_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.scorer = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # Output one Q-value
        )

    def forward(self, history_vec, actions_vecs):
        """
        Args:
            history_vec: Tensor (batch_size, history_dim)
            actions_vecs: Tensor (batch_size, num_actions, action_dim)

        Returns:
            q_values: Tensor (batch_size, num_actions)
        """
        batch_size, num_actions, action_dim = actions_vecs.shape

        # Encode the history once
        history_encoded = self.history_encoder(history_vec)  # (batch_size, hidden_dim)
        # Flatten actions and encode them
        actions_flat = actions_vecs.view(-1, action_dim)  # (batch_size * num_actions, action_dim)
        actions_encoded = self.action_encoder(actions_flat)  # (batch_size * num_actions, hidden_dim)
        actions_encoded = actions_encoded.view(batch_size, num_actions, -1)  # (batch_size, num_actions, hidden_dim)

        # Repeat history for each action
        history_repeated = history_encoded.unsqueeze(1).repeat(1, num_actions, 1)  # (batch_size, num_actions, hidden_dim)

        # Concatenate encoded history and actions
        concat = torch.cat([history_repeated, actions_encoded], dim=-1)  # (batch_size, num_actions, hidden_dim*2)

        # Score each (history, action) pair
        q_values = self.scorer(concat).squeeze(-1)  # (batch_size, num_actions)

        return q_values


In [68]:
import torch
import torch.nn as nn

class DQN(nn.Module):
    def __init__(self, history_dim, action_dim, hidden_dim=128):
        super(DQN, self).__init__()

        self.history_encoder = nn.Sequential(
            nn.Linear(history_dim, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.01)
        )

        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.01)
        )

        self.scorer = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, 1)
        )

        # Apply custom initialization
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
            nn.init.zeros_(m.bias)

    def forward(self, history_vec, actions_vecs):
        batch_size, num_actions, action_dim = actions_vecs.shape

        history_encoded = self.history_encoder(history_vec)  # (batch_size, hidden_dim)

        actions_flat = actions_vecs.view(-1, action_dim)
        actions_encoded = self.action_encoder(actions_flat)
        actions_encoded = actions_encoded.view(batch_size, num_actions, -1)

        history_repeated = history_encoded.unsqueeze(1).repeat(1, num_actions, 1)

        concat = torch.cat([history_repeated, actions_encoded], dim=-1)

        q_values = self.scorer(concat).squeeze(-1)  # (batch_size, num_actions)

        return q_values

In [653]:
class DQN(nn.Module):
    def __init__(self, history_dim, action_dim, hidden_dim=128):
        super(DQN, self).__init__()

        self.history_encoder = nn.Sequential(
            nn.Linear(history_dim, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.LayerNorm(hidden_dim),  # <-- new layer
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.01)
        )

        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.LayerNorm(hidden_dim),  # <-- new layer
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.01)
        )

        self.scorer = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.LayerNorm(hidden_dim),  # <-- new layer
            nn.Linear(hidden_dim, 1)
        )

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
            nn.init.zeros_(m.bias)

    def forward(self, history_vec, actions_vecs):
        batch_size, num_actions, action_dim = actions_vecs.shape

        history_encoded = self.history_encoder(history_vec)  # (batch_size, hidden_dim)

        actions_flat = actions_vecs.view(-1, action_dim)
        actions_encoded = self.action_encoder(actions_flat)
        actions_encoded = actions_encoded.view(batch_size, num_actions, -1)

        history_repeated = history_encoded.unsqueeze(1).repeat(1, num_actions, 1)

        concat = torch.cat([history_repeated, actions_encoded], dim=-1)

        q_values = self.scorer(concat).squeeze(-1)  # (batch_size, num_actions)

        return q_values


In [554]:
import torch
import torch.nn as nn

class DQN(nn.Module):
    def __init__(self, history_dim, action_dim, hidden_dim=128):
        super(DQN, self).__init__()

        self.history_encoder = nn.Sequential(
            nn.Linear(history_dim, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.01)
        )

        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.01)
        )

        self.scorer = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.Linear(hidden_dim, 1)
        )

        # Apply custom initialization
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
            nn.init.zeros_(m.bias)

    def forward(self, history_vec, actions_vecs):
        batch_size, num_actions, action_dim = actions_vecs.shape

        history_encoded = self.history_encoder(history_vec)  # (batch_size, hidden_dim)

        actions_flat = actions_vecs.view(-1, action_dim)
        actions_encoded = self.action_encoder(actions_flat)
        actions_encoded = actions_encoded.view(batch_size, num_actions, -1)

        history_repeated = history_encoded.unsqueeze(1).repeat(1, num_actions, 1)

        concat = torch.cat([history_repeated, actions_encoded], dim=-1)

        q_values = self.scorer(concat).squeeze(-1)  # (batch_size, num_actions)

        return q_values
