# Forward chaining

## Explanation

This section is an explanation of the system you'll be working with. There aren't any problems to solve. Read it carefully anyway.

This problem set will make use of a *production rule system*. The system is given a list of rules and a list of data. The rules look for certain things in the data -- these things are the *antecedents* of the rules -- and usually produce a new piece of data, called the *consequent*. Rules can also delete existing data.

Importantly, rules can contain variables. This allows a rule to match more than one possible datum. The consequent can contain variables that were bound in the antecedent.

A rule is an expression that contains certain keywords, like IF, THEN, AND, OR, and NOT. An example of a rule looks like this:

     IF( AND( 'parent (?x) (?y)',
              'parent (?x) (?z)' ),
         THEN( 'sibling (?y) (?z)' ))

This could be taken to mean:

   If *x* is the parent of *y*, and *x* is the parent of *z*, then *y* is the sibling of *z*.
   
Given data that look like '`parent marge bart`' and '`parent marge lisa`', then, it will produce further data like `'sibling bart lisa`'. (It will also produce '`sibling bart bart`', which is something that will need to be dealt with.)

Of course, the rule system doesn't know what these arbitrary words "parent" and "sibling" mean! It doesn't even care that they're at the beginning of the expression. The rule could also be written like this:

    IF (AND( '(?x) is a parent of (?y)',
             '(?x) is a parent of (?z)' ),
        THEN( '(?y) is a sibling of (?z)' ))

Then it will expect its data to look like '`marge is a parent of lisa`'. This gets wordy and
involves some unnecessary matching of symbols like '`is`' and '`a`', and it doesn't help anything for
this problem, but we'll write some later rule systems in this English-like way for clarity.

Just remember that the English is for you to understand, not the computer.
    
## Rule expressions

Here's a more complete description of how the system works.

The rules are given in a specified order, and the system will check each rule in turn: for each rule, it will go through all the data searching for matches to that rule's antecedent, before moving on to the next rule.

A rule is an expression that can have an IF antecedent and a THEN consequent. Both of these parts are required. Optionally, a rule can also have a DELETE clause, which specifies some data to delete.

The IF antecedent can contain AND, OR, and NOT expressions. AND requires that multiple statements
are matched in the dataset, OR requires that one of multiple statements are matched in the dataset, and
NOT requires that a statement is *not* matched in the dataset. AND, OR, and NOT expressions can be
nested within each other. When nested like this, these expressions form an AND/OR tree (or really an AND/OR/NOT tree). At the bottom of this tree are strings, possibly with variables in them.

The data are searched for items that match the requirements of the antecedent. Data items that appear earlier in the data take precedence. Each pattern in an AND clause will match the data in order, so that later ones have the variables of the earlier ones.

If there is a NOT clause in the antecedent, the data are searched to make sure that no items in the data match the pattern. A NOT clause should not introduce new variables - the matcher won't know what to do with them. Generally, NOT clauses will appear inside an AND clause, and earlier parts of the AND clause will introduce the variables. For example, this clause will match objects that are asserted to be birds, but are not asserted to be penguins:

    AND( '(?x) is a bird',
         NOT( '(?x) is a penguin' ))

The other way around won't work:
        
    AND( NOT( '(?x) is a penguin' ), # don't do this!
         '(?x) is a bird' )
         
The terms **match** and **fire** are important. A rule **matches** if its antecedent matches the existing data. A rule that matches can **fire** if its THEN or DELETE clauses change the data. (Otherwise, it fails to fire.)

Only one rule can fire at a time. When a rule successfully fires, the system changes the data appropriately, and then starts again from the first rule. This lets earlier rules take precedence over later ones. (In other systems, the precedence order of rules can be defined differently.)

## Running the system

The procedure
`forward_chain(rules, data, verbose=False)` will make inferences as described
above. It returns the final state of its input data.


# Utils

In [None]:
from collections.abc import MutableMapping
import re

class ClobberedDictKey(Exception):
    "A flag that a variable has been assigned two incompatible values."
    pass

class NoClobberDict(MutableMapping):
    """
    A dictionary-like object that prevents its values from being
    overwritten by different values. If that happens, it indicates a
    failure to match.
    """
    def __init__(self, initial_dict = None):
        if initial_dict == None:
            self._dict = {}
        else:
            self._dict = dict(initial_dict)

    def __getitem__(self, key):
        return self._dict[key]

    def __setitem__(self, key, value):
        if key in self._dict and self._dict[key] != value:
            raise ClobberedDictKey(key, value)

        self._dict[key] = value

    def __delitem__(self, key):
        del self._dict[key]

    def __contains__(self, key):
        return self._dict.__contains__(key)

    def __iter__(self):
        return self._dict.__iter__()

    def __len__(self):
        return self._dict.__len__()

    def iteritems(self):
        return self._dict.iteritems()

    def keys(self):
        return self._dict.keys()

# A regular expression for finding variables.
AIRegex = re.compile(r'\(\?(\S+)\)')

def AIStringToRegex(AIStr):
    return AIRegex.sub( r'(?P<\1>\\S+)', AIStr )+'$'

def AIStringToPyTemplate(AIStr):
    return AIRegex.sub( r'%(\1)s', AIStr )

def AIStringVars(AIStr):
    # This is not the fastest way of doing things, but
    # it is probably the most explicit and robust
    return set([ AIRegex.sub(r'\1', x) for x in AIRegex.findall(AIStr) ])

  """Entry point for launching an IPython kernel.


### Production


In [None]:
import re
try:
    set()
except NameError:
    from sets import Set as set, ImmutableSet as frozenset

try:
    sorted([])
except NameError:
    def sorted(lst):
        new_lst = list(lst)
        new_lst.sort()
        return new_lst


### We've tried to keep the functions you will need for
### back-chaining at the top of this file. Keep in mind that you
### can get at this documentation from a Python prompt:
###
### >>> import production
### >>> help(production)

def forward_chain(rules, data, apply_only_one=False, verbose=False):
    """
    Apply a list of IF-expressions (rules) through a set of data
    in order.  Return the modified data set that results from the
    rules.

    Set apply_only_one=True to get the behavior we describe in
    class.  When it's False, a rule that fires will do so for
    _all_ possible bindings of its variables at the same time,
    making the code considerably more efficient. In the end, only
    DELETE rules will act differently.
    """
    old_data = ()

    while set(old_data) != set(data):
        old_data = list(data)
        for condition in rules:
            data = condition.apply(data, apply_only_one, verbose)
            if set(data) != set(old_data):
                break

    return data

def instantiate(template, values_dict):
    """
    Given an expression ('template') with variables in it,
    replace those variables with values from values_dict.

    For example:
    >>> instantiate("sister (?x) {?y)", {'x': 'Lisa', 'y': 'Bart'})
    => "sister Lisa Bart"
    """
    if (isinstance(template, AND) or isinstance(template, OR) or
        isinstance(template, NOT)):

        return template.__class__(*[populate(x, values_dict)
                                    for x in template])
    elif isinstance(template, str):
        return AIStringToPyTemplate(template) % values_dict
    else: raise ValueError("Don't know how to populate a %s" % \
      type(template))

# alternate name for instantiate
populate = instantiate

def match(template, AIStr):
    """
    Given two strings, 'template': a string containing variables
    of the form '(?x)', and 'AIStr': a string that 'template'
    matches, with certain variable substitutions.

    Returns a dictionary of the set of variables that would need
    to be substituted into template in order to make it equal to
    AIStr, or None if no such set exists.
    """
    try:
        return re.match( AIStringToRegex(template),
                         AIStr ).groupdict()
    except AttributeError: # The re.match() expression probably
                           # just returned None
        return None

def is_variable(myStr):
    """Is 'myStr' a variable, of the form '(?x)'?"""
    return isinstance(myStr, str) and myStr[0] == '(' and \
      myStr[-1] == ')' and re.search( AIStringToRegex(myStr) )

def variables(exp):
    """
    Return a dictionary containing the names of all variables in
    'exp' as keys, or None if there are no such variables.
    """
    try:
        return re.search( AIStringToRegex(exp).groupdict() )
    except AttributeError: # The re.match() expression probably
                           # just returned None
        return None

class IF(object):
    """
    A conditional rule.

    This should have the form IF( antecedent, THEN(consequent) ),
    or IF( antecedent, THEN(consequent), DELETE(delete_clause) ).

    The antecedent is an expression or AND/OR tree with variables
    in it, determining under what conditions the rule can fire.

    The consequent is an expression or list of expressions that
    will be added when the rule fires. Variables can be filled in
    from the antecedent.

    The delete_clause is an expression or list of expressions
    that will be deleted when the rule fires. Again, variables
    can be filled in from the antecedent.
    """
    def __init__(self, conditional, action = None,
                 delete_clause = ()):
        # Deal with an edge case imposed by type_encode()
        if type(conditional) == list and action == None:
            return apply(self.__init__, conditional)

        # Allow 'action' to be either a single string or an
        # iterable list of strings
        if isinstance(action, str):
            action = [ action ]

        self._conditional = conditional
        self._action = action
        self._delete_clause = delete_clause

    def apply(self, rules, apply_only_one=False, verbose=False):
        """
        Return a new set of data updated by the conditions and
        actions of this IF statement.

        If 'apply_only_one' is True, after adding one datum,
        return immediately instead of continuing. This is the
        behavior described in class, but it is slower.
        """
        new_rules = set(rules)
        old_rules_count = len(new_rules)
        bindings = RuleExpression().test_term_matches(
            self._conditional, new_rules)

        for k in bindings:
            for a in self._action:
                new_rules.add( populate(a, k) )
                if len(new_rules) != old_rules_count:
                    if verbose:
                        print("Rule:", self)
                        print("Added:", populate(a, k))
                    if apply_only_one:
                        return tuple(sorted(new_rules))
            for d in self._delete_clause:
                try:
                    new_rules.remove( populate(d, k) )
                    if len(new_rules) != old_rules_count:
                        if verbose:
                            print("Rule:", self)
                            print("Deleted:", populate(d, k))
                        if apply_only_one:
                            return tuple(sorted(new_rules))
                except KeyError:
                    pass

        return tuple(sorted(new_rules)) # Uniquify and sort the
                                        # output list


    def __str__(self):
        return "IF(%s, %s)" % (str(self._conditional),
                               str(self._action))

    def antecedent(self):
        return self._conditional

    def consequent(self):
        return self._action

    __repr__ = __str__

class RuleExpression(list):
    """
    The parent class of AND, OR, and NOT expressions.

    Just like Sums and Products from lab 0, RuleExpressions act
    like lists wherever possible. For convenience, you can leave
    out the brackets when initializing them: AND([1, 2, 3]) ==
    AND(1, 2, 3).
    """
    def __init__(self, *args):
        if (len(args) == 1 and isinstance(args[0], list)
            and not isinstance(args[0], RuleExpression)):
            args = args[0]
        list.__init__(self, args)

    def conditions(self):
        """
        Return the conditions contained by this
        RuleExpression. This is the same as converting it to a
        list.
        """
        return list(self)

    def __str__(self):
        return '%s(%s)' % (self.__class__.__name__,
                           ', '.join([repr(x) for x in self]) )

    __repr__ = __str__

    def test_term_matches(self, condition, rules,
                          context_so_far = None):
        """
        Given an expression which might be just a string, check
        it against the rules.
        """
        rules = set(rules)
        if context_so_far == None: context_so_far = {}

        # Deal with nesting first If we're a nested term, we
        # already have a test function; use it
        if not isinstance(condition, str):
            return condition.test_matches(rules, context_so_far)

        # Hm; no convenient test function here
        else:
            return self.basecase_bindings(condition,
                                          rules, context_so_far)

    def basecase_bindings(self, condition, rules, context_so_far):
        for rule in rules:
            bindings = match(condition, rule)
            if bindings is None: continue
            try:
                context = NoClobberDict(context_so_far)
                context.update(bindings)
                yield context
            except ClobberedDictKey:
                pass

    def get_condition_vars(self):
        if hasattr(self, '_condition_vars'):
            return self._condition_vars

        condition_vars = set()

        for condition in self:
            if isinstance(condition, RuleExpression):
                condition_vars |= condition.get_condition_vars()
            else:
                condition_vars |= AIStringVars(condition)

        return condition_vars

    def test_matches(self, rules):
        raise NotImplementedError

    def __eq__(self, other):
        return type(self) == type(other) and list.__eq__(self, other)

    def __hash__(self):
        return hash((self.__class__.__name__, list(self)))

class AND(RuleExpression):
    """A conjunction of patterns, all of which must match."""
    class FailMatchException(Exception):
        pass

    def test_matches(self, rules, context_so_far = {}):
        return self._test_matches_iter(rules, list(self))

    def _test_matches_iter(self, rules, conditions = None,
                           cumulative_dict = None):
        """
        Recursively generate all possible matches.
        """
        # Set default values for variables.  We can't set these
        # in the function header because values defined there are
        # class-local, and we need these to be reinitialized on
        # each function call.
        if cumulative_dict == None:
            cumulative_dict = NoClobberDict()

        # If we have no more conditions to analyze, pass the
        # dictionary that we've accumulated back up the
        # function-call stack.
        if len(conditions) == 0:
            yield cumulative_dict
            return

        # Recursive Case
        condition = conditions[0]
        for bindings in self.test_term_matches(condition, rules,
                                               cumulative_dict):
            bindings = NoClobberDict(bindings)

            try:
                bindings.update(cumulative_dict)
                for bindings2 in self._test_matches_iter(rules,
                  conditions[1:], bindings):
                    yield bindings2
            except ClobberedDictKey:
                pass


class OR(RuleExpression):
    """A disjunction of patterns, one of which must match."""
    def test_matches(self, rules, context_so_far = {}):
        for condition in self:
            for bindings in self.test_term_matches(condition, rules):
                yield bindings

class NOT(RuleExpression):
    """A RuleExpression for negation. A NOT clause must only have
    one part."""
    def test_matches(self, data, context_so_far = {}):
        assert len(self) == 1 # We're unary; we can only process
                              # one condition

        try:
            new_key = populate(self[0], context_so_far)
        except KeyError:
            new_key = self[0]

        matched = False
        for x in self.test_term_matches(new_key, data):
            matched = True

        if matched:
            return
        else:
            yield NoClobberDict()


class THEN(list):
    """
    A THEN expression is a container with no interesting semantics.
    """
    def __init__(self, *args):
        if (len(args) == 1 and isinstance(args[0], list)
            and not isinstance(args[0], RuleExpression)):
            args = args[0]
        super(list, self).__init__()
        for a in args:
            self.append(a)

    def __str__(self):
        return '%s(%s)' % (self.__class__.__name__, ', '.join([repr(x) for x in self]) )

    __repr__ = __str__


class DELETE(THEN):
    """
    A DELETE expression is a container with no interesting
    semantics. That's why it's exactly the same as THEN.
    """
    pass

def uniq(lst):
    """
    this is like list(set(lst)) except that it gets around
    unhashability by stringifying everything.  If str(a) ==
    str(b) then this will get rid of one of them.
    """
    seen = {}
    result = []
    for item in lst:
        if str(item) not in seen:
            result.append(item)
            seen[str(item)]=True
    return result

def simplify(node):
    """
    Given an AND/OR tree, reduce it to a canonical, simplified
    form, as described in the lab.

    You should do this to the expressions you produce by backward
    chaining.
    """
    if not isinstance(node, RuleExpression): return node
    branches = uniq([simplify(x) for x in node])
    if isinstance(node, AND):
        return _reduce_singletons(_simplify_and(branches))
    elif isinstance(node, OR):
        return _reduce_singletons(_simplify_or(branches))
    else: return node

def _reduce_singletons(node):
    if not isinstance(node, RuleExpression): return node
    if len(node) == 1: return node[0]
    return node

def _simplify_and(branches):
    for b in branches:
        if b == FAIL: return FAIL
    pieces = []
    for branch in branches:
        if isinstance(branch, AND): pieces.extend(branch)
        else: pieces.append(branch)
    return AND(*pieces)

def _simplify_or(branches):
    for b in branches:
        if b == PASS: return PASS
    pieces = []
    for branch in branches:
        if isinstance(branch, OR): pieces.extend(branch)
        else: pieces.append(branch)
    return OR(*pieces)

PASS = AND()
FAIL = OR()
run_conditions = forward_chain

# Basic rule system

Here's an example of using it with a very simple rule system:

In [None]:
copy_rule = IF( 'you have (?x)',
                  THEN( 'i have (?x)' ),)


theft_rule = IF( 'you have (?x)',
                  THEN( 'i have (?x)' ),
                  DELETE( 'you have (?x)' ))

data = ( 'you have apple',
         'you have orange',
         'you have pear',
        'i have red sneakers' )
print(forward_chain([copy_rule], data, verbose=True))

# Try theft rule by uncommenting below
# print(forward_chain([theft_rule], data, verbose=True))

Rule: IF(you have (?x), THEN('i have (?x)'))
Added: i have apple
Rule: IF(you have (?x), THEN('i have (?x)'))
Added: i have pear
Rule: IF(you have (?x), THEN('i have (?x)'))
Added: i have orange
('i have apple', 'i have orange', 'i have pear', 'i have red sneakers', 'you have apple', 'you have orange', 'you have pear')


We provide the system with a list containing a single rule, called `theft_rule`, which replaces a
datum like '`you have apple`' with '`i have apple`'. Given the three items of data, it will
replace each of them in turn.

NOTE: The `Rule:` and `Added:` lines come from the verbose printing. The final output is the set of
assertions after applying the forward chaining procedure.

### Zookeeper
A much larger example is `zookeeper`, which classifies animals based on their characteristics.

In [None]:
## ZOOKEEPER RULES
ZOOKEEPER_RULES = (

    IF( AND( '(?x) has hair' ),         # Z1
        THEN( '(?x) is a mammal' )),

    IF( AND( '(?x) gives milk' ),       # Z2
        THEN( '(?x) is a mammal' )),

    IF( AND( '(?x) has feathers' ),     # Z3
        THEN( '(?x) is a bird' )),

    IF( AND( '(?x) flies',              # Z4
             '(?x) lays eggs' ),
        THEN( '(?x) is a bird' )),

    IF( AND( '(?x) is a mammal',        # Z5
             '(?x) eats meat' ),
        THEN( '(?x) is a carnivore' )),

    IF( AND( '(?x) is a mammal',        # Z6
             '(?x) has pointed teeth',
             '(?x) has claws',
             '(?x) has forward-pointing eyes' ),
        THEN( '(?x) is a carnivore' )),

    IF( AND( '(?x) is a mammal',        # Z7
             '(?x) has hoofs' ),
        THEN( '(?x) is an ungulate' )),

    IF( AND( '(?x) is a mammal',        # Z8
             '(?x) chews cud' ),
        THEN( '(?x) is an ungulate' )),

    IF( AND( '(?x) is a carnivore',     # Z9
             '(?x) has tawny color',
             '(?x) has dark spots' ),
        THEN( '(?x) is a cheetah' )),

    IF( AND( '(?x) is a carnivore',     # Z10
             '(?x) has tawny color',
             '(?x) has black stripes' ),
        THEN( '(?x) is a tiger' )),

    IF( AND( '(?x) is an ungulate',     # Z11
             '(?x) has long legs',
             '(?x) has long neck',
             '(?x) has tawny color',
             '(?x) has dark spots' ),
        THEN( '(?x) is a giraffe' )),

    IF( AND( '(?x) is an ungulate',     # Z12
             '(?x) has white color',
             '(?x) has black stripes' ),
        THEN( '(?x) is a zebra' )),

    IF( AND( '(?x) is a bird',          # Z13
             '(?x) does not fly',
             '(?x) has long legs',
             '(?x) has long neck',
             '(?x) has black and white color' ),
        THEN( '(?x) is an ostrich' )),

    IF( AND( '(?x) is a bird',          # Z14
             '(?x) does not fly',
             '(?x) swims',
             '(?x) has black and white color' ),
        THEN( '(?x) is a penguin' )),

    IF( AND( '(?x) is a bird',        # Z15
             '(?x) is a good flyer' ),
        THEN( '(?x) is an albatross' )),

    )



ZOO_DATA = (
    'tim has feathers',
    'tim is a good flyer',
    'mark flies',
    'mark does not fly',
    'mark lays eggs',
    'mark swims',
    'mark has black and white color',
    )

In [None]:
# Run forward chaining using ZOO_DATA and ZOO_RULES below


## Rule systems Exercises

### Poker Hands



We can use a production system to rank types of poker hands against each other. If we tell it the basic things like '`three-of-a-kind beats two-pair`' and '`two-pair beats pair`', it should be able to deduce by transitivity that '`three-of-a-kind beats pair`'.

Write a one-rule system that ranks poker hands (or anything else, really) transitively, given some of the
rankings already. The rankings will all be provided in the form '`(?x) beats (?y)`'.

Call the one rule you write `transitive-rule`, so that your list of rules is `[ transitive-rule ]`.

Just for this problem, it is okay if your transitive rule adds '`X beats X`', even though in real-life
transitivity may not always imply reflexivity.

In [None]:
# Poker hands


# You're given this data about poker hands:
poker_data = ( 'two-pair beats pair',
               'three-of-a-kind beats two-pair',
               'straight beats three-of-a-kind',
               'flush beats straight',
               'full-house beats flush',
               'straight-flush beats full-house' )

# Fill in this rule so that it finds all other combinations of
# which poker hands beat which, transitively. For example, it
# should be able to deduce that a three-of-a-kind beats a pair,
# because a three-of-a-kind beats two-pair, which beats a pair.
transitive_rule = IF( AND(), THEN() )

# You can test your rule like this:
# print(forward_chain([transitive_rule], poker_data))

# Here's some other data sets for the rule. The tester uses
# these, so don't change them.

abc_data = [ 'a beats b', 'b beats c' ]
TEST_RESULTS_TRANS1 = forward_chain([transitive_rule],
                                    abc_data)
rock_paper_scissors_data = [ 'rock beats scissors',
    'scissors beats paper',
    'paper beats rock' ]
TEST_RESULTS_TRANS2 = forward_chain([transitive_rule], rock_paper_scissors_data)

### Rock Paper Scissors

In [None]:
# Write a transitive rule that is still able to find all combinations of which poker hands beat which,
# but also does not create information such as 'rock beats rock' or 'rock beats paper' in the rock-paper-scissors game
rock_paper_scissors_data = [ 'rock beats scissors',
    'scissors beats paper',
    'paper beats rock' ]

improved_transitive_rule = IF( AND(), THEN() )

# You can test your improved transitive rule like this:
# print(forward_chain([improved_transitive_rule], rock_paper_scissors_data))

TEST_RESULTS_TRANS3 = forward_chain([improved_transitive_rule], rock_paper_scissors_data)

### Family relations



You will be given data that includes three kinds of statements:

*  '`male x`': *x* is male
*   '`female x`': *x* is female
*   '`parent x y`': *x* is a parent of *y*

Every person in the data set will be defined to be either male or female.

Your task is to deduce, wherever you can, the following relations:

* '`brother x y`': *x* is the brother of *y* (sharing at least one parent)
* '`sister x y`': *x* is the sister of *y* (sharing at least one parent)
* '`mother x y`': *x* is the mother of *y*
* '`father x y`': *x* is the father of *y*
* '`son x y`': *x* is the son of *y*
* '`daughter x y`': *x* is the daughter of *y*
* '`cousin x y`': *x* and *y* are cousins (a parent of *x* and a parent of *y* are siblings)
* '`uncle x y`': *x* is the uncle of *y*
* '`aunt x y`': *x* is the aunt of *y*
* '`grandparent x y`': *x* is the grandparent of *y*
* '`grandchild x y`': *x* is the grandchild of *y*

You will probably run into the problem that the system wants to conclude that everyone is his or her own sibling. To avoid this, you will probably want to write a rule that adds '`same-identity (?x) (?x)`' for every person, and make sure that potential siblings don't have `same-identity`. (Hint: You can assume that every person will be mentioned in a clause stating his gender (either male or female)). The order of the rules will matter, of course. Note that it's fine to include statements that are not any of the specified relations (such as `same-identity` or `sibling`).

Some relationships are symmetrical, and you need to include them both ways. For example, if *a* is a cousin of *b*, then *b* is a cousin of *a*.

As the answer to this problem, you should provide a list called `family-rules` that contains the rules you wrote in order, so it can be plugged into the rule system. We've given you two sets of test data: one for the Simpsons family, and one for the Black family from Harry Potter.

The code below defines `black_family_cousins` to include all the '`cousin x y`' relationships you find in the Black family. There should be 14 of them.

In [None]:
# Family relations

# First, define all your rules here individually. That is, give
# them names by assigning them to variables. This way, you'll be
# able to refer to the rules by name and easily rearrange them if
# you need to.

# Then, put them together into a list in order, and call it
# family_rules.
family_rules = [ ]                    # fill me in

# Some examples to try it on:
# Note: These are used for testing, so DO NOT CHANGE
simpsons_data = ("male bart",
                 "female lisa",
                 "female maggie",
                 "female marge",
                 "male homer",
                 "male abe",
                 "parent marge bart",
                 "parent marge lisa",
                 "parent marge maggie",
                 "parent homer bart",
                 "parent homer lisa",
                 "parent homer maggie",
                 "parent abe homer")
TEST_RESULTS_6 = forward_chain(family_rules,
                               simpsons_data,verbose=False)
# You can test your results by uncommenting this line:
# print(forward_chain(family_rules, simpsons_data, verbose=True))

black_data = ("male sirius",
              "male regulus",
              "female walburga",
              "male alphard",
              "male cygnus",
              "male pollux",
              "female bellatrix",
              "female andromeda",
              "female narcissa",
              "female nymphadora",
              "male draco",
              "parent walburga sirius",
              "parent walburga regulus",
              "parent pollux walburga",
              "parent pollux alphard",
              "parent pollux cygnus",
              "parent cygnus bellatrix",
              "parent cygnus andromeda",
              "parent cygnus narcissa",
              "parent andromeda nymphadora",
              "parent narcissa draco")

# This should generate 14 cousin relationships, representing
# 7 pairs of people who are cousins:

black_family_cousins = [
    x for x in
    forward_chain(family_rules, black_data, verbose=False)
    if "cousin" in x ]

# To see if you found them all, uncomment this line:
# print(black_family_cousins)

# To debug what happened in your rules, you can set verbose=True
# in the function call above.

# Some other data sets to try it on. The tester uses these
# results, so don't comment them out.

TEST_DATA_1 = [ 'female alice',
                'male bob',
                'male chuck',
                'parent chuck alice',
                'parent chuck bob' ]
TEST_RESULTS_1 = forward_chain(family_rules,
                               TEST_DATA_1, verbose=False)

TEST_DATA_2 = [ 'female a1', 'female b1', 'female b2',
                'female c1', 'female c2', 'female c3',
                'female c4', 'female d1', 'female d2',
                'female d3', 'female d4',
                'parent a1 b1',
                'parent a1 b2',
                'parent b1 c1',
                'parent b1 c2',
                'parent b2 c3',
                'parent b2 c4',
                'parent c1 d1',
                'parent c2 d2',
                'parent c3 d3',
                'parent c4 d4' ]

TEST_RESULTS_2 = forward_chain(family_rules,
                               TEST_DATA_2, verbose=False)

TEST_RESULTS_BLACK = forward_chain(family_rules,
                               black_data, verbose=False)

TEST_RESULTS_6 = forward_chain(family_rules,
                               simpsons_data,verbose=False)

# Tests

### Tester Code

In [None]:
from xmlrpc import client
import traceback
import sys
import os
import tarfile

try:
    from StringIO import StringIO
except ImportError:
    from io import StringIO


def test_summary(dispindex, ntests):
    return "Test %d/%d" % (dispindex, ntests)

tests = []


def show_result(testsummary, testcode, correct, got, expected, verbosity):
    """ Pretty-print test results """
    if correct:
        if verbosity > 0:
            print("%s: Correct." % testsummary)
        if verbosity > 1:
#             print('\t', testcode)
            print
    else:
        print("%s: Incorrect." % testsummary)
#         print('\t', testcode)
        print("Got:     ", got)
        print("Expected:", expected)

def show_exception(testsummary, testcode):
    """ Pretty-print exceptions (including tracebacks) """
    print("%s: Error." % testsummary)
    print("While running the following test case:")
    print('\t', testcode)
    print("Your code encountered the following error:")
    traceback.print_exc()
    print


def get_lab_module():
    # Try the easy way first
    try:
        from tests import lab_number
    except ImportError:
        lab_number = None

    if lab_number != None:
        lab = __import__('lab%s' % lab_number)
        return lab

    lab = None

    for labnum in xrange(10):
        try:
            lab = __import__('lab%s' % labnum)
        except ImportError:
            pass

    if lab == None:
        raise ImportError("Cannot find your lab; or, error importing it.  Try loading it by running 'python labN.py' (for the appropriate value of 'N').")

    if not hasattr(lab, "LAB_NUMBER"):
        lab.LAB_NUMBER = labnum

    return lab

def type_decode(arg, lab):
    """
    XMLRPC can only pass a very limited collection of types.
    Frequently, we want to pass a subclass of 'list' in as a test argument.
    We do that by converting the sub-type into a regular list of the form:
    [ 'TYPE', (data) ] (ie., AND(['x','y','z']) becomes ['AND','x','y','z']).
    This function assumes that TYPE is a valid attr of 'lab' and that TYPE's
    constructor takes a list as an argument; it uses that to reconstruct the
    original data type.
    """
    if isinstance(arg, list) and len(arg) >= 1: # We'll leave tuples reserved for some other future magic
        try:
            mytype = arg[0]
            data = arg[1:]
            return getattr(lab, mytype)([ type_decode(x, lab) for x in data ])
        except AttributeError:
            return [ type_decode(x, lab) for x in arg ]
        except TypeError:
            return [ type_decode(x, lab) for x in arg ]
    else:
        return arg


def type_encode(arg):
    """
    Encode trees as lists in a way that can be decoded by 'type_decode'
    """
    if isinstance(arg, list):
        return [ arg.__class__.__name__ ] + [ type_encode(x) for x in arg ]
    elif hasattr(arg, '__class__') and arg.__class__.__name__ == 'IF':
        return [ 'IF', type_encode(arg._conditional), type_encode(arg._action), type_encode(arg._delete_clause) ]
    else:
        return arg


def run_test(test, lab):
    """
    Takes a 'test' tuple as provided by the online tester
    (or generated by the offline tester) and executes that test,
    returning whatever output is expected (the variable that's being
    queried, the output of the function being called, etc)

    'lab' (the argument) is the module containing the lab code.

    'test' tuples are in the following format:
      'id': A unique integer identifying the test
      'type': One of 'VALUE', 'FUNCTION', 'MULTIFUNCTION', or 'FUNCTION_ENCODED_ARGS'
      'attr_name': The name of the attribute in the 'lab' module
      'args': a list of the arguments to be passed to the function; [] if no args.
      For 'MULTIFUNCTION's, a list of lists of arguments to be passed in
    """
    id, mytype, attr_name, args = test

    attr = getattr(lab, attr_name)

    if mytype == 'VALUE':
        return attr
    elif mytype == 'FUNCTION':
        try:
            return apply(attr, args)
        except NotImplementedError:
            print("NotImplementedError: You have to implement this function before we can test it!")
            return None
    elif mytype == 'MULTIFUNCTION':
        return [ run_test( (id, 'FUNCTION', attr_name, FN), lab) for FN in args ]
    elif mytype == 'FUNCTION_ENCODED_ARGS':
        return run_test( (id, 'FUNCTION', attr_name, type_decode(args, lab)), lab )
    else:
        raise Exception("Test Error: Unknown TYPE '%s'.  Please make sure you have downloaded the latest version of the tester script.  If you continue to see this error, contact a TA.")


def test_offline(verbosity=1):
    """ Run the unit tests in 'tests.py' """
#     import tests as tests_module

#     tests = [ (x[:-8],
#               getattr(tests_module, x),
#               getattr(tests_module, "%s_testanswer" % x[:-8]),
#               getattr(tests_module, "%s_expected" % x[:-8]),
#               "_".join(x[:-8].split('_')[:-1]))
#              for x in tests_module.__dict__.keys() if x[-8:] == "_getargs" ]

#     tests = tests_module.get_tests()
    global tests

    ntests = len(tests)
    ncorrect = 0

    for index, (testname, getargs, testanswer, expected, fn_name, type) in enumerate(tests):
        dispindex = index+1
        summary = test_summary(dispindex, ntests)

        try:
            if callable(getargs):
                getargs = getargs()
            if type == 'FUNCTION_ENCODED_ARGS':
                answer = fn_name(getargs[0],getargs[1])#run_test((index, type, fn_name, getargs), get_lab_module())
            else:
                answer = fn_name
        except Exception:
            show_exception(summary, testname)
            continue

        correct = testanswer(answer)
        show_result(summary, testname, correct, answer, expected, verbosity)
        if correct: ncorrect += 1

    print("Passed %d of %d tests." % (ncorrect, ntests))
#     if ncorrect == ntests:
#         print("You're done! Run 'python %s submit' to submit your code and have it graded." % sys.argv[0])
    tests = []


def get_target_upload_filedir():
    """ Get, via user prompting, the directory containing the current lab """
    cwd = os.getcwd() # Get current directory.  Play nice with Unicode pathnames, just in case.

    print("Please specify the directory containing your lab.")
    print("Note that all files from this directory will be uploaded!")
    print("Labs should not contain large amounts of data; very-large")
    print("files will fail to upload.")
    print
    print("The default path is '%s'" % cwd)
    target_dir = raw_input("[%s] >>> " % cwd)

    target_dir = target_dir.strip()
    if target_dir == '':
        target_dir = cwd

    print("Ok, using '%s'." % target_dir)

    return target_dir

def get_tarball_data(target_dir, filename):
    """ Return a binary String containing the binary data for a tarball of the specified directory """
    data = StringIO()
    file = tarfile.open(filename, "w|bz2", data)

    print("Preparing the lab directory for transmission...")

    file.add(target_dir)

    print("Done.")
    print
    print("The following files have been added:")

    for f in file.getmembers():
        print(f.name)

    file.close()

    return data.getvalue()


def test_online(verbosity=1):
    """ Run online unit tests.  Run them against the server via XMLRPC. """
    lab = get_lab_module()

    try:
        server = xmlrpclib.Server(server_url, allow_none=True)
        tests = server.get_tests(username, password, lab.__name__)
    except NotImplementedError: # Solaris Athena doesn't seem to support HTTPS
        print("Your version of Python doesn't seem to support HTTPS, for")
        print("secure test submission.  Would you like to downgrade to HTTP?")
        answer = raw_input("(Y/n) >>> ")
        if len(answer) == 0 or answer[0] in "Yy":
            server = xmlrpclib.Server(server_url.replace("https", "http"))
            tests = server.get_tests(username, password, lab.__name__)
        else:
            print("Ok, not running your tests.")
            print("Please try again on another computer.")
            print("Linux Athena computers are known to support HTTPS,")
            print("if you use the version of Python in the 'python' locker.")
            sys.exit(0)

    ntests = len(tests)
    ncorrect = 0

    lab = get_lab_module()

    target_dir = get_target_upload_filedir()

    tarball_data = get_tarball_data(target_dir, "lab%s.tar.bz2" % lab.LAB_NUMBER)

    print("Submitting to the Webserver...")

    server.submit_code(username, password, lab.__name__, xmlrpclib.Binary(tarball_data))

    print("Done submitting code.")
    print("Running test cases...")

    for index, testcode in enumerate(tests):
        dispindex = index+1
        summary = test_summary(dispindex, ntests)

        try:
            answer = run_test(testcode, get_lab_module())
        except Exception:
            show_exception(summary, testcode)
            continue

        correct, expected = server.send_answer(username, password, lab.__name__, testcode[0], type_encode(answer))
        show_result(summary, testcode, correct, answer, expected, verbosity)
        if correct: ncorrect += 1

    response = server.status(username, password, lab.__name__)
    print(response)



# if __name__ == '__main__':
#     test_offline()

def make_test_counter_decorator():

    def make_test(getargs, testanswer, expected_val, name = None, type = 'FUNCTION'):
        if name != None:
            getargs_name = name
        elif not callable(getargs):
            getargs_name = "_".join(getargs[:-8].split('_')[:-1])
            getargs = lambda: getargs
        else:
            getargs_name = "_".join(getargs.__name__[:-8].split('_')[:-1])

        tests.append( ( getargs_name,
                        getargs,
                        testanswer,
                        expected_val,
                        getargs_name,
                        type ) )

    def get_tests():
        return tests

    return make_test, get_tests


make_test, get_tests = make_test_counter_decorator()

### Transitive Rule Tests

In [None]:
### TEST 6 ###direc

transitive_rule_1_getargs = TEST_RESULTS_TRANS1

def transitive_rule_1_testanswer(val, original_val = None):
    return ( set(val)  == set([ 'a beats b',
                                'b beats c', 'a beats c' ]) )

# This test checks to make sure that your transitive rule
# produces the correct set of statements given the a/b/c data.

make_test(type = 'VALUE',
          getargs = transitive_rule_1_getargs,
          testanswer = transitive_rule_1_testanswer,
          expected_val = "[ 'a beats b', 'b beats c', 'a beats c' ]",
          name = transitive_rule_1_getargs
          )


### TEST 7 ###

transitive_rule_2_getargs = TEST_RESULTS_TRANS2

def transitive_rule_2_testanswer(val, original_val = None):
    return ( set(val)
             == set([ 'rock beats rock',
                      'rock beats scissors',
                      'rock beats paper',
                      'scissors beats rock',
                      'scissors beats scissors',
                      'scissors beats paper',
                      'paper beats rock',
                      'paper beats scissors',
                      'paper beats paper' ]) )

# This test checks to make sure that your transitive rule produces
# the correct set of statements given the rock-paper-scissors data.

make_test(type = 'VALUE',
          getargs = transitive_rule_2_getargs,
          testanswer = transitive_rule_2_testanswer,
          expected_val = "[ 'rock beats rock', 'rock beats scissors', 'rock beats paper', 'scissors beats rock', 'scissors beats scissors', 'scissors beats paper', 'paper beats rock', 'paper beats scissors', 'paper beats paper' ]",
          name = transitive_rule_2_getargs
          )


### TEST 8 ###

transitive_rule_3_getargs = TEST_RESULTS_TRANS3
def transitive_rule_3_testanswer(val, original_val = None):
    return ( set(val)
             == set([ 'rock beats scissors',
                      'scissors beats paper',
                      'paper beats rock']) )

make_test(type = 'VALUE',
          getargs = transitive_rule_3_getargs,
          testanswer = transitive_rule_3_testanswer,
          expected_val = "['rock beats scissors', 'scissors beats paper', 'paper beats rock']",
          name = transitive_rule_3_getargs
          )

# This test checks to make sure that your transitive rule produces
# the correct set of statements given the rock-paper-scissors data and improved transitive rule.

test_offline()

Test 1/3: Incorrect.
Got:      ('a beats b', 'b beats c')
Expected: [ 'a beats b', 'b beats c', 'a beats c' ]
Test 2/3: Incorrect.
Got:      ('paper beats rock', 'rock beats scissors', 'scissors beats paper')
Expected: [ 'rock beats rock', 'rock beats scissors', 'rock beats paper', 'scissors beats rock', 'scissors beats scissors', 'scissors beats paper', 'paper beats rock', 'paper beats scissors', 'paper beats paper' ]
Test 3/3: Correct.
Passed 1 of 3 tests.


### Family relations tests

In [None]:
### TEST 9 ###

family_rules_1_getargs = TEST_RESULTS_1
expected_family_relations = [
    'brother bob alice',
    'sister alice bob',
    'father chuck bob',
    'son bob chuck',
    'daughter alice chuck',
    'father chuck alice' ]

def family_rules_1_testanswer(val, original_val = None):
    return ( set( [ x for x in val
                    if x.split()[0] in (
                                         'father',
                                         'son',
                                         'daughter',
                                         'brother',
                                         'sister',
                                         ) ] )
             == set(expected_family_relations))

# This test checks to make sure that your family rules produce
# the correct set of statements given the alice/bob/chuck data.
# Note that it ignores all statements that don't contain any of
# the words 'father', 'son', 'daughter', 'brother', or 'sister',
# so you can include extra statements if it helps you.

make_test(type = 'VALUE',
          getargs = family_rules_1_getargs,
          testanswer = family_rules_1_testanswer,
          expected_val = "added family relations should include: " + str(expected_family_relations),
          name = family_rules_1_getargs
          )


### TEST 10 ###

family_rules_2_getargs = TEST_RESULTS_2

def family_rules_2_testanswer(val, original_val = None):
    return ( set( [ x for x in val
                    if x.split()[0] == 'cousin' ] )
             == set([ 'cousin c1 c3',
                      'cousin c1 c4',
                      'cousin c2 c3',
                      'cousin c2 c4',
                      'cousin c3 c1',
                      'cousin c3 c2',
                      'cousin c4 c1',
                      'cousin c4 c2',
                      'cousin d1 d2',
                      'cousin d2 d1',
                      'cousin d3 d4',
                      'cousin d4 d3' ]) )

# This test checks to make sure that your family rules produce
# the correct set of statements given the a/b/c/d data.

make_test(type = 'VALUE',
          getargs = family_rules_2_getargs,
          testanswer = family_rules_2_testanswer,
          expected_val = "Results including " + str([ 'cousin c1 c3',
                               'cousin c1 c4',
                               'cousin c2 c3',
                               'cousin c2 c4',
                               'cousin c3 c1',
                               'cousin c3 c2',
                               'cousin c4 c1',
                               'cousin c4 c2',
                               'cousin d1 d2',
                               'cousin d2 d1',
                               'cousin d3 d4',
                               'cousin d4 d3' ]),
          name = family_rules_2_getargs
          )


### TEST 11 ###

family_rules_3_getargs = TEST_RESULTS_BLACK

def family_rules_3_testanswer(val, original_val = None):
    return ( set( [ x for x in val
                    if x.split()[0] == 'uncle' ] )
             == set([ 'uncle alphard bellatrix',
                      'uncle alphard andromeda',
                      'uncle alphard narcissa',
                      'uncle alphard sirius',
                      'uncle alphard regulus',
                      'uncle cygnus sirius',
                      'uncle cygnus regulus' ]))

# This test checks to make sure that your family rules produce
# the correct set of uncles for the Black data.

make_test(type = 'VALUE',
          getargs = family_rules_3_getargs,
          testanswer = family_rules_3_testanswer,
          expected_val = "Results including " + str([ 'uncle alphard bellatrix',
                      'uncle alphard andromeda',
                      'uncle alphard narcissa',
                      'uncle alphard sirius',
                      'uncle alphard regulus',
                      'uncle cygnus sirius',
                      'uncle cygnus regulus' ]),
          name = family_rules_3_getargs
          )


### TEST 12 ###

family_rules_4_getargs = TEST_RESULTS_BLACK

def family_rules_4_testanswer(val, original_val = None):
    return ( set( [ x for x in val
                    if x.split()[0] == 'aunt' ] )
             == set([ 'aunt walburga bellatrix',
                      'aunt walburga andromeda',
                      'aunt bellatrix nymphadora',
                      'aunt bellatrix draco',
                      'aunt walburga narcissa',
                      'aunt andromeda draco',
                      'aunt narcissa nymphadora' ]))

make_test(type = 'VALUE',
          getargs = family_rules_4_getargs,
          testanswer = family_rules_4_testanswer,
          expected_val = "Results including " + str([ 'aunt walburga bellatrix',
                      'aunt walburga andromeda',
                      'aunt walburga narcissa',
                      'aunt bellatrix nymphadora',
                      'aunt bellatrix draco',
                      'aunt andromeda draco',
                      'aunt narcissa nymphadora' ]),
          name = family_rules_4_getargs
          )

# This test checks to make sure that your family rules produce
# the correct set of aunts for the Black data.


test_offline()

Test 1/4: Incorrect.
Got:      ['female alice', 'male bob', 'male chuck', 'parent chuck alice', 'parent chuck bob']
Expected: added family relations should include: ['brother bob alice', 'sister alice bob', 'father chuck bob', 'son bob chuck', 'daughter alice chuck', 'father chuck alice']
Test 2/4: Incorrect.
Got:      ['female a1', 'female b1', 'female b2', 'female c1', 'female c2', 'female c3', 'female c4', 'female d1', 'female d2', 'female d3', 'female d4', 'parent a1 b1', 'parent a1 b2', 'parent b1 c1', 'parent b1 c2', 'parent b2 c3', 'parent b2 c4', 'parent c1 d1', 'parent c2 d2', 'parent c3 d3', 'parent c4 d4']
Expected: Results including ['cousin c1 c3', 'cousin c1 c4', 'cousin c2 c3', 'cousin c2 c4', 'cousin c3 c1', 'cousin c3 c2', 'cousin c4 c1', 'cousin c4 c2', 'cousin d1 d2', 'cousin d2 d1', 'cousin d3 d4', 'cousin d4 d3']
Test 3/4: Incorrect.
Got:      ('male sirius', 'male regulus', 'female walburga', 'male alphard', 'male cygnus', 'male pollux', 'female bellatrix', 'femal