In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
import networkx as nx
import pandas as pd
import logging
import seaborn as sns
import matplotlib.pyplot as plt
import sys
from collections import Counter

In [None]:
plt.rcParams['figure.figsize'] = (14, 8)

In [None]:
sys.path.append('./..')

In [None]:
logging.basicConfig(level=logging.ERROR, format="%(message)s")

In [None]:
from VRG.runner import get_grammars
from VRG.src.utils import load_pickle, dump_pickle

In [None]:
def get_rule_df(dataset, vrg):
    cols = 'id', 'size', 'freq', 'n', 'm', 'graph', 'is_connected'
    rows = []
    for rule in vrg.rule_list:
        row = {'dataset': dataset, 'id': rule.id, 'size': rule.lhs_nt.size, 'freq': rule.frequency, 'n': rule.graph.order(), 'm': rule.graph.size(), 
               'is_connected': nx.is_connected(rule.graph), 'graph': rule.graph}
        rows.append(row)
    return pd.DataFrame(rows)

In [None]:
datasets = 'grenoble', 'waterloo', 'uppsala'
dfs = []

for dataset in datasets:
    vrg = load_pickle(f'../VRG/dumps/grammars/{dataset}/VRG_leiden_6_0.pkl')
    print(vrg)
    df = get_rule_df(dataset=dataset, vrg=vrg)
    dfs.append(df)

rule_df = pd.concat(dfs, ignore_index=True)

In [None]:
rule_df

In [None]:
ax = plt.gca()
for name, group_df in rule_df.groupby('dataset'):
    sns.distplot(a=group_df['size'], hist=False, label=name)
# rule_df['size'].plot(kind='hist', bins=100);
ax.set_ylim((-0.001, 0.02))
ax.set_xscale('log')

plt.legend(loc='best')
plt.title('VRG RHS size distribution')
plt.ylabel('Frequency');
plt.xlabel('Size of RHS (log)');

In [None]:
rule_df.n.value_counts().plot(kind='barh');
plt.ylabel('n');
plt.xlabel('freq');

In [None]:
rule_df['size'].value_counts()#.plot(kind='hist');
# plt.ylabel('n');
# plt.xlabel('size');

In [None]:
rule_df['size'].value_counts()#.plot(kind='hist');
# plt.ylabel('n');
# plt.xlabel('size');

In [None]:
rule_df.m.value_counts()#.plot(kind='hist');

In [None]:
rule = vrg.rule_list[-2]
rule.draw()

In [None]:
rule = vrg.unique_rule_list[-3]
Counter(map(lambda d: d['gender'], nx.get_node_attributes(rule.graph, name='attr_dict').values()))

In [None]:
def get_gender_counts(row):
    graph = row.graph
    ctr = Counter(map(lambda d: d['gender'], nx.get_node_attributes(rule.graph, name='attr_dict').values()))
    

In [None]:
rule_df.graph.apply(lambda g: list(nx.get_node_attributes(g, name='attr_dict').values())[0]).head(1).values

In [None]:
rule_df['male_count'] = rule_df['graph'].apply(lambda g: Counter(map(lambda d: d['gender'], 
                                                                     list(nx.get_node_attributes(g, name='attr_dict').values()))))

In [None]:
rule_df.head(1)['graph'].apply(lambda g: )

In [None]:
def extract_gender(g, gender):
    attr_dict = nx.get_node_attributes(g, 'attr_dict')
    if len(attr_dict) == 0:
        return 0
    else:
        print(attr_dict.values())
        if 'gender' in attr_dict.values():
            ctr = Counter(map(lambda d: d['gender'], attr_dict.values()))
            if gender in ctr:
                return ctr[gender]
            else:
                return 0
        else:
            return 0

In [None]:
rule_df['male_count'] = -1
rule_df['female_count'] = -1

In [None]:
for i, row in enumerate(rule_df.itertuples()):
    rule_df.at[i, 'male_count'] = extract_gender(row.graph, 'male')
    rule_df.at[i, 'female_count'] = extract_gender(row.graph, 'female')

In [None]:
rule_df.head(6)

In [None]:
extract_gender(rule.graph, 'male')