In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import javalang
from dataclasses import dataclass
from tqdm.auto import tqdm
from pickle_cache import PickleCache
from javalang import tree
from pprint import pprint
import textwrap
import copy
from iterextras import par_for, unzip
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from utils import *

sns.set()
pcache = PickleCache('.cache')

In [None]:
solutions = get_solutions()

In [None]:
class BadStatement(Exception):
    pass

def convert_method(expr):
    if not isinstance(expr, tree.MethodInvocation):
        raise BadStatement('no')
    s = expr.member
    return {
        'type': s[0].capitalize() + s[1:]
    }  

def pred_to_json(expr):
    if isinstance(expr, tree.BinaryOperation):
        raise BadStatement('no')
    return convert_method(expr)

def selfref(stmt):
    if stmt is None:
        return None
    return {
        'type': 'Concrete',
        'data': stmt
    }


def stmt_to_json(stmt):
    if isinstance(stmt, list):
        if len(stmt) == 0:
            raise BadStatement('no')
        if len(stmt) == 1:
            return stmt_to_json(stmt[0])        
        first = stmt_to_json(stmt[0])
        second = stmt_to_json(stmt[1:])
        return {
            'type': 'Seq',
            'first': selfref(first),
            'second': selfref(second)
        }
    elif isinstance(stmt, tree.BlockStatement):
        return stmt_to_json(stmt.statements)
    elif isinstance(stmt, tree.StatementExpression):
        expr = stmt.expression
        return {
            'type': 'Action',
            'action': convert_method(expr)
        }
    elif isinstance(stmt, tree.IfStatement):
        return {
            'type': 'If',
            'pred': pred_to_json(stmt.condition),
            'then_': selfref(stmt_to_json(stmt.then_statement)),
            'else_': selfref(stmt_to_json(stmt.else_statement))
        }
    elif isinstance(stmt, tree.WhileStatement):
        return {
            'type': 'While',
            'pred': pred_to_json(stmt.condition),
            'body': selfref(stmt_to_json(stmt.body))
        }
    elif stmt is None:
        return None
    else:
        raise BadStatement(f'Unknown type {type(stmt)}')    

In [None]:
progs = []
for solns in tqdm(solutions.values()):
    try:
        methods = get_methods(solns[-1])
    except IndexError:
        continue
        
    if 'run' not in methods:
        continue
        
    try:
        inlined = Inline(methods).visit(methods['run'])
        prog = stmt_to_json(inlined.body)
        progs.append(prog)
    except (RecursionError, BadStatement):
        pass

In [3]:
progs = pcache.get('progs', lambda: progs)

In [6]:
import json
json.dump(progs, open('progs.json', 'w'))

In [4]:
import grammar_induction
grammar_induction.test(progs)

PanicException: no entry found for key

In [None]:
grammars = {}
for k, solns in tqdm(list(solutions.items())):   
    try:
        methods = get_methods(solns[-1])
        if 'run' not in methods:
            continue

        generator = GrammarGenerator(k, methods)
        generator.generate(methods['run'])
        grammars[k] = Grammar(generator.productions)
    except (RecursionError, Unimplemented, IndexError):
        pass
    except Exception:
        print(k)
        raise

In [None]:
productions = {k: p for g in grammars.values() for k, p in g.productions.items()}

productions['start'] = Production(rules=[
    Rule(
        parts=[f'{k}_run'],
        prob=1. / len(grammars)
    )
    for k in grammars.keys()
])
g = Grammar(productions=productions)
g, cs = g.simplify()

In [None]:
list(g.productions.values())[:100]

In [None]:
k = list(solutions.keys())[0]
solns = solutions[k]
methods = get_methods(solns[-1])
prog = GrammarGenerator(k, methods, grammar=False).generate(methods['run'])

In [None]:
p = g.sample()
p

In [None]:
# Find all 
groups = sorted(cs.items(), key=lambda t: len(t[1]))[::-1]
for k, group in groups[:50]:
    other_names = ['_'.join(k2.split('_')[1:]) for k2 in group if '_rule' not in k2]
    print(k, len(group), len(other_names), other_names)
    pprint(g.expand(k))
    print()

In [None]:
def parse(g, prog):    
    ops = {}
    for i, part in enumerate(prog.parts):
        if isinstance(part, Op):
            ops[i] = [parse(g, child) for child in part.children()]
            
    print(ops)
    
    n = len(prog.parts)
    levels = []
    for k in range(1, n+1):   
        level = []        

        for i in range(0, n-k+1):
            matches = []
            
            for prod_name, prod in g.productions.items():
                for rule in prod.rules:                    
                    match = True
                    
                    j = 0                    
                    for part in rule.parts:
                        if i + j >= len(prog.parts):
                            match = False
                            break
                            
                        token = prog.parts[i+j]
                        if isinstance(part, str):
                            for k2 in range(0, k-1):
                                if part in levels[k2][i+j]:
                                    j += k2 + 1
                                    break
                            else:
                                match = False
                        elif isinstance(part, Op):
                            if (part.__class__ == token.__class__ and 
                              len(part.children()) == len(token.children()) and
                              part.cond == token.cond):
                                for (child, child_matches) in zip(part.children(), ops[i+j][-1][0]):
                                    if child not in child_matches:
                                        match = False
                                        break
                            j += 1
                        elif part != token:
                            match = False                       
                            j += 1
                        else:
                            j += 1
                        
                        if not match:
                            break
                            
                    if j == len(prog.parts) - 1 and match:
                        matches.append(prod_name)                            
                    
#                     substr = prog.parts[i:i+k]
#                     for j, (l, r) in enumerate(zip(rule.parts, substr)):
#                         if isinstance(l, str) and k > 1 and l not in levels[k-2][i+j]:
#                             match = False
#                         elif (isinstance(l, Op) and 
#                               l.__class__ == r.__class__ and 
#                               len(l.children()) == len(r.children()) and
#                               l.cond == r.cond):
#                             for (child, child_matches) in zip(l.children(), ops[i+j][-1][0]):
#                                 if child not in child_matches:
#                                     match = False
#                         elif l != r:
#                             match = False
#                         if not match:
#                             break
#                     if match:
#                         matches.append(prod_name)
                        
            level.append(matches)
            
        levels.append(level)
        
    return levels
    
#parse(g, Block([Action.move]))
parse(g, Block([IfNode(Predicate.frontIsClear, Block([Action.move]), None)]))
#parse(g, Block([IfNode(Predicate.frontIsBlocked, Block([Action.turnLeft]), None)]))
#parse(g, Block([prog.parts[1]]))

In [None]:
g.productions['student367_rule1']

In [None]:
g.productions['student416_accountForSingleColumn']

In [None]:
prog

In [None]:
# # Compute overlap in productions between students
# N = len(grammars)
# dists = np.zeros((N, N))
# index = list(grammars.keys())
# rev_index = {s: i for i, s in enumerate(index)}

# for k1, grammar1 in tqdm(list(grammars.items())):
#     for k2, grammar2 in grammars.items():
#         count = 0
#         grammar1, grammar2 = normalize_grammars(grammar1, grammar2)
#         for p1 in grammar1.productions.values():
#             for p2 in grammar2.productions.values():
#                 if p1 == p2:
#                     count += 1
#                     break
#         dists[rev_index[k1],rev_index[k2]] = count

# dists_sorted = np.dstack(np.unravel_index(np.argsort(dists, axis=None), dists.shape))[0,::-1]
# dists_sorted = dists_sorted[dists_sorted[:,0] != dists_sorted[:,1]]        