In [5]:
import matplotlib.pyplot as plt
import matplotlib
from typing import Tuple
from functools import cache
import itertools
import scipy.special
import scipy.stats
import numpy as np

%matplotlib inline
matplotlib.rcParams['font.size'] = 20
matplotlib.rcParams['figure.figsize'] = (16,9)
matplotlib.rcParams['savefig.bbox'] = 'tight'

# A fast implementation of artifact statistics
This code aims to demonstrate the ideas and maths shown in the following paper: [https://www.overleaf.com/read/nvxmkdpqjprj](https://www.overleaf.com/read/nvxmkdpqjprj)

I offer two contributions:

1. Very fast implementation of GenshinOptimizer's existing Roll Probability Calculator. frzyc noted that it's somewhat straining the front-end, so I've made a cache-based implementation that runs in microseconds per query.
2. A method that, given an optimization target and a current build, finds the artifact in the inventory that has the greatest chance of improving your build.

In [6]:
@cache
def nk3(n, k):
    return sum([scipy.special.binom(n, i) * scipy.special.binom(n, k-2*i) for i in range(k//2+1)])

@cache
def mu(a, j):
    return sum([nk3(j, i-7*j) for i in range(a, 10*j+1)]) / (4**j)

@cache
def multinom(j1, j2=None, j3=None, j4=None, N=5):
    # The standard multinomial distribution. Below logic handles variable arguments
    if j2 is None:
        rem = N-j1
        rv = scipy.stats.multinomial(N, [1/4, 3/4])
        return rv.pmf([j1, rem])
    
    if j3 is None:
        rem = N-j1-j2
        rv = scipy.stats.multinomial(N, [1/4, 1/4, 2/4])
        return rv.pmf([j1, j2, rem])
    
    if j4 is None:
        rem = N-j1-j2-j3
        rv = scipy.stats.multinomial(N, [1/4, 1/4, 1/4, 1/4])
        return rv.pmf([j1, j2, j3, rem])
    
    if j1+j2+j3+j4 != N:
        return 0
    rv = scipy.stats.multinomial(N, [1/4, 1/4, 1/4, 1/4])
    return rv.pmf([j1, j2, j3, j4])

## Type 1 Query
The type 1 query behaves as:
$$ P(A_1\geq a_1 \land A_2\geq a_2\land\cdots)$$

This is already implemented in frzyc's GenshinOptimizer, but issues regarding its computational load have been brought up. The following method I present is quite fast; with appropriate caching it reaches about 200µs per loop. (around .2s for 1000 queries)

----

Can be further optimized:
- iterate with nested loops i=0..N; j=0..(N-i); k=0..(N-i-j) rather than the cartesian products
- check for mu=0 conditions, whenever $a_i> 10(j_i+b)$. Skip computation of those terms.
- wrap the cached functions in a preprocessor to reduce required cache size
  - mu=1 whenever $a_i \leq 7(j_i+b)$
  - sort inputs to `multinom`, discard any zeros in the input

In [7]:
# P(A1 >= a1 AND A2 >= a2)
def prob1(a1, a2=None, a3=None, a4=None, N=5, b=1):
    if a2 is None:
        return sum([multinom(j1, N=N) * mu(a1, j1+b) for j1 in range(N+1)])
    
    if a3 is None:
        ret = 0
        for j1, j2 in filter(lambda x: sum(x)<=N, itertools.product(range(N+1), repeat=2)):
            ret += multinom(j1, j2, N=N) * mu(a1, j1+b) * mu(a2, j2+b)
        return ret
    
    if a4 is None:
        ret = 0
        for j1, j2, j3 in filter(lambda x: sum(x)<=N, itertools.product(range(N+1), repeat=3)):
            ret += multinom(j1, j2, j3, N=N) * mu(a1, j1+b) * mu(a2, j2+b) * mu(a3, j3+b)
        return ret
    
    ret = 0
    for j1, j2, j3 in filter(lambda x: sum(x)<=N, itertools.product(range(N+1), repeat=3)):
        j4 = N-j1-j2-j3
        ret += multinom(j1, j2, j3, j4, N=N) * mu(a1, j1+b) * mu(a2, j2+b) * mu(a3, j3+b) * mu(a4, j4+b)
    return ret


In [8]:
%timeit prob1(*np.random.randint(20, size=4))

180 µs ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### Type 1 Query on existing artifact
Existing artifacts reduce to a special case of the previous query. Since they're practically the same query, the runtime is essentially the same, and is faster if `rolls_left` is smaller.

In [9]:
def p1_existing(requirements, existing, rolls_left):
    transformed_requirements = []
    for ai, az in zip(requirements, existing):
        transformed_requirements.append(ai - az)
    
    return prob1(*transformed_requirements, N=rolls_left, b=0)

## Type 2 Query
The type 2 query behaves as:
$$ P(k_1A_1 + k_2A_2 + \cdots\geq a^*)$$

There are many reasons why this kind of query is interesting. For example, suppose I want my artifact to roll 100 EM or 15% crit rate, or anywhere in between. So some acceptable results would be:
- 100 EM, 0% cr
- 90 EM, 1.5% cr
- 80 EM, 3% cr
- 50 EM, 7.5% cr
- 20 EM, 12% cr
- 0 EM, 25% cr

Converting to the language of substat rolls, we're asking for 42.9 IVs in EM or 38.6 IVs in CR. Writing with linear coefficients, we can roughly phrase the problem as:
$$ \begin{array}{ccc}k_1 = 0.9 & k_2 = 1 & a^* = 38.6 \end{array} $$
$$ P(0.9A_1 + A_2 \geq 38.6) $$


The problem becomes rather nontrivial even for the combination of only two artifacts, especially when they have different weights.

In [325]:
# Returns list of (v, p(v)) which represent exact distribution of P(k1A1 + ... = v)
def exact(ks, js):
    if len(ks) == 0:
        return np.array([[0, 1]])
    
    k, j = ks[-1], js[-1]
    pr1 = np.array([(k*i, nk3(j, i-7*j) / 4**j) for i in range(7*j, 10*j+1)])
    pr2 = exact(ks[:-1], js[:-1])
    
    res = {}
    for v1, p1 in pr1:
        for v2, p2 in pr2:
            res[v1 + v2] = res.get(v1+v2, 0) + p1*p2
    return np.array([(k, v) for k, v in res.items()])

# Returns the probability the sum total value of k1A1 + ... exceeds thr.
# We can also guarantee some % error threshold, at the cost of computation time. Set err_thresh=0 to solve exactly.
def pnj(thr, kjs, err_thresh=None):
    # increase all j's by b    
    ks = kjs[0]
    js = (kjs[1]).astype(int)

    # Check for simple situations
    if thr <= 7 * np.sum(ks * js):
        return 1
    if thr > 10 * np.sum(ks*js):
        return 0
        
    # Gaussian estimate
    mu = 8.5 * np.sum(ks * js)
    var = 5/4 * np.sum(ks*ks*js)        
    p_est = scipy.special.erfc((thr - mu) / np.sqrt(2 * var)) / 2

    # error estimate of Gaussian
    err = np.amax(ks) / (2 * np.sqrt(2 * var))
    max_err = err / p_est
    if err_thresh is None or max_err <= err_thresh:
        return p_est
        
    # Partially exact formulation
    ktarg = err_thresh * p_est * (2 * np.sqrt(2 * var))
    approx_select = ks < ktarg
    ext = exact(ks[~approx_select], js[~approx_select])
    
    # Construct partial approximate distribution
    ks, js = ks[approx_select], js[approx_select]
    mu = 8.5 * np.sum(ks * js)
    var = 5/4 * np.sum(ks*ks*js)
    if var > 0:
        p_est = lambda x: scipy.special.erfc((x - mu) / np.sqrt(2 * var)) / 2
    else:
        p_est = lambda x: 1 if x < 0 else 0
    
    # convolve exact distribution with approximate distribution
    ptot = 0
    for v, p in ext:
        ptot += p * p_est(thr - v)
    return ptot

def prob2(thresh, ks = (1), N=5, b=1, err=None):
    reps = 0
    ks = np.array(ks)
    
    ptot = 0
    for js in filter(lambda x: sum(x) <= N, itertools.product(range(N+1), repeat=len(ks))):
        js2 = np.array(js) + b
        if thresh <= 7*np.sum(ks * js2):
            ptot += multinom(*js, N=N)
            continue
        elif thresh > 10*np.sum(ks * js2):
            continue
            
        tmp = pnj(thresh, (ks, js2), err_thresh=err)
        ptot += multinom(*js, N=N) * tmp
    return ptot

A note on the `err` paramter. I typically leave it unconstrained because my approximation is constructed very well. I've also found a way to mathematically guarantee certain error bounds, which naturally makes the computation time suffer.

For example, we could let `err=0.1` to guarantee that the returned result is within 10% of the true probability, or even `err=0` to force the program to find the exact probability.

In [345]:
print(f'Approximate P(100EM or 15CR): {prob2(38.6, [.9, 1])}')
print(f'Exact       P(100EM or 15CR): {prob2(38.6, [.9, 1], err=0)}')

Approximate P(100EM or 15CR): 0.4258741547179381
Exact       P(100EM or 15CR): 0.4154639244079591


In [346]:
%timeit prob2(38.6, [.9, 1])

601 µs ± 23.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [347]:
%timeit prob2(38.6, [.9, 1], err=0)

2.82 ms ± 172 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Type 2 Query on existing artifact
Same as with Type 1 queries, we just need to shift the threshold.

In [337]:
def p2_existing(thresh, ks, existing, rolls_left):
    new_th = thresh
    for ki, vi in zip(ks, existing):
        new_th -= ki*vi
    
    return prob2(new_th, ks, N=rolls_left, b=0)

## Connection to the Damage Formula
The following couple blocks set up my damage evaluation schema, where I've manually filled in some standard(ish) values.

All the percent- values are multiplied by 1000 for some reason.

In [699]:
import artifact2
import char2
import damage2
from util.common import statnames, statmap

In [755]:
import importlib
importlib.reload(damage2)

<module 'damage2' from '/Users/albertxu/PycharmProjects/playground/Genshin/damage2.py'>

In [756]:
# Stats of lvl 90 Diona w/ protype crescent
diona_stats = np.zeros(len(statnames))
diona_stats[statmap['HP']] = 9570
diona_stats[statmap['BaseATK']] = 722
diona_stats[statmap['ATK%']] = 413
diona_stats[statmap['DEF']] = 601
diona_stats[statmap['CR']] = 50
diona_stats[statmap['CD']] = 500
diona_stats[statmap['ER']] = 1000
diona_stats[statmap['Cryo']] = 240

In [757]:
# The artifacts I have equipped
flower = np.zeros(len(statnames))
flower[statmap['HP']] = 4780
flower[statmap['ATK']] = 33
flower[statmap['DEF']] = 16
flower[statmap['EM']] = 82
flower[statmap['CD']] = 140

feather = np.zeros(len(statnames))
feather[statmap['ATK']] = 311
feather[statmap['ER']] = 162
feather[statmap['DEF']] = 37
feather[statmap['CR']] = 31
feather[statmap['CD']] = 192

sands = np.zeros(len(statnames))
sands[statmap['EM']] = 187
sands[statmap['HP']] = 538
sands[statmap['DEF']] = 39
sands[statmap['CR']] = 124
sands[statmap['CD']] = 78

cup = np.zeros(len(statnames))
cup[statmap['Cryo']] = 466
cup[statmap['HP%']] = 47
cup[statmap['ATK%']] = 111
cup[statmap['ER']] = 220
cup[statmap['EM']] = 16

hat = np.zeros(len(statnames))
hat[statmap['CR']] = 311
hat[statmap['HP']] = 299
hat[statmap['CD']] = 155
hat[statmap['ATK%']] = 105
hat[statmap['EM']] = 56

In [938]:
dmg = damage2.NormalDmg(2.23) * damage2.CritMult() * damage2.VapeMelt(1.5)
stats = diona_stats + flower + feather + sands + cup + hat
print(f'Current damage per charged shot: {dmg.eval(stats)[0]}')

Current damage per charged shot: 12169.443822372452


### Queries for replacing an artifact
Let's say I have a couple candidate artifacts to upgrade. All of them are EM sands, and they are all at lvl 0.

In [943]:
def to_stat(subs):
    st = np.zeros(len(statnames))
    for k, v in subs.items():
        st[statmap[k]] = v * artifact2.sub_vals[k] / 10
    return st

# Want to beat this number
sands0 = np.zeros(len(statnames))
sands0[statmap['EM']] = 187

stats = diona_stats + flower + feather + sands0 + cup + hat
subs_orig = {'HP': 18, 'DEF': 17, 'CR': 32, 'CD': 10}
dmg0 = dmg.eval(stats + to_stat(subs_orig))[0]

def eval_potential(new_subs, rolls=0, err=None):
    rolls_left = 5 - rolls
    v, g, h = dmg.eval(stats + to_stat(new_subs))
    ix = [statmap[k] for k in sub1]
    scale = np.array([artifact2.sub_vals[k]/10 for k in sub1])

    gnorm = g[ix] * scal
    hnorm = h[ix][:,ix] * np.outer(scal, scal)
    
    if len(ix) == 3:
        # force the thing to be 4x4.
        gnorm = np.append(gnorm, 0)
        hnorm = np.pad(hnorm, [0,1])
        rolls_left = 4

    h_lin1 = np.sum(hz) + np.sum(np.diag(hz)) + 2 * (np.sum(hz, axis=0) + np.diag(hz))
    h_lin = 50 / 7 * np.linalg.pinv(np.eye(4) + 1) @ h_lin1
    
    weights = gz + h_lin/2
    weights[weights < 0] = 0
        
    return prob2(dmg0 - v, weights, N=rolls_left, b=0, err=err)

In [944]:
# Pretend all of these are unupgraded EM main-stat sands.
sub1 = {'CD': 7, 'HP': 9, 'ATK%': 8, 'ATK': 7}
sub2 = {'CD': 8, 'CR': 9, 'HP': 10, 'ATK': 7}
sub3 = {'CD': 10, 'CR': 10, 'ATK%': 9}
sub4 = {'ATK': 8, 'DEF': 10, 'ATK%': 9}

In [959]:
%timeit eval_potential(sub4, err=0)

20.6 ms ± 647 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [960]:
%timeit eval_potential(sub4, err=None)

4.2 ms ± 23.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [958]:
print(eval_potential(sub4, err=0))
print(eval_potential(sub4, err=None))

0.27045917510986334
0.26915888130853294


In [936]:
dmg.eval(stats + to_stat(subs_orig))[0]

12172.241709086176

In [916]:
subs_orig

{'HP': 18, 'DEF': 17, 'CR': 13, 'CD': 10}

In [921]:
to_stat(subs_orig)[statmap['CD']]

77.7

In [915]:
to_stat(subs_orig2)

array([  0.   , 539.55 ,   0.   ,  39.355,   0.   ,   0.   ,  50.57 ,
        77.7  ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,
         0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,   0.   ,
         0.   ])

In [922]:
subs_orig2 = {'HP': 18, 'DEF': 17, 'CR': 13, 'DEF%': 10}
dmg.eval(stats + to_stat(subs_orig2))[0]

11284.160960709472

In [932]:
subs_orig = {'HP': 18, 'DEF': 17, 'CR': 27, 'DEF%': 10}
eval_potential(subs_orig, rolls=4)

150.76644583209054 [25.21828783  0.         30.95882571 14.23092477]


0.5

In [905]:
1/.0625

16.0

In [874]:
print(eval_potential(sub1))
print(eval_potential(sub2))
print(eval_potential(sub3))
print(eval_potential(sub4))

0.490494105706669
0.6056074687668723
0.9155252194726436
0.2872047984442604


In [865]:
prob2(new_th, ks, N=rolls_left, b=0)

TypeError: p2_existing() missing 4 required positional arguments: 'thresh', 'ks', 'existing', and 'rolls_left'

In [860]:
hz2

array([[0.        , 0.        , 0.06623656],
       [0.        , 0.        , 0.        ],
       [0.06623656, 0.        , 0.        ]])

In [839]:
v, g, h = dmg.eval(stats + to_stat(sub1))
ix = [statmap[k] for k in sub1]
scal = np.array([artifact2.sub_vals[k]/10 for k in sub1])

gz = g[ix] * scal
hz = h[ix][:,ix] * np.outer(scal, scal)

rhs = np.sum(hz) + np.sum(np.diag(hz)) + 2 * (np.sum(hz, axis=0) + np.diag(hz))
h_adjust = 50 / 7 * np.linalg.pinv(np.eye(4) + 1) @ rhs
gz + h_adjust

array([25.77167591, -0.27669404, 31.29359697, 14.31119457])

In [845]:
gz

array([24.66489975,  0.        , 30.62405446, 14.15065497])

In [844]:
delta = to_stat({'CD': 7.5, 'HP': 2.2, 'ATK%': 4.3})
print(dmg.eval(stats + to_stat(sub1) + delta)[0])
print(v + g@delta + delta@h@delta/2 )
print(v + (gz+h_adjust)@[7.5,2.2,4.3,0] )

11722.466872756158
11722.466872756158
11730.901870857402


In [835]:
gz

array([246.64899746,   0.        , 306.24054457, 141.50654965])

In [832]:
gz

array([246.64899746,   0.        , 306.24054457, 141.50654965])

In [828]:
np.sum(hz) + np.sum(np.diag(hz)) + 2 * (np.sum(hz, axis=0) + np.diag(hz))

array([38.7371656 , 19.3685828 , 32.61589455, 25.48985385])

In [813]:
dmg0 - v

768.5811476073068

In [816]:
h[ix][:,ix]

array([[0.        , 0.        , 0.0014622 , 0.00202521],
       [0.        , 0.        , 0.        , 0.        ],
       [0.0014622 , 0.        , 0.        , 0.        ],
       [0.00202521, 0.        , 0.        , 0.        ]])

In [None]:
h[ix][:,ix]

In [814]:
np.linalg.pinv(np.eye(4) + 1)

array([[ 0.8, -0.2, -0.2, -0.2],
       [-0.2,  0.8, -0.2, -0.2],
       [-0.2, -0.2,  0.8, -0.2],
       [-0.2, -0.2, -0.2,  0.8]])

In [809]:
g[ix] * scal

array([246.64899746,   0.        , 306.24054457, 141.50654965])

In [807]:
g[ix] * scal

array([262.90534328, 323.39378847,   0.        , 145.50356492])

In [787]:
artifact2.sub_vals

{'HP': 299.75,
 'ATK': 19.45,
 'DEF': 23.15,
 'HP%': 58.3,
 'ATK%': 58.3,
 'DEF%': 72.9,
 'EM': 23.31,
 'ER': 6.48,
 'CR': 38.9,
 'CD': 77.7}

In [788]:
oneatk = np.zeros(len(statnames))
oneatk[statmap['ATK%']] = 58.3
onecd = np.zeros(len(statnames))
onecd[statmap['CD']] = 77.7


In [791]:
print(dmg.eval(stats + to_stat(sub1) + oneatk)[0])
print(dmg.eval(stats + to_stat(sub1) + onecd)[0])

11709.901106047048
11650.309558935687


In [793]:
zz = stats + to_stat(sub1)

In [795]:
zz[statmap['ATK%']]

675.64

In [799]:
zz[statmap['ATK']] + zz[statmap['BaseATK']] * (1 + zz[statmap['ATK%']]/1000)

1567.42708

In [800]:
zz[statmap['BaseATK']] * .0583

42.0926

In [796]:
zz[statmap['CD']]

1041.39

In [786]:
h[ix][:,ix] * np.outer(scal, scal)

array([[0.        , 0.        , 6.62365588, 3.06063552],
       [0.        , 0.        , 0.        , 0.        ],
       [6.62365588, 0.        , 0.        , 0.        ],
       [3.06063552, 0.        , 0.        , 0.        ]])

In [None]:
dz = hat_subs
print(dmg.eval(stats + dz)[0])
print(v + g @ dz + dz @ h @ dz/2)

12169.443822372452
12169.228615425942
