In [1]:
#| default_exp dists

In [2]:
#| exporti
import numpy, re
from collections import defaultdict
from math import ceil, floor, comb, isnan
prange = range # as we use range as argument name so it helps to have alias

In [3]:
# Two profiles for testing
wep = { 'type':'ranged', 'attacks': '3', 'bsws': 3, 'strength': 5, 'AP': -2, 'damage': '2', 'kws': ['lethal hits', 'devastating wounds', 'sustained hits 2', 'anti-infantry 4+', 'rapid fire d2', 'melta d3+2'] }
target = { 'toughness': 3, 'save': 2, 'invuln': 4, 'wounds': 3, 'kws':['infantry'], 'abilities': ['feel no pain 5+', 'stealth'] }

In [78]:
#| export

# Convolution of d1 and d2.
# If n_wounds is given, respect the n_wounds boundaries of units
# If first_n is given, assume first unit has only first_n wounds
# NB! n_wounds makes convolution non-associative, and first_n non-commutative,
#     which makes things a lot more complex to reason about
def convolve(d1,d2, n_wounds=None, first_n=0):
    res = defaultdict(lambda: 0)
    if n_wounds is None:
        for k1,v1 in d1.items():
            for k2,v2 in d2.items():
                res[k1+k2] += v1*v2
    else: # Limit damage to n_wounds units
        for k1,v1 in d1.items():
            for k2,v2 in d2.items():
                kv = k1+k2
                if first_n: # Handle first threshold separate
                    if k1<first_n:
                        mod = k1+k2+1 # i.e. ignore mod part
                        kv = min(k1+k2,first_n) # just set the kv
                    else: # Past first threshold, so business as usual
                        k1,mod = k1-first_n,n_wounds
                else: mod = n_wounds
                
                k1m, k2m = k1%mod, k2%mod
                if k1m + k2m > mod: 
                    kv -= (k1m+k2m)%mod
                res[kv] += v1*v2
    return res


In [None]:
#| export

# Other helper functions for distributions
# TODO: this is better encapsulated into a class

def mult_ddist_vals(d, val):
    res = defaultdict(lambda: 0)
    for k1,v1 in d.items():
        res[int(ceil(k1*val))] += v1 
    return res

def mult_ddist_probs(d,p):
    return { k:v*p for k,v in d.items() }

def threshold_ddist(dd,val,lt=True):
    for k in list(dd.keys()): 
        if (lt and k<val) or (not lt and k>val): 
            dd[val]+=dd[k]
            del dd[k]

def add_ddist(d1,d2):
    res = defaultdict(lambda: 0)
    for k1,v1 in d1.items():
        res[k1] += v1  
    for k2,v2 in d2.items():
        res[k2] += v2 
    return res
    
def flatdist(n):
    return { (i+1):1/n for i in range(n) }

# q is the probability of saving the damage
def fnp_transform(d, q):
    p = 1.0-q
    res = defaultdict(lambda: 0)
    for k,v in d.items():
        for i in range(k+1): # Binomial distribution
            res[i] += v*comb(k,i)*(p**i)*(q**(k-i))
    return res

def dd_rep(d, n, **argv):
    if n == 0 : return { 0: 1 }
    res_d = d
    for _ in range(1,n):
        res_d = dd_prune(convolve(res_d,d,**argv),1e-3)
    return res_d

# Prune all values with prob below ratio * <max prob>
def dd_prune(d, ratio):
    t = ratio*max(d.values())
    return { k:v for k,v in d.items() if v>t }

def dd_mean(dd):
    val = 0.0
    for k, v in dd.items():
        val += k*v
    return val


def dd_above(d, thresh):
    p = 0.0
    for k,v in d.items():
        if k>=thresh: p+=v
    return p


def dd_max(dd):
    return max(dd.keys())

def dd_psum(dd):
    return sum(dd.values())


In [5]:
#| exporti

def dd_from_str(dstr):
    dd = { 0: 1.0 }
    for t in dstr.lower().split('+'):
        if 'd' in t:
            d = t.split('d')
            if d[0]=='': d = [1,int(d[1])]
            else: d = list(map(int,d))
            for _ in range(d[0]):
                nd = flatdist(d[1])
                dd = convolve(dd,nd)
        else:
            d = int(t)
            dd = convolve(dd,{d:1.0})
    return dd

In [6]:
#| export

# Return suffix of the kw if found, and None otherwise
def find_kw(kw,kws):
    sus = None
    for k in kws:
        if k.startswith(kw):
            sus = k[len(kw):].strip()
    return sus

In [7]:
#| export

def single_dam_dist(wep, target, range=False):

    # Create dmgstr distribution
    dd = dd_from_str(wep['damage'])

    # Apply div/mult modifiers - ǸB order matters  
    if 'halve damage' in target['abilities']: dd = mult_ddist_vals(dd,1.0/2)
    if 'double damage' in wep['kws']: dd = mult_ddist_probs(dd,2)

    # Threshold to 1
    dr = find_kw('damage reduction', target['abilities'])
    if dr: dd = convolve(dd,{-int(dr):1.0})
    threshold_ddist(dd,1,True)

    # Melta - after div and mult
    melta = find_kw('melta',wep['kws'])
    if melta and range: 
        #print("MELTA",melta)
        dd = convolve(dd,dd_from_str(melta))

    # Apply FNP
    fnp = find_kw('feel no pain',target['abilities'])
    if fnp: 
        #print("FNP",fnp)
        fnp = int(fnp.strip('+'))
        dd = fnp_transform(dd,(7-fnp)/6)

    # Threshold to n_wounds
    threshold_ddist(dd,target['wounds'],False)

    #print(dd_mean(dd))

    return dd
    

In [8]:
#| exporti

def get_hit_probs(wep,target):

    # Prob to hit
    if 'torrent' in wep['kws']:
        p_hcrit, p_hit = 0, 1
    else:
        hc = find_kw('hit crit',wep['kws'])
        hit_crit = int(hc.strip('+')) if hc else 6
        
        p_hcrit = (7-hit_crit)/6.0

        hit_t = wep['bsws']

        # Stealth
        if wep['type']=='ranged' and 'stealth' in target['abilities']: 
            #print("stealth")
            hit_t +=1
        ih = find_kw('improved hits',wep['kws'])
        if ih: hit_t -= int(ih)

        p_hit = max((6*p_hcrit),min(5, # hit_crit always hits, 1 always misses
                    (7-hit_t)))/6.0
        
        if 'overwatch' in wep['kws'] and wep['type']=='ranged': # Only on an unmodified roll of 6
            p_hit = p_hcrit = 1/6.0
        elif 'indirect' in wep['kws'] and 'indirect fire' in wep['kws']: # Only on an unmodified roll of 3
            p_hit = min(p_hit,0.5)
            p_hcrit = min(p_hcrit,0.5)
        
        if 'reroll hits' in wep['kws']:
            p_hit += (1-p_hit)*p_hit
            p_hcrit += (1-p_hit)*p_hcrit
        elif 'reroll 1s to hit' in wep['kws']:
            p_hit += (1/6.0)*p_hit
            p_hcrit += (1/6.0)*p_hcrit
            

    return p_hit,p_hcrit

In [9]:
#| export
def atk_success_prob(wep, target, crit_hit=None, cover=False, verbose=False):

    # TODO:
    # Rerolls (1 and all)
    # Fish-for-sixes if better.

    # Probs to hit
    
        
    # Check if fn parameter already tells us if it was a crit or not
    # This behavior is needed for Sustained hits as they also affect hit counts
    if crit_hit: p_hcrit,p_hit = 1,1
    elif crit_hit==False: p_hcrit,p_hit = 0, 1
    else: p_hit, p_hcrit = get_hit_probs(wep, target)

    if verbose: print("Hit",p_hit,p_hcrit)

    # Prob to wound

    # Get crit wound threshold (ANTI keywords)
    wc = find_kw('wound crit',wep['kws'])
    wound_crit = int(wc.strip('+')) if wc else 6
    
    for k in wep['kws']:
        if k.startswith('anti-'):
            kw,val = k[5:].split(' ')
            if kw in target['kws']: 
                #print("Anti",kw,val)
                wound_crit = min(wound_crit,int(val.strip('+')))    

    p_wcrit = (7-wound_crit)/6.0

    if wep['strength']>target['toughness']:
        if wep['strength']>=2*target['toughness']:
            p_wound = 5
        else: p_wound = 4
    elif wep['strength']<target['toughness']:
        if 2*wep['strength']<=target['toughness']:
            p_wound = 1
        else: p_wound = 2
    else: p_wound = 3

    iw = find_kw('improved wounds',wep['kws'])
    if iw: p_wound += int(iw)
    p_wound /= 6

    p_wound = min(5/6,max(p_wcrit,p_wound))

    if 'twin-linked' in wep['kws'] or 'reroll wounds' in wep['kws']:
        #print("Twinlinked")
        p_wound += (1-p_wound)*p_wound
        p_wcrit += (1-p_wound)*p_wcrit
    elif 'reroll 1s to wound'  in wep['kws']:
        p_wound += (1/6.0)*p_wound
        p_wcrit += (1/6.0)*p_wcrit

    if verbose: print("Wound",p_wound,p_wcrit)

    # Prob to not save

    # Cover effect
    c_eff = 0 if (not cover or wep['type']!='ranged' or 
                'ignores cover' in wep['kws'] or 
                (wep['AP']==0 and target['save']<=3)) else 1

    iap = find_kw('improved ap',wep['kws'])
    if iap: c_eff += int(iap)
    
    save = min(target['invuln'] or 10,target['save']-c_eff-wep['AP'])
    p_nsave = 1.0 - max(0,min(6,7-save))/6.0

    if 'devastating wounds' in wep['kws']: # Devastating wounds
        #print("devwounds")
        p_nsave = ((p_wound-p_wcrit)*p_nsave + p_wcrit)/p_wound

    if verbose: print("Save",p_nsave)

    # Total probability
    if 'lethal hits' in wep['kws']:
        #print("lethal hits")
        p_dam = p_hcrit*p_nsave + (p_hit-p_hcrit)*p_wound*p_nsave
    else:
        p_dam = p_hit*p_wound*p_nsave

    if verbose: print("Total prob", p_dam)

    return p_dam

In [10]:
#| exporti

# Create res as weighted sum of repeated convolutions with weights given by b_dd and repeated self-convolutons of r_dd
def dd_over_dd(b_dd,r_dd,base=0,**argv):
    cur_d,res_d = {base: 1.0}, {0: b_dd.get(0,0.0)}
    for i in range(1,dd_max(b_dd)+1):
        cur_d = convolve(cur_d,r_dd,**argv)
        if i in b_dd:
            res_d = add_ddist(res_d,mult_ddist_probs(cur_d,b_dd[i]))
    return res_d

In [11]:
#| export

# Wrapper around atk_success_prob that handles sustained hits
def atk_success_dist(wep,target,cover=False,overwatch=False):
   
    # Find number of sustained hits
    sus = find_kw('sustained hits',wep['kws'])
   
    # Handle the easy case (no sustained hits)
    if not sus:
        p = atk_success_prob(wep,target,cover=cover)
        return { 1: p, 0: (1-p) }
    
    # Sustained hits:
    #print("Sustained",sus)
    sus_d = dd_from_str(sus)
    
    p_hit, p_hcrit = get_hit_probs(wep,target)
    pc = atk_success_prob(wep,target,True,cover=cover)
    pn = atk_success_prob(wep,target,False,cover=cover)

    #p = pn*(1-p_hcrit)
    normal = { 1: pn, 0: (1-pn) }
    crit = { 1: pc, 0: (1-pc) }
    crit = convolve(crit,dd_over_dd(sus_d,normal))
    
    normal = mult_ddist_probs(normal,p_hit-p_hcrit)
    crit = mult_ddist_probs(crit,p_hcrit)

    total =  add_ddist(normal, crit)
    total[0] += 1.0-p_hit

    return total


In [12]:
#| export

def successful_atk_dist(wep,target, range=False, cover=False):
    if range not in [True,False]: range = (range<=wep['range']/2)

    # Base attack number dist
    an_d = dd_from_str(wep['attacks'])

    # Rapid fire
    rfire = find_kw('rapid fire',wep['kws'])
    if rfire and range: 
        #print("Rapidfire",rfire)
        an_d = convolve(an_d,dd_from_str(rfire))

    # Other added attacks, incl Blast
    added_attacks = 0
    if 'blast' in wep['kws'] and target.get('models',0)>=5:
        #print("Blast")
        added_attacks += target['models']//5
    if added_attacks!=0:
        an_d = convolve(an_d,{added_attacks:1})

    # Attack successes dist for an individual attack
    as_d = atk_success_dist(wep,target,cover)

    # Create res as weighted sum of repeated convolutions
    res_d = dd_over_dd(an_d,as_d)

    return res_d

In [45]:
#| export

# Final end-to-end calculation for a weapon
# Range can be True (is in half distance), False (is not) or number of inches
def dam_dist(wep,target, n=1, range=False, cover=False, fulldist=False):
    if range not in [True,False]: range = (range<=wep['range']/2)

    # Successful attack dist
    sa_d = successful_atk_dist(wep,target, range, cover)

    # Single damage dist
    sd_d = single_dam_dist(wep,target,range=range)

    # Create res as weighted sum of repeated convolutions
    sar_d = dd_rep(sa_d,n)

    if not fulldist: # Return regular dist    
        res_d = dd_over_dd(sar_d,sd_d,n_wounds=target['wounds'])
    else: # Return a list of dists, one with first_n for each of 0 ... target['wounds']-1 values
        res_d = [ dd_over_dd(sar_d,sd_d,n_wounds=target['wounds'],first_n=n_f) for n_f in prange(target['wounds'],0,-1) ]
            
    return res_d

In [84]:
wep = { 'type':'ranged', 'attacks': '3', 'bsws': 3, 'strength': 5, 'AP': -2, 'damage': '2', 'kws': ['lethal hits', 'devastating wounds'] }#, 'sustained hits 2',  'rapid fire d2', 'melta d3+2'] }
target = { 'toughness': 3, 'save': 2, 'invuln': None, 'wounds': 3, 'kws':['infantry'], 'abilities': [] }#'feel no pain 5+', 'stealth'] }
dam_dist(wep,target,cover=True,range=True)

defaultdict(<function __main__.add_ddist.<locals>.<lambda>()>,
            {0: 0.421875, 2: 0.421875, 3: 0.140625, 5: 0.015625})

In [None]:
#| export

# Convolution of 'full' distributions i.e. when we have one for each mod n_wounds value
def fulldist_convolve(fd1,fd2,n_wounds):
    fres = []
    for f_n,d1 in enumerate(fd1):
        res = defaultdict(lambda: 0)
        for k1, v1 in d1.items():
            d2 = fd2[(k1+f_n)%n_wounds]
            for k2,v2 in d2.items():
                kv = k1+k2
                res[kv] += v1*v2
        fres.append(res)

    return fres


In [81]:
wep = { 'type':'ranged', 'attacks': '3', 'bsws': 3, 'strength': 5, 'AP': -2, 'damage': '2', 'kws': ['lethal hits', 'devastating wounds', 'anti-infantry 4+'] }#, 'rapid fire d2', 'melta d3+2', 'sustained hits d3'] }
target = { 'toughness': 3, 'save': 2, 'invuln': None, 'wounds': 3, 'kws':['infantry'], 'abilities': [] }#'feel no pain 5+', 'stealth'] }
dd = dam_dist(wep,target,cover=True,range=True, fulldist=True)
fulldist_convolve(dd,dd,target['wounds'])

[defaultdict(<function __main__.fulldist_convolve.<locals>.<lambda>()>,
             {0: 0.039400412058470474,
              2: 0.16885890882201637,
              3: 0.30153376575360075,
              5: 0.2871750150034293,
              6: 0.15384375803755143,
              8: 0.043955359439300415,
              9: 0.005232780885631004}),
 defaultdict(<function __main__.fulldist_convolve.<locals>.<lambda>()>,
             {0: 0.039400412058470474,
              2: 0.16885890882201637,
              4: 0.30153376575360075,
              5: 0.2871750150034293,
              7: 0.15384375803755143,
              8: 0.043955359439300415,
              10: 0.005232780885631004}),
 defaultdict(<function __main__.fulldist_convolve.<locals>.<lambda>()>,
             {0: 0.039400412058470474,
              1: 0.16885890882201637,
              3: 0.30153376575360075,
              4: 0.2871750150034293,
              6: 0.15384375803755143,
              7: 0.043955359439300415,
              

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()