In [1]:
from typing import List, Dict

from mathgap.logicalforms.logicalform import LogicalForm
from mathgap.natlang.templates.template import WHITESPACE
from mathgap.trees.generators import MultiGenerator
from mathgap.trees.sampling.order import OrderSampler
from mathgap.generation_util import *

from data.util import DATA_FOLDER

from mathgap.instantiate.instantiation import Instantiation
from mathgap.properties import PropertyKey
from mathgap.natlang.templates import WHITESPACE, render_problem
from mathgap.natlang.templates.templaterenderer import TemplateRenderer
from templaterendering import interleaved_template_selections, render_interleaved, traverse_reasoning_trace_multitree, interleaved_template_selection_rt, render_interleaved_rt

from generationutil import get_next_problem


In [2]:
min_depth = 2
max_depth = 5
data_folder = "data"
nr_problems = 10
num_irrelevant_trees = 3
seed = 14

In [3]:
weights_by_generator_base = {
    default_generator(use_attribute=False, use_unit=True, 
                        comp_same_entity_prob=1.0, compeq_same_entity_prob=1.0,
                        stopping_criterion=BranchDepthCriterion(i), 
                        start_types=CONT_START_TYPE, inference_rules=COMP_RULESET): 1.0
    for i in range(min_depth, max_depth+1)
}
base_generator = MultiGenerator(weights_by_generator_base)
irrelevant_generator = base_generator # draw from same distribution for now

instantiator = default_instantiator(data_folder=data_folder, dataversion="v1", leaf_min_value=2, leaf_max_value=20, 
                                    inner_min_value=2, inner_max_value=10_000)

problem_order_sampler = CanonicalOrderSampler()
template_renderer = TemplateRenderer()
ps_template_sampler, ps_answers_template_sampler, ps_renderer, rt_template_sampler, rt_renderer \
    = default_templates_and_samplers(data_folder, "v1", WHITESPACE)

In [4]:
get_next_problem(base_generator, irrelevant_generator, instantiator, num_irrelevant_trees, problem_order_sampler, ps_template_sampler, ps_renderer, rt_template_sampler, rt_renderer, template_renderer, seed=seed)

{'disconnected': {'type': 'entities',
  'simple': {'problem': [{'text': 'Mason has 18 kilograms of gold.',
     'inst': {'agents': ['Mason'],
      'attributes': [],
      'entities': ['gold'],
      'units': ['kilogram'],
      'quantities': ['18'],
      'equations': []},
     'type': 'statement',
     'relevant': True},
    {'text': ' Christian has 9 kilograms of gold more than Mason.',
     'inst': {'agents': ['Christian', 'Mason'],
      'attributes': [],
      'entities': ['gold'],
      'units': ['kilogram'],
      'quantities': ['9'],
      'equations': []},
     'type': 'statement',
     'relevant': True},
    {'text': ' Liam has 3 kilograms of gold more than Ava.',
     'inst': {'agents': ['Liam', 'Ava'],
      'attributes': [],
      'entities': ['gold'],
      'units': ['kilogram'],
      'quantities': ['3'],
      'equations': []},
     'type': 'statement',
     'relevant': False},
    {'text': ' Christian has 5 kilograms of gold fewer than James.',
     'inst': {'agents':

In [None]:


# 1. generate relevant and irrelevant trees
base_tree = base_generator.generate(seed=seed)
irrelevant_trees = [irrelevant_generator.generate(seed=seed+i+1) for i in range(num_irrelevant_trees)]

# 2. instantiate the base problem
base_instantiation = instantiator.instantiate(base_tree, seed=seed)

In [5]:
overlap = "none"#random.choice(["entities", "none"]) # TODO: support agents

# 3.a instantiate the irrelevant problems with the same entities but different agents
irrelevant_instantiations = []

taken_agents = set(base_instantiation.get_instantiations_of_type(PropertyType.AGENT).values())
taken_entities = set(base_instantiation.get_instantiations_of_type(PropertyType.ENTITY).values())

for irrelevant_tree in irrelevant_trees:
    irrelevant_instantiation = Instantiation({})

    # avoid agent overlap by adding all agents as pseudo-entries to the instantiation
    for i,agent_name in enumerate(taken_agents, start=1):
        irrelevant_instantiation[PropertyKey(PropertyType.AGENT, -i)] = agent_name
    
    # avoid entity overlap by adding all entities as pseudo-entries to the instantiation
    for i,entity_name in enumerate(taken_entities, start=1):
        irrelevant_instantiation[PropertyKey(PropertyType.ENTITY, -i)] = entity_name

    if overlap == "entities":
        base_entities = [base_instantiation[PropertyKey(PropertyType.ENTITY, pid)] for pid in base_tree.property_tracker.get_by_type(PropertyType.ENTITY)]
        for p in irrelevant_tree.property_tracker.get_by_type(PropertyType.ENTITY):
            irrelevant_instantiation[PropertyKey(PropertyType.ENTITY, p)] = random.choice(base_entities)
    elif overlap == "agents":
        raise NotImplementedError("need to pick agents from original tree but entities that are related to their respective items")
        
    # finish instantiating
    irrelevant_instantiation = instantiator.instantiate(irrelevant_tree, irrelevant_instantiation, skip_existing=True, seed=seed)
    irrelevant_instantiations.append(irrelevant_instantiation)
    
    taken_agents = taken_agents.union(irrelevant_instantiation.get_instantiations_of_type(PropertyType.AGENT).values())
    taken_entities = taken_entities.union(irrelevant_instantiation.get_instantiations_of_type(PropertyType.ENTITY).values())


In [6]:
# 4. render the trees in an interleaved manner (likely best to uniformly distribute the axioms)
# 4.1 simple case
irrelevant_tree, irrelevant_instantiation = random.choice(list(zip(irrelevant_trees, irrelevant_instantiations)))
problem_order_base = problem_order_sampler.sample_order(base_tree, seed=seed)
problem_base,problem_base_meta = render_problem(base_tree, base_instantiation, problem_order_base, ps_template_sampler, ps_renderer, seed=seed)

problem_order_irrelevant = problem_order_sampler.sample_order(irrelevant_tree, seed=seed)
problem_order_irrelevant.body_node_ids = [random.choice(problem_order_irrelevant.body_node_ids)]

interleaved_ts = interleaved_template_selections([
    (base_tree, base_instantiation, problem_order_base, ps_template_sampler, True),
    (irrelevant_tree, irrelevant_instantiation, problem_order_irrelevant, ps_template_sampler, False)
], seed=seed)

order = [(tree,ts.primary_node_id) for tree,inst,ts in interleaved_ts]
rt_interleaved, rt_interleaved_meta = render_interleaved_rt(template_renderer, interleaved_template_selection_rt([base_tree, irrelevant_tree], [base_instantiation, irrelevant_instantiation], order, rt_template_sampler.sampler, seed=seed), eods_separator=WHITESPACE) 

problem_interleaved,problem_interleaved_meta = render_interleaved(template_renderer, interleaved_ts)
print(problem_base)
print(problem_interleaved)
print(rt_interleaved)

Mason has 18 kilograms of gold. Christian has 9 kilograms of gold more than Mason. Christian has 5 kilograms of gold fewer than James. How many kilograms of gold does James have?
Mason has 18 kilograms of gold. Christian has 9 kilograms of gold more than Mason. Mila has 4 grams of butter less than Grace. Christian has 5 kilograms of gold fewer than James. How many kilograms of gold does James have?
Mason has 18 kilograms of gold. Christian has 9 kilograms of gold more than Mason. So Christian has 18 + 9 = 27 kilograms of gold. Mila has 4 grams of butter fewer than Grace. Christian has 5 kilograms of gold less than James. So James has 27 + 5 = 32 kilograms of gold. 


In [7]:
# 4.2 complex case
irrelevant_tree, irrelevant_instantiation = random.choice(list(zip(irrelevant_trees, irrelevant_instantiations)))
problem_order_base = problem_order_sampler.sample_order(base_tree, seed=seed)
problem_base,problem_base_meta = render_problem(base_tree, base_instantiation, problem_order_base, ps_template_sampler, ps_renderer, seed=seed)

problem_order_irrelevant = problem_order_sampler.sample_order(irrelevant_tree, seed=seed)

interleaved_ts = interleaved_template_selections([
    (base_tree, base_instantiation, problem_order_base, ps_template_sampler, True),
    (irrelevant_tree, irrelevant_instantiation, problem_order_irrelevant, ps_template_sampler, False)
], seed=seed)

order = [(tree,ts.primary_node_id) for tree,inst,ts in interleaved_ts]
rt_interleaved, rt_interleaved_meta = render_interleaved_rt(template_renderer, interleaved_template_selection_rt([base_tree, irrelevant_tree], [base_instantiation, irrelevant_instantiation], order, rt_template_sampler.sampler, seed=seed), eods_separator=WHITESPACE) 

problem_interleaved,problem_interleaved_meta = render_interleaved(template_renderer, interleaved_ts)
print(problem_base)
print(problem_interleaved)
print(rt_interleaved)

Mason has 18 kilograms of gold. Christian has 9 kilograms of gold more than Mason. Christian has 5 kilograms of gold fewer than James. How many kilograms of gold does James have?
Mason has 18 kilograms of gold. Christian has 9 kilograms of gold more than Mason. Harper has 19 liters of milk. Harper has 10 liters of milk more than Elijah. Logan has 7 liters of milk fewer than Elijah. Christian has 5 kilograms of gold fewer than James. Alexander has 6 liters of milk more than Logan. How many kilograms of gold does James have?
Mason has 18 kilograms of gold. Christian has 9 kilograms of gold more than Mason. So Christian has 18 + 9 = 27 kilograms of gold. Harper has 19 liters of milk. Harper has 10 liters of milk more than Elijah. So Elijah has 19 - 10 = 9 liters of milk. Logan has 7 liters of milk less than Elijah. So Logan has 9 - 7 = 2 liters of milk. Christian has 5 kilograms of gold less than James. So James has 27 + 5 = 32 kilograms of gold. Alexander has 6 liters of milk more than L

In [8]:
# 4.3 more complex case
problem_order_base = problem_order_sampler.sample_order(base_tree, seed=seed)
problem_base,problem_base_meta = render_problem(base_tree, base_instantiation, problem_order_base, ps_template_sampler, ps_renderer, seed=seed)

interleaved_ts = interleaved_template_selections([
    (base_tree, base_instantiation, problem_order_base, ps_template_sampler, True),
    *[(it, ii, problem_order_sampler.sample_order(it, seed=seed), ps_template_sampler, False) for it,ii in zip(irrelevant_trees, irrelevant_instantiations)]
], seed=seed)

problem_interleaved,problem_interleaved_meta = render_interleaved(template_renderer, interleaved_ts)

order = [(tree,ts.primary_node_id) for tree,inst,ts in interleaved_ts]
all_trees = [base_tree, *irrelevant_trees]
all_instantiations = [base_instantiation, *irrelevant_instantiations]
rt_interleaved, rt_interleaved_meta = render_interleaved_rt(template_renderer, interleaved_template_selection_rt(all_trees, all_instantiations, order, rt_template_sampler.sampler, seed=seed), eods_separator=WHITESPACE) 

print(problem_base)
print(problem_interleaved)
print(rt_interleaved)

Mason has 18 kilograms of gold. Christian has 9 kilograms of gold more than Mason. Christian has 5 kilograms of gold fewer than James. How many kilograms of gold does James have?
Mason has 18 kilograms of gold. Charlotte has 18 grams of butter. Harper has 19 liters of milk. Harper has 10 liters of milk more than Elijah. Logan has 7 liters of milk fewer than Elijah. Christian has 9 kilograms of gold more than Mason. Evelyn has 10 acres of farmland. Alexander has 6 liters of milk more than Logan. Nicholas has 7 acres of farmland less than Evelyn. Nicholas has 19 acres of farmland fewer than Amelia. Hannah has 7 acres of farmland fewer than Amelia. Christian has 5 kilograms of gold fewer than James. Mia has 6 acres of farmland more than Hannah. Charlotte has 8 grams of butter more than Grace. Mila has 4 grams of butter fewer than Grace. Mila has 12 grams of butter fewer than Natalie. Liam has 12 grams of butter less than Natalie. Liam has 3 grams of butter more than Ava. How many kilogram

In [10]:
extract_per_sent_data(rt_interleaved, rt_interleaved_meta, base_tree, irrelevant_trees, base_instantiation, irrelevant_instantiations)

[{'text': 'Mason has 18 kilograms of gold.',
  'inst': {'agents': ['Mason'],
   'attributes': [],
   'entities': ['gold'],
   'units': ['kilogram'],
   'quantities': ['18'],
   'equations': []},
  'type': 'statement',
  'relevant': True},
 {'text': ' Charlotte has 18 grams of butter.',
  'inst': {'agents': ['Charlotte'],
   'attributes': [],
   'entities': ['butter'],
   'units': ['gram'],
   'quantities': ['18'],
   'equations': []},
  'type': 'statement',
  'relevant': False},
 {'text': ' Harper has 19 liters of milk.',
  'inst': {'agents': ['Harper'],
   'attributes': [],
   'entities': ['milk'],
   'units': ['liter'],
   'quantities': ['19'],
   'equations': []},
  'type': 'statement',
  'relevant': False},
 {'text': ' Harper has 10 liters of milk more than Elijah.',
  'inst': {'agents': ['Harper', 'Elijah'],
   'attributes': [],
   'entities': ['milk'],
   'units': ['liter'],
   'quantities': ['10'],
   'equations': []},
  'type': 'statement',
  'relevant': False},
 {'text': ' So 