In this notebook I plan to gain an overview of the functions contained in the LOGO domain programs beside the primitive functions in the REGAL paper.

In [None]:
import json

# Load train and test dataset
with open("external/dependencies/logo_data/python/train_200_dataset.jsonl", 'r') as f:
    train_data = [json.loads(line) for line in f]

with open("external/dependencies/logo_data/python/test_dataset.jsonl", 'r') as f:
    test_data = [json.loads(line) for line in f]

In [None]:
import ast

def extract_non_logo_expressions(data, primitives):
    """
    Extracts non-LOGO primitive expressions from the code in the 'gpt' key of the dataset.

    Args:
        data (list): List of dictionaries containing the dataset.
        primitives (list): List of LOGO primitives to exclude.

    Returns:
        list: List of unique non-LOGO primitive expressions.
    """
    expressions = set()

    def is_logo_primitive(name):
        """Check if a function name is a LOGO primitive."""
        return name in primitives

    def visit_node(node):
        """Recursively visit AST nodes to extract relevant expressions."""
        if isinstance(node, ast.Call):
            # Handle function calls, extract function names
            if isinstance(node.func, ast.Name) and not is_logo_primitive(node.func.id):
                expressions.add(node.func.id)

        elif isinstance(node, ast.For):
            # Extract 'for' loop construct
            expressions.add("for-loop")

        # Recursively visit child nodes
        for child in ast.iter_child_nodes(node):
            visit_node(child)

    # Iterate over each dictionary in the dataset
    for item in data:
        for message in item.get('messages', []):
            if message['from'] == 'gpt':
                code = message['value']
                try:
                    tree = ast.parse(code)
                    visit_node(tree)
                except SyntaxError:
                    continue

    return list(expressions)

# List containing LOGO primitives from the ReGAL paper
logo_primitives = ['forward', 'left', 'right', 'penup', 'pendown', 'teleport', 'heading', 'isdown', 'embed']

In [6]:
train_result = extract_non_logo_expressions(train_data, logo_primitives)
test_result = extract_non_logo_expressions(test_data, logo_primitives)

print(train_result)
print(test_result)

['range', 'locals', 'for-loop']
['range', 'locals', 'for-loop']
