In [4]:
from RewritePuzzleLogic import ExpressionNode

expr1 = ExpressionNode(
    operator='+',
    left=ExpressionNode(value=1),
    right=ExpressionNode(
        operator='*',
        left=ExpressionNode(value=2),
        right=ExpressionNode(value=3)
    )
)

expr2 = ExpressionNode(
    operator='+',
    left=ExpressionNode(value=1),
    right=ExpressionNode(value=2)
)

# Compare with a different expr
expr3 = ExpressionNode(
    operator='+',
    left=ExpressionNode(value=1),
    right=ExpressionNode(
        operator='*',
        left=ExpressionNode(value=3),
        right=ExpressionNode(value=2)
    )
)

# The reason you are getting [(<ExpressionNode object>, None, [])] as the output is that:
# - find_matches returns a list of tuples: (node_matched, parent, path_to_node)
# - In this case, expr1.right is the (2 * 3) subtree, which matches the pattern you created.
# - So the method returns [(expr1.right, None, [])], meaning: the whole expr1.right root node matches, its parent is None (since it's the root in this call), and the path is an empty list (you're at the node itself).
# To make the result more readable, you can pretty-print the outputs:
matches = expr1.find_matches_strict(expr1)
print(matches[0][0])

(1 + (2 * 3))


In [1]:
from RewritePuzzleLogic import ExpressionNode

expr1 = ExpressionNode(
    operator='+',
    left=ExpressionNode(value=1),
    right=ExpressionNode(
        operator='*',
        left=ExpressionNode(value=2),
        right=ExpressionNode(value=3)
    )
)

expr2 = ExpressionNode(
    operator='+',
    left=ExpressionNode(value=1),
    right=ExpressionNode(
        operator='*',
        left=ExpressionNode(value=2),
        right=ExpressionNode(value=3)
    )
)

# Compare the two expressions
print(expr1 == expr2)  # Should print True

# Compare with a different expr
expr3 = ExpressionNode(
    operator='+',
    left=ExpressionNode(value=1),
    right=ExpressionNode(
        operator='*',
        left=ExpressionNode(value=3),
        right=ExpressionNode(value=2)
    )
)
# print(expr1 == expr3)  # Should print False
# for subexpr in expr1.get_all_subexpressions():
#     print(subexpr)

# To match an a * b pattern, we need to build a pattern ExpressionNode with operator '*' and two leaves as wildcards.
# Let's create a simple pattern with any two subtrees; in simplest form for this logic, just two leaf nodes (values can be left None).
# To match expr1.right (which is (2 * 3)), use a pattern for '*' with two wildcard leaves.
pattern = ExpressionNode(operator='*', left=ExpressionNode(), right=ExpressionNode())

# The reason you are getting [(<ExpressionNode object>, None, [])] as the output is that:
# - find_matches returns a list of tuples: (node_matched, parent, path_to_node)
# - In this case, expr1.right is the (2 * 3) subtree, which matches the pattern you created.
# - So the method returns [(expr1.right, None, [])], meaning: the whole expr1.right root node matches, its parent is None (since it's the root in this call), and the path is an empty list (you're at the node itself).
# To make the result more readable, you can pretty-print the outputs:
matches = expr1.find_matches('b + 0')
for node, parent, path in matches:
    print(f"Matched node: {node}, parent: {parent}, path: {path}")

True


In [7]:
from RewritePuzzleLogic import RewriteRule

rule2 = RewriteRule(
    name="commute_add2",
    pattern="a + b",
    replacement="b + a",
    is_commutative=True
)

print(rule2.matches(expr1))

True


In [None]:
from RewritePuzzleLogic import RewriteRule
# Example set of RewriteRule instances for common arithmetic rewrites:

rule1 = RewriteRule(
    name="commute_add",
    pattern="a + b",
    replacement="b + a",
    is_commutative=True
)

rule2 = RewriteRule(
    name="commute_add2",
    pattern="b + c",
    replacement="c + b",
    is_commutative=True
)

# To get the new expression after applying the rule (instead of immediately evaluating it),
# assign the result of .apply() to a variable:
new_expr = rule1.apply(expr1)
print("New expr after applying rule:", new_expr)
# If you want to see its value:
if new_expr is not None:
    print("Evaluated value:", new_expr.evaluate())

# rules = [
#     # a + 0 -> a
#     RewriteRule(
#         name="add_zero_right",
#         pattern="a + 0",
#         replacement=None
#     ),
#     # 0 + a -> a
#     RewriteRule(
#         name="add_zero_left",
#         pattern="0 + a",
#         replacement=None
#     ),
#     # a * 1 -> a
#     RewriteRule(
#         name="mult_one_right",
#         pattern="a * 1",
#         replacement=None
#     ),
#     # 1 * a -> a
#     RewriteRule(
#         name="mult_one_left",
#         pattern="1 * a",
#         replacement=None
#     ),
#     # a + b -> b + a (commutativity)
#     RewriteRule(
#         name="add_commutative",
#         pattern="a + b",
#         replacement="b + a",
#         is_commutative=True
#     ),
#     # a + (b + c) -> (a + b) + c (associativity)
#     RewriteRule(
#         name="add_associative",
#         pattern=ExpressionNode(operator='+',
#                                left=ExpressionNode(),
#                                right=ExpressionNode(operator='+', left=ExpressionNode(), right=ExpressionNode())),
#         replacement="(a + b) + c"
#     ),
#     # a * (b + c) -> a*b + a*c (distribution)
#     RewriteRule(
#         name="distribute",
#         pattern="a * (b + c)",
#         replacement="a*b + a*c"
#     ),
# ]

# # Example usage: print the names of all rules
# for rule in rules:
#     print(f"Rule: {rule.name}")

New expr after applying rule: ((2 * 3) + 1)
Evaluated value: 7


In [27]:
from RewritePuzzleLogic import RewritePuzzleBoard

# Create a simple puzzle board
board = RewritePuzzleBoard("1 + 2 * 3", 7, max_steps=10)

print("Initial expression:", board.current_expr)
print("Goal:", board.goal_expr)
print("Solved?", board.is_solved())
print("Available rules:")
for rule in board.rules:
    print("  -", rule.name)

print("\nValid actions (rule_idx, path):")
valid_actions = board.get_all_valid_actions()
for i, (rule_idx, path) in enumerate(valid_actions):
    print(f"{i}: rule={board.rules[rule_idx].name} at path {path}")

# Apply the first valid action, if any, and print the result
if valid_actions:
    rule_idx, path = valid_actions[0]
    board.apply_action(rule_idx, path)
    print("\nAfter one action:")
    print("  Expression:", board.current_expr)
    print("  Steps taken:", board.steps_taken)
    print("  Solved?", board.is_solved())
else:
    print("No valid actions available.")



Initial expression: ((1.0 + 2.0) * 3.0)
Goal: 7
Solved? False
Available rules:
  - add_zero_left
  - add_zero_right
  - mult_one_left
  - mult_one_right
  - commute_add
  - assoc_add
  - distribute

Valid actions (rule_idx, path):
0: rule=commute_add at path ['left']

After one action:
  Expression: ((2.0 + 1.0) * 3.0)
  Steps taken: 1
  Solved? False


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Pretend game with 1D board of length 8 and 5 possible actions
class DummyGame:
    def getBoardSize(self): return (8,)
    def getActionSize(self): return 5

# Minimal args object
class Args:
    num_channels = 64   # hidden width
    dropout = 0.3

from pytorch.RewritePuzzleNNet import RewritePuzzleNNet

game = DummyGame()
args = Args()
net = RewritePuzzleNNet(game, args)

# Batch of 2 board states (shape: batch_size x board_size)
# Example: numbers could be token IDs, features, etc.
boards = torch.tensor([
    [0, 1, 0, 2, 1, 0, 0, 3],
    [1, 0, 1, 0, 2, 2, 1, 0]
], dtype=torch.float32)

log_pi, v = net(boards)
print(game.getBoardSize())
print(net)
print("log_pi shape:", log_pi.shape)  # (2, 5)
print("v shape:", v.shape)            # (2, 1)
print("log_pi row 0 (per-action log-probs):", log_pi[0])
print("v row 0 (value):", v[0])


(8,)
RewritePuzzleNNet(
  (fc1): Linear(in_features=8, out_features=64, bias=True)
  (fc_bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (fc_bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=64, out_features=64, bias=True)
  (fc_bn3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc4): Linear(in_features=64, out_features=64, bias=True)
  (fc_bn4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc5): Linear(in_features=64, out_features=512, bias=True)
  (fc_bn5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc6): Linear(in_features=512, out_features=256, bias=True)
  (fc_bn6): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc_pi): Linear(in_features=256, out_features=5, bias=True

ImportError: attempted relative import with no known parent package

# Testing MCTC module
