In [None]:
from collections import defaultdict

def parents(G,n): return list(G.predecessors(n))
def aff(G): return [n for n,d in G.nodes(data='phenotype') if d==2]
def unaff(G): return [n for n,d in G.nodes(data='phenotype') if d!=2]

_ALLOWED = {
 'AD': {(0,0):{0}, (1,0):{0,1}, (0,1):{0,1}, (1,1):{0,1,2}, (2,'_'):{1,2}},
 'AR': {(0,0):{0}, (1,0):{0,1}, (0,1):{0,1}, (1,1):{0,1,2}, (2,'_'):{1,2}},
}

def allowed_child_set(mode,gp,gm):
    k=(gp,gm) if (gp,gm) in _ALLOWED[mode] else (gm,gp)
    if k not in _ALLOWED[mode]:
        k=(2,'_') if 2 in (gp,gm) else (1,'_')
    return _ALLOWED[mode].get(k,{0,1,2})

def generations(G):
    lvl={}
    Q=[(n,0) for n in G if G.in_degree(n)==0]
    while Q:
        n,d=Q.pop(0)
        if n in lvl: continue
        lvl[n]=d
        for c in G.successors(n): Q.append((c,d+1))
    return lvl

def segregation_network_score(G, gt, mode='AD', w_edge=.6,w_gen=.2,w_bet=.2):
    phen = nx.get_node_attributes(G,'phenotype')
    # edge penalty
    bad=0; total=0
    for child in G:
        prnts=parents(G,child)
        gp,gm=[gt.get(p,0) for p in prnts+[0,0]][:2]
        if gt[child] not in allowed_child_set(mode,gp,gm):
            bad+=1
        total+=1
    edge_score=1-bad/total
    # generation continuity
    gen=generations(G)
    gens_total=max(gen.values())+1
    alt_gens={gen[n] for n in G if gt[n]>0}
    if mode=='AD': gen_score=len(alt_gens)/gens_total
    else: gen_score=1-len(alt_gens)/gens_total
    gen_score=max(0,min(1,gen_score))
    # carrier betweenness
    het_car=[n for n in G if gt[n]==1 and phen[n]!=2]
    if het_car:
        bet=nx.betweenness_centrality(G)
        cb=sum(bet[n] for n in het_car)/len(het_car)
        cb/=max(bet.values()) if bet else 1
        bet_score=cb if mode=='AR' else 1-cb
    else:
        bet_score=0 if mode=='AR' else 1
    return round(w_edge*edge_score+w_gen*gen_score+w_bet*bet_score,3)

def scan_variants(G, vars_dict, mode='AD'):
    scores={vid:segregation_network_score(G,gt,mode) for vid,gt in vars_dict.items()}
    best=max(scores,key=scores.get)
    return best, scores

best_AD, scores_AD = scan_variants(G_ad, vars_AD, 'AD')
best_AR, scores_AR = scan_variants(G_ar, vars_AR, 'AR')
print('Best variant in AD:', best_AD, scores_AD[best_AD])
print('Best variant in AR:', best_AR, scores_AR[best_AR])