In [1]:
import pandas as pd 
from Chord import Chord
from Rule import Rule
from collections import defaultdict
import pitchtypes as pt
from probabilistic_model import ProbabilisticModel

In [2]:
p_model = ProbabilisticModel()

In [3]:
data = pd.read_json('../eda/JazzHarmonyTreebank/treebank.json')
trees_data = data[data['trees'].notna()]['trees']
ccts = [x[0]['complete_constituent_tree'] for x in list(trees_data)]

In [4]:
p_model.fit(ccts)

In [5]:
p_model.prob_dict

{(('child_intervals', ('0', '0')),
  ('child_qualities', ('minor', 'minor')),
  ('parent_quality', 'minor')): 0.4410143329658214,
 (('child_intervals', ('7', '0')),
  ('child_qualities', ('sus', 'minor')),
  ('parent_quality', 'minor')): 0.0033076074972436605,
 (('child_intervals', ('10', '0')),
  ('child_qualities', ('sus', 'sus')),
  ('parent_quality', 'sus')): 0.14285714285714285,
 (('child_intervals', ('7', '0')),
  ('child_qualities', ('minor', 'sus')),
  ('parent_quality', 'sus')): 0.09523809523809523,
 (('child_intervals', ('7', '0')),
  ('child_qualities', ('major', 'minor')),
  ('parent_quality', 'minor')): 0.3263506063947078,
 (('child_intervals', ('1', '0')),
  ('child_qualities', ('major', 'major')),
  ('parent_quality', 'major')): 0.07003089598352215,
 (('child_intervals', ('7', '0')),
  ('child_qualities', ('major', 'major')),
  ('parent_quality', 'major')): 0.25060075523515274,
 (('child_intervals', ('7', '0')),
  ('child_qualities', ('minor', 'major')),
  ('parent_quali

In [7]:
def parse_subtree(subtree):
    if isinstance(subtree, dict):
        label = Chord(subtree['label'])
        child_intervals = [str(label.distance_to(Chord(x['label']))) for x in subtree['children']]
        child_qualities = [Chord(x['label']).quality for x in subtree['children']]
        rule = Rule(label.quality, child_intervals, child_qualities)
        
        return rule

In [8]:
def make_hashable(d):
    def convert(value):
        if isinstance(value, dict):
            return tuple(sorted((k, convert(v)) for k, v in value.items()))
        elif isinstance(value, list):
            return tuple(convert(v) for v in value)
        else:
            return value

    return tuple(sorted((k, convert(v)) for k, v in d.items()))

In [None]:
count_dict = defaultdict(int)
lhs_count_dict = defaultdict(int)

def parse_leftmost(tree):
    output = []
    steps = []
    stack = [tree]
    parsed_so_far = []

    while stack:
        current = stack.pop(0)
        rule = parse_subtree(current)
        rule_key = rule.make_hashable()
        count_dict[rule_key] += 1
        lhs_count_dict[rule.lhs()] += 1

        steps.append({
            'current': current['label'],
            'parsed_before': parsed_so_far.copy()
        })

        if current.get('children'):
            stack = current['children'] + stack
        else:
            output.append(current['label'])

        parsed_so_far.append(current['label'])

    return steps, count_dict, lhs_count_dict

In [None]:
for cct in ccts:
    steps, count_dict, lhs_count_dict = parse_leftmost(cct)


In [None]:
sorted_dict = dict(sorted(count_dict.items(), key=lambda x: x[1], reverse=False))

In [None]:
lhs_count_dict = dict(sorted(lhs_count_dict.items(), key=lambda x: x[1], reverse=False))
lhs_count_dict 

{'sus': 47, 'unknown': 321, 'minor': 2125, 'major': 5455}

In [None]:
def compute_conditional_probs(count_dict, lhs_count_dict):
    prob_dict = {}
    for rule, count in count_dict.items():
        lhs = rule[2][1]
        if lhs_count_dict[lhs] == 0:
            print(lhs)
        prob_dict[rule] = count / lhs_count_dict[lhs]
    return prob_dict


In [None]:
def get_rule_probability(node, prob_dict):
    subtree = parse_subtree(node)
    rule_key = subtree.make_hashable()
    return prob_dict.get(rule_key, 0)

In [None]:
def compute_tree_probability(tree, prob_dict):
    if not tree.get('children'):
        return 1.0  

    prob = get_rule_probability(tree, prob_dict)

    for child in tree['children']:
        prob *= compute_tree_probability(child, prob_dict)

    return prob
    

In [None]:
for cct in ccts:
    steps, count_dict, lhs_count_dict = parse_leftmost(cct)
    prob_dict = compute_conditional_probs(count_dict, lhs_count_dict)

In [None]:
compute_tree_probability(ccts[3], expanded_prob_dict)

1.8069074877778802e-11

In [11]:
prob_dict = p_model.prob_dict

In [12]:
prob_dict

{(('child_intervals', ('0', '0')),
  ('child_qualities', ('minor', 'minor')),
  ('parent_quality', 'minor')): 0.4410143329658214,
 (('child_intervals', ('7', '0')),
  ('child_qualities', ('sus', 'minor')),
  ('parent_quality', 'minor')): 0.0033076074972436605,
 (('child_intervals', ('10', '0')),
  ('child_qualities', ('sus', 'sus')),
  ('parent_quality', 'sus')): 0.14285714285714285,
 (('child_intervals', ('7', '0')),
  ('child_qualities', ('minor', 'sus')),
  ('parent_quality', 'sus')): 0.09523809523809523,
 (('child_intervals', ('7', '0')),
  ('child_qualities', ('major', 'minor')),
  ('parent_quality', 'minor')): 0.3263506063947078,
 (('child_intervals', ('1', '0')),
  ('child_qualities', ('major', 'major')),
  ('parent_quality', 'major')): 0.07003089598352215,
 (('child_intervals', ('7', '0')),
  ('child_qualities', ('major', 'major')),
  ('parent_quality', 'major')): 0.25060075523515274,
 (('child_intervals', ('7', '0')),
  ('child_qualities', ('minor', 'major')),
  ('parent_quali

In [103]:
def expand_prob_dict(prob_dict, default_prob=1e-6):
    INTERVAL_VOCAB = list(str(i) for i in range(0,12))
    QUALITY_VOCAB = ['major', 'minor', 'sus', 'unknown']

    all_child_intervals = list(product(INTERVAL_VOCAB, repeat=2))
    all_child_qualities = list(product(QUALITY_VOCAB, repeat=2))
    all_parent_qualities = QUALITY_VOCAB

    count_added = 0

    for interval_combo, quality_combo, parent_quality in product(all_child_intervals, all_child_qualities, all_parent_qualities):
        key = Rule(parent_quality, interval_combo, quality_combo).make_hashable()

        if key not in prob_dict:
            prob_dict[key] = default_prob
            count_added += 1

    print(f"Added {count_added} new rules to prob_dict.")
    return prob_dict

In [108]:
len(prob_dict)

9216

In [100]:
Rule(
    'major',
    ['0', '4'],
    ['major', 'minor']
).make_hashable()

(('child_intervals', ('0', '4')),
 ('child_qualities', ('major', 'minor')),
 ('parent_quality', 'major'))

In [106]:
INTERVAL_VOCAB = [str(i) for i in range(12)]
QUALITY_VOCAB = ['major', 'minor', 'sus', 'unknown']

missing = []

for parent_quality in QUALITY_VOCAB:
    for int1 in INTERVAL_VOCAB:
        for q1 in QUALITY_VOCAB:
            for int2 in INTERVAL_VOCAB:
                for q2 in QUALITY_VOCAB:
                    key = Rule(parent_quality, [int1, int2], [q1, q2]).make_hashable()
                    if key not in prob_dict:
                        missing.append(key)

print(f"Total missing combinations: {len(missing)}")
for m in missing[:10]:  # just show a few
    print(m)

Total missing combinations: 0


In [105]:
prob_dict = expand_prob_dict(prob_dict)

Added 0 new rules to prob_dict.


# RL Model

In [148]:
rules = [Rule.unhash(i) for i in prob_dict.keys()]

In [None]:
class TreeNode:
    def __init__(self, chord):
        self.chord = chord
        self.left = None           # child node
        self.right = None          # child node
        self.parent = None         # to be filled in later

    def is_leaf(self):
        return self.left is None and self.right is None

    def __repr__(self):
        return f"TreeNode(label={self.chord.label})"


def create_parent_node(left, right, parent, parent_quality):
    parent = TreeNode(Chord(f"{parent.chord.label}"))
    parent.left = left
    parent.right = right
    left.parent = parent
    right.parent = parent

    return parent

def get_top_level_nodes(nodes):
    """
    Given a list of TreeNodes (leaves and partial trees), 
    return the current top-level (root) nodes that are not children of any other node.
    """
    return [node for node in nodes if node.parent is None]


In [15]:
def get_possible_intervals(chord1,chord2):

    intervals = []
    intervals.append([str(chord1.distance_to(chord1)), str(chord1.distance_to(chord2))])
    intervals.append([str(chord1.distance_from(chord2)), str(chord2.distance_to(chord2))])

    return intervals

In [17]:
def get_qualities(chord1,chord2):
    return [chord1.quality, chord2.quality]

In [18]:
def get_possible_rhs_from_str(chord1,chord2):
    return [(get_possible_intervals(chord1, chord2)[0], get_qualities(chord1, chord2)), (get_possible_intervals(chord1, chord2)[1], get_qualities(chord1, chord2)),]

In [19]:
def get_applicable_rules(rhs_list, rules):
    applicable_rules = []
    for rule in rules:
        for rhs in rhs_list:
            if rule.rhs() == rhs:
                applicable_rules.append(rule)
    return applicable_rules

In [20]:
import torch
import torch.nn as nn

class DQN(nn.Module):
    def __init__(self, history_dim, action_dim, hidden_dim=128):
        super(DQN, self).__init__()

        self.history_encoder = nn.Sequential(
            nn.Linear(history_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.scorer = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # Output one Q-value
        )

    def forward(self, history_vec, actions_vecs):
        """
        Args:
            history_vec: Tensor (batch_size, history_dim)
            actions_vecs: Tensor (batch_size, num_actions, action_dim)

        Returns:
            q_values: Tensor (batch_size, num_actions)
        """
        batch_size, num_actions, action_dim = actions_vecs.shape

        # Encode the history once
        history_encoded = self.history_encoder(history_vec)  # (batch_size, hidden_dim)

        # Flatten actions and encode them
        actions_flat = actions_vecs.view(-1, action_dim)  # (batch_size * num_actions, action_dim)
        actions_encoded = self.action_encoder(actions_flat)  # (batch_size * num_actions, hidden_dim)
        actions_encoded = actions_encoded.view(batch_size, num_actions, -1)  # (batch_size, num_actions, hidden_dim)

        # Repeat history for each action
        history_repeated = history_encoded.unsqueeze(1).repeat(1, num_actions, 1)  # (batch_size, num_actions, hidden_dim)

        # Concatenate encoded history and actions
        concat = torch.cat([history_repeated, actions_encoded], dim=-1)  # (batch_size, num_actions, hidden_dim*2)

        # Score each (history, action) pair
        q_values = self.scorer(concat).squeeze(-1)  # (batch_size, num_actions)

        return q_values


In [157]:
import random
import numpy as np

class Environment:

    MAX_RULES = 10
    MAX_ACTIONS = 10
    MAX_CHORDS = 10

    def __init__(self, chord_sequence):
        self.initial_sequence = chord_sequence
        self.current_state = chord_sequence.copy()
        self.current_nodes = [TreeNode(Chord(chord)) for chord in chord_sequence]
        self.actions = self.get_actions()
        self.applied_rules = []

    # Each index of rhs are a list of applicable rules to currentstate and currentstate + 1
    def get_actions(self):
        rhs_list = []

        for i in range(len(self.current_state) - 1):
            chord1 = Chord(self.current_state[i])
            chord2 = Chord(self.current_state[i + 1])
            possible_rhs = get_possible_rhs_from_str(chord1, chord2)
            rhs_list.append(get_applicable_rules(possible_rhs, rules))


        return rhs_list

    def build_action_index_map(self):

        index_map = []
        for i, rule_list in enumerate(self.get_actions()):
            for j, _ in enumerate(rule_list):
                index_map.append((i, j))
        return index_map

    def apply_rule_index(self,rule_index_i,rule_index_j):
        rule = self.actions[rule_index_i][rule_index_j]
        node1 = self.current_nodes[rule_index_i]
        node2 = self.current_nodes[rule_index_i + 1]
        nodes = [node1,node2]
        parent_root_index = rule.child_intervals.index('0') 
        parent = create_parent_node(node1, node2, nodes[parent_root_index], rule.lhs())
        self.current_nodes[rule_index_i] = parent
        self.current_nodes.pop(rule_index_i + 1)
        self.applied_rules.append(rule)
        self.current_nodes = get_top_level_nodes(self.current_nodes)
        self.current_state = [node.chord.label for node in self.current_nodes]
        self.actions = self.get_actions()

    def apply_rule(self, rule, node1, node2):
        nodes = [node1,node2]
        parent_root_index = rule.child_intervals.index('0') 
        parent = create_parent_node(node1, node2, nodes[parent_root_index], rule.lhs())

        self.current_nodes.append(parent)
        self.applied_rules.append(rule)

        self.current_nodes = get_top_level_nodes(self.current_nodes)
        self.current_state = [node.chord.label for node in self.current_nodes]
        self.actions = self.get_actions()

    def get_state_tensor(self):
        QUALITY_VOCAB = ['major', 'minor', 'sus', 'unknown']
        INTERVAL_VOCAB = [str(i) for i in range(0,12)]

        def one_hot_encode(value, vocab):
            vec = [0] * len(vocab)
            idx = vocab.index(value if value in vocab else vocab.index('Other'))
            vec[idx] = 1
            return vec

        def encode_rule(rule):
            parent_vec = one_hot_encode(rule.parent_quality, QUALITY_VOCAB)

            interval1_vec = one_hot_encode(rule.child_intervals[0], INTERVAL_VOCAB)
            quality1_vec = one_hot_encode(rule.child_qualities[0], QUALITY_VOCAB)

            interval2_vec = one_hot_encode(rule.child_intervals[1], INTERVAL_VOCAB)
            quality2_vec = one_hot_encode(rule.child_qualities[1], QUALITY_VOCAB)

            return parent_vec + interval1_vec + quality1_vec + interval2_vec + quality2_vec

        chord_vector = []
        for chord in self.current_state[:self.MAX_CHORDS]:
            chord_vector.append(Chord.encode_chord(chord))

        # Pad if needed
        while len(chord_vector) < self.MAX_CHORDS:
            chord_vector.append([0] * len(chord_vector[0]))

        flat_chord_vector = [val for chord in chord_vector for val in chord]
        chord_tensor = torch.tensor(flat_chord_vector, dtype=torch.float32).unsqueeze(0)

        actions = self.get_actions()
        action_vector = []

        count = 0
        for i in range(len(actions)):
            for j in range(len(actions[i])):
                if count >= self.MAX_ACTIONS:
                    break
                action_vector.append(encode_rule(actions[i][j]))
                count += 1
            if count >= self.MAX_ACTIONS:
                break

        rule_dim = len(action_vector[0]) if action_vector else 0
        while len(action_vector) < self.MAX_ACTIONS:
            action_vector.append([0] * rule_dim)

        actions_tensor = torch.tensor(action_vector, dtype=torch.float32).unsqueeze(0)
        return chord_tensor, actions_tensor

    def step(self, epsilon=0.1):
        chord_tensor, actions_tensor = self.get_state_tensor()
        results = self.model.forward(chord_tensor, actions_tensor)
        index_map = self.build_action_index_map()[0:10]

        if len(index_map) == 0:
            return None  # no actions possible
        
        if random.random() < epsilon:
            flat_index = random.randint(0, len(index_map) - 1)
        else:
            q_values = self.model(chord_tensor, actions_tensor).detach().squeeze(0)[0:len(index_map)]  # shape: (num_actions,)
            flat_index = q_values.squeeze(0).argmax().item()

        
        i, j = index_map[flat_index]
        self.apply_rule_index(i, j)
        next_chord_tensor, next_actions_tensor = self.get_state_tensor()
        reward = self.evaluate_tree() if self.is_terminal() else -100
        done = self.is_terminal()

        return {
        "state": chord_tensor,
        "actions": actions_tensor,
        "action_index": flat_index,
        "reward": reward,
        "next_state": next_chord_tensor,
        "next_actions": next_actions_tensor,
        "done": done
    }
 

    def add_model(self, model):
        self.model = model

    def is_terminal(self):
        return len(self.current_state) == 1



    def evaluate_tree(self):
        log_probs = 0
        for rule in self.applied_rules:
            log_probs += prob_dict.get(rule.make_hashable(), 1e-10)  # Add a small value to avoid log(0)

        return log_probs



In [135]:
Chord("Bbsus")

Chord(root=A#, quality=sus)

In [22]:
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = []
        self.capacity = capacity

    def add(self, transition):
        self.buffer.append(transition)
        if len(self.buffer) > self.capacity:
            self.buffer.pop(0)

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)

In [73]:
import torch.nn as nn

def train_on_batch(model, optimizer, batch, gamma=0.99):
    hist_batch = torch.cat([b["state"] for b in batch])
    actions_batch = torch.cat([b["actions"] for b in batch])
    next_hist_batch = torch.cat([b["next_state"] for b in batch])
    next_actions_batch = torch.cat([b["next_actions"] for b in batch])
    
    action_indices = torch.tensor([b["action_index"] for b in batch]).unsqueeze(1)
    rewards = torch.tensor([b["reward"] for b in batch], dtype=torch.float32)
    dones = torch.tensor([b["done"] for b in batch], dtype=torch.bool)

    q_values = model(hist_batch, actions_batch)
    q_chosen = q_values.gather(1, action_indices).squeeze(1)

    with torch.no_grad():
        next_q_values = model(next_hist_batch, next_actions_batch)  
        next_q_max = next_q_values.max(dim=1)[0]
        target_q = rewards + gamma * next_q_max * (~dones)

    loss = nn.MSELoss()(q_chosen, target_q)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # print(f"\nBatch size: {len(batch)}")
    # print("Sample Q-values:", q_values[0].tolist())
    # print("Chosen Q:", q_chosen[0].item(), "Target:", target_q[0].item())
    # print("Q diff:", (target_q - q_chosen).abs().mean().item())  # ✅ Add here
    # print("Loss:", loss.item())

    return loss.item()

In [74]:
def train_model(model, dataset, num_episodes=1000, batch_size=32, gamma=0.99):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    buffer = ReplayBuffer(capacity=5000)

    for episode in range(num_episodes):
        # Random starting sequence
        chord_seq = random.choice(dataset)
        env = Environment(chord_seq)
        env.add_model(model)
        loss = None
        while not env.is_terminal():
            transition = env.step(epsilon=0.1)
            if transition and transition["next_actions"].numel() > 0:
                buffer.add(transition)

            if len(buffer) >= batch_size:
                batch = buffer.sample(batch_size)
                loss = train_on_batch(model, optimizer, batch, gamma)

        if episode % 10 == 0:
            if loss:
                print(f"Episode {episode} complete | Last loss: {loss:.4f}")

In [155]:
model = DQN(history_dim=160, action_dim=36)

In [94]:
chords_dataset = list(data['chords'])[0:10]

In [None]:
train_model(model, chords_dataset, num_episodes=1000, batch_size=32, gamma=0.99)

Episode 10 complete | Last loss: 11471.9736
Episode 20 complete | Last loss: 2012.8856
Episode 30 complete | Last loss: 584.3676
Episode 40 complete | Last loss: 680.2872
Episode 50 complete | Last loss: 242.5954
Episode 60 complete | Last loss: 186.5224
Episode 70 complete | Last loss: 1089.6772
Episode 80 complete | Last loss: 409.9654


In [None]:
env.add_model(model)

In [None]:
model = DQN(300,36)

In [154]:
import torch
import torch.nn as nn

class DQN(nn.Module):
    def __init__(self, history_dim, action_dim, hidden_dim=128):
        super(DQN, self).__init__()

        self.history_encoder = nn.Sequential(
            nn.Linear(history_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.scorer = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # Output one Q-value
        )

    def forward(self, history_vec, actions_vecs):
        """
        Args:
            history_vec: Tensor (batch_size, history_dim)
            actions_vecs: Tensor (batch_size, num_actions, action_dim)

        Returns:
            q_values: Tensor (batch_size, num_actions)
        """
        batch_size, num_actions, action_dim = actions_vecs.shape

        # Encode the history once
        history_encoded = self.history_encoder(history_vec)  # (batch_size, hidden_dim)
        # Flatten actions and encode them
        actions_flat = actions_vecs.view(-1, action_dim)  # (batch_size * num_actions, action_dim)
        actions_encoded = self.action_encoder(actions_flat)  # (batch_size * num_actions, hidden_dim)
        actions_encoded = actions_encoded.view(batch_size, num_actions, -1)  # (batch_size, num_actions, hidden_dim)

        # Repeat history for each action
        history_repeated = history_encoded.unsqueeze(1).repeat(1, num_actions, 1)  # (batch_size, num_actions, hidden_dim)

        # Concatenate encoded history and actions
        concat = torch.cat([history_repeated, actions_encoded], dim=-1)  # (batch_size, num_actions, hidden_dim*2)

        # Score each (history, action) pair
        q_values = self.scorer(concat).squeeze(-1)  # (batch_size, num_actions)

        return q_values
