In [1]:
from dataclasses import dataclass
from typing import List

import sympy as sp
import sympy.stats as sps
from sympy import Expr, Symbol

# Expectation

In [2]:
lam = Symbol('lambda', positive=True)
eps = Symbol('epsilon', positive=True)
eps_plus = Symbol('epsilon_+', positive=True)
eps_minus = Symbol('epsilon_-', positive=True)
d = 4
N = 5

In [3]:
def calc_E_x_c(c: int, d: int) -> sp.Matrix:
    x = sp.Matrix([[(1 if i == c else sps.Uniform(f'X_{{{i},N+1}}', 0, lam))] for i in range(d)])
    x -= sp.Matrix([eps]*d)

    E_x_c = sp.zeros(d, 1)
    for j in range(d):
        E_x_c[j, 0] = sps.E(x[j, 0])

    return E_x_c

calc_E_x_c(1, d)

Matrix([
[-epsilon + lambda/2],
[        1 - epsilon],
[-epsilon + lambda/2],
[-epsilon + lambda/2]])

In [4]:
def calc_E_G_c(c: int, d: int, N: int, show: bool = False) -> sp.Matrix:
    X = sp.Matrix([[(1 if i == c else sps.Uniform(f'X_{{{i},{n}}}', 0, lam)) for n in range(N)] for i in range(d)])
    X = X.row_insert(d, sp.Matrix([[1]*N]))

    if show:
        display(X)

    G = X * X.T / N

    E_G_c = sp.zeros(d+1, d+1)
    for i in range(d+1):
        for j in range(d+1):
            E_G_c[i, j] = sps.E(G[i, j])

    return E_G_c

calc_E_G_c(1, d, N, True)

Matrix([
[X_{0,0}, X_{0,1}, X_{0,2}, X_{0,3}, X_{0,4}],
[      1,       1,       1,       1,       1],
[X_{2,0}, X_{2,1}, X_{2,2}, X_{2,3}, X_{2,4}],
[X_{3,0}, X_{3,1}, X_{3,2}, X_{3,3}, X_{3,4}],
[      1,       1,       1,       1,       1]])

Matrix([
[lambda**2/3, lambda/2, lambda**2/4, lambda**2/4, lambda/2],
[   lambda/2,        1,    lambda/2,    lambda/2,        1],
[lambda**2/4, lambda/2, lambda**2/3, lambda**2/4, lambda/2],
[lambda**2/4, lambda/2, lambda**2/4, lambda**2/3, lambda/2],
[   lambda/2,        1,    lambda/2,    lambda/2,        1]])

In [5]:
def calc_E_Gx_c(c: int, d: int, N: int, i: int) -> sp.Matrix:
    E_G_c = calc_E_G_c(c, d, N)
    E_x_c = sp.Matrix([[(eps_plus if j == c else eps_minus)] for j in range(d)])
    return E_G_c[i, :].T * E_x_c.T # type: ignore

E_Gx_c = calc_E_Gx_c(1, d, N, 0)
display(E_Gx_c)

E_Gx_c = calc_E_Gx_c(1, d, N, 1)
display(E_Gx_c)

E_Gx_c = calc_E_Gx_c(1, d, N, d)
display(E_Gx_c)

Matrix([
[epsilon_-*lambda**2/3, epsilon_+*lambda**2/3, epsilon_-*lambda**2/3, epsilon_-*lambda**2/3],
[   epsilon_-*lambda/2,    epsilon_+*lambda/2,    epsilon_-*lambda/2,    epsilon_-*lambda/2],
[epsilon_-*lambda**2/4, epsilon_+*lambda**2/4, epsilon_-*lambda**2/4, epsilon_-*lambda**2/4],
[epsilon_-*lambda**2/4, epsilon_+*lambda**2/4, epsilon_-*lambda**2/4, epsilon_-*lambda**2/4],
[   epsilon_-*lambda/2,    epsilon_+*lambda/2,    epsilon_-*lambda/2,    epsilon_-*lambda/2]])

Matrix([
[epsilon_-*lambda/2, epsilon_+*lambda/2, epsilon_-*lambda/2, epsilon_-*lambda/2],
[         epsilon_-,          epsilon_+,          epsilon_-,          epsilon_-],
[epsilon_-*lambda/2, epsilon_+*lambda/2, epsilon_-*lambda/2, epsilon_-*lambda/2],
[epsilon_-*lambda/2, epsilon_+*lambda/2, epsilon_-*lambda/2, epsilon_-*lambda/2],
[         epsilon_-,          epsilon_+,          epsilon_-,          epsilon_-]])

Matrix([
[epsilon_-*lambda/2, epsilon_+*lambda/2, epsilon_-*lambda/2, epsilon_-*lambda/2],
[         epsilon_-,          epsilon_+,          epsilon_-,          epsilon_-],
[epsilon_-*lambda/2, epsilon_+*lambda/2, epsilon_-*lambda/2, epsilon_-*lambda/2],
[epsilon_-*lambda/2, epsilon_+*lambda/2, epsilon_-*lambda/2, epsilon_-*lambda/2],
[         epsilon_-,          epsilon_+,          epsilon_-,          epsilon_-]])

In [6]:
def calc_E_Gx(d: int, N: int, i: int) -> sp.Matrix:
    E_Gx = sp.zeros(d+1, d)
    for c in range(d):
        E_Gx_c = calc_E_Gx_c(c, d, N, i)
        E_Gx += E_Gx_c
    return E_Gx

E_Gx = calc_E_Gx(d, N, 0)
display(E_Gx)

E_Gx = calc_E_Gx(d, N, d)
display(E_Gx)

Matrix([
[                                epsilon_+ + epsilon_-*lambda**2,      epsilon_+*lambda**2/3 + 2*epsilon_-*lambda**2/3 + epsilon_-,      epsilon_+*lambda**2/3 + 2*epsilon_-*lambda**2/3 + epsilon_-,      epsilon_+*lambda**2/3 + 2*epsilon_-*lambda**2/3 + epsilon_-],
[epsilon_+*lambda/2 + epsilon_-*lambda**2/2 + epsilon_-*lambda/2,  epsilon_+*lambda/2 + epsilon_-*lambda**2/2 + epsilon_-*lambda/2, epsilon_+*lambda**2/4 + epsilon_-*lambda**2/4 + epsilon_-*lambda, epsilon_+*lambda**2/4 + epsilon_-*lambda**2/4 + epsilon_-*lambda],
[epsilon_+*lambda/2 + epsilon_-*lambda**2/2 + epsilon_-*lambda/2, epsilon_+*lambda**2/4 + epsilon_-*lambda**2/4 + epsilon_-*lambda,  epsilon_+*lambda/2 + epsilon_-*lambda**2/2 + epsilon_-*lambda/2, epsilon_+*lambda**2/4 + epsilon_-*lambda**2/4 + epsilon_-*lambda],
[epsilon_+*lambda/2 + epsilon_-*lambda**2/2 + epsilon_-*lambda/2, epsilon_+*lambda**2/4 + epsilon_-*lambda**2/4 + epsilon_-*lambda, epsilon_+*lambda**2/4 + epsilon_-*lambda**2/4 + epsilon_-*lambda

Matrix([
[                 epsilon_+ + 3*epsilon_-*lambda/2, epsilon_+*lambda/2 + epsilon_-*lambda + epsilon_-, epsilon_+*lambda/2 + epsilon_-*lambda + epsilon_-, epsilon_+*lambda/2 + epsilon_-*lambda + epsilon_-],
[epsilon_+*lambda/2 + epsilon_-*lambda + epsilon_-,                  epsilon_+ + 3*epsilon_-*lambda/2, epsilon_+*lambda/2 + epsilon_-*lambda + epsilon_-, epsilon_+*lambda/2 + epsilon_-*lambda + epsilon_-],
[epsilon_+*lambda/2 + epsilon_-*lambda + epsilon_-, epsilon_+*lambda/2 + epsilon_-*lambda + epsilon_-,                  epsilon_+ + 3*epsilon_-*lambda/2, epsilon_+*lambda/2 + epsilon_-*lambda + epsilon_-],
[epsilon_+*lambda/2 + epsilon_-*lambda + epsilon_-, epsilon_+*lambda/2 + epsilon_-*lambda + epsilon_-, epsilon_+*lambda/2 + epsilon_-*lambda + epsilon_-,                  epsilon_+ + 3*epsilon_-*lambda/2],
[                          epsilon_+ + 3*epsilon_-,                           epsilon_+ + 3*epsilon_-,                           epsilon_+ + 3*epsilon_-,              

# Optimization

In [7]:
@dataclass(frozen=True)
class Const:
    lam: Symbol = Symbol('lambda', positive=True)
    eps: Symbol = Symbol('epsilon', positive=True)
    d: Symbol = Symbol('d', positive=True)
    dd: Symbol = Symbol("d'", positive=True)
    b_d_plus_1: Symbol = Symbol('b_{d+1}', positive=True)

    eps_plus: Expr = 1 - eps # type: ignore
    eps_minus: Expr = lam / 2 - eps # type: ignore

    r1: Expr = eps_plus + lam**2 / 3 * (d - 1) * eps_minus # type: ignore
    r2: Expr = eps_minus + lam**2 / 3 * eps_plus + lam**2 / 3 * (d - 2) * eps_minus # type: ignore
    r3: Expr = eps_plus + lam / 2 * (d - 1) * eps_minus # type: ignore
    r4: Expr = eps_minus + lam / 2 * eps_plus + lam / 2 * (d - 2) * eps_minus # type: ignore
    r5: Expr = lam / 2 * eps_plus + lam / 2 * eps_minus + lam**2 / 4 * (d - 2) * eps_minus # type: ignore
    r6: Expr = lam / 2 * eps_minus + lam / 2 * eps_minus + lam**2 / 4 * eps_plus + lam**2 / 4 * (d - 3) * eps_minus # type: ignore
    r7: Expr = eps_plus + (d - 1) * eps_minus # type: ignore

    s1: Expr = b_d_plus_1 * r7 + r3 + (dd - 1) * r4 # type: ignore
    s2: Expr = b_d_plus_1 * r7 + dd * r4 # type: ignore
    s3: Expr = b_d_plus_1 * r3 + r1 + (dd - 1) * r5 # type: ignore
    s4: Expr = b_d_plus_1 * r3 + dd * r5 # type: ignore
    s5: Expr = b_d_plus_1 * r4 + r2 + r5 + (dd - 2) * r6 # type: ignore
    s6: Expr = b_d_plus_1 * r4 + r2 + (dd - 1) * r6 # type: ignore
    s7: Expr = b_d_plus_1 * r4 + r5 + (dd - 1) * r6 # type: ignore
    s8: Expr = b_d_plus_1 * r4 + dd * r6 # type: ignore

    eps1: Expr = sp.solve(r1, eps)[0].factor()
    eps2: Expr = sp.solve(r2, eps)[0].factor()
    eps3: Expr = sp.solve(r3, eps)[0].factor()
    eps4: Expr = sp.solve(r4, eps)[0].factor()
    eps5: Expr = sp.solve(r5, eps)[0].factor()
    eps6: Expr = sp.solve(r6, eps)[0].factor()
    eps7: Expr = sp.solve(r7, eps)[0].factor()

    eps_s5 = sp.solve(s5.subs(b_d_plus_1, 1).subs(dd, d), eps)[0].factor()

    def calc_score(self, booleans: List[int]) -> Expr:
        assert len(booleans) == 8
        score = booleans[0] * self.dd * self.s1 # type: ignore
        score += booleans[1] * (self.d - self.dd) * self.s2 # type: ignore
        score += booleans[2] * self.dd * self.s3 # type: ignore
        score += booleans[3] * (self.d - self.dd) * self.s4 # type: ignore
        score += booleans[4] * self.dd * (self.dd - 1) * self.s5 # type: ignore
        score += booleans[5] * self.dd * (self.d - self.dd) * self.s6 # type: ignore
        score += booleans[6] * self.dd * (self.d - self.dd) * self.s7 # type: ignore
        score += booleans[7] * (self.d - self.dd) * (self.d - self.dd - 1) * self.s8 # type: ignore
        return score.factor()

c = Const()

## $\epsilon_1, \ldots, \epsilon_7$

In [8]:
for i in range(1, 8):
    name = f'eps{i}'
    display(name, getattr(c, name))
display(c.eps_s5)

'eps1'

(d*lambda**3 - lambda**3 + 6)/(2*(d*lambda**2 - lambda**2 + 3))

'eps2'

lambda*(d*lambda**2 - 2*lambda**2 + 2*lambda + 3)/(2*(d*lambda**2 - lambda**2 + 3))

'eps3'

(d*lambda**2 - lambda**2 + 4)/(2*(d*lambda - lambda + 2))

'eps4'

lambda*(d*lambda - 2*lambda + 4)/(2*(d*lambda - lambda + 2))

'eps5'

(d*lambda**2 - 2*lambda**2 + 2*lambda + 4)/(2*(d*lambda - 2*lambda + 4))

'eps6'

lambda*(d*lambda - 3*lambda + 6)/(2*(d*lambda - 2*lambda + 4))

'eps7'

(d*lambda - lambda + 2)/(2*d)

lambda*(3*d**2*lambda**2 - 8*d*lambda**2 + 24*d*lambda + 4*lambda**2 - 34*lambda + 48)/(2*(3*d**2*lambda**2 - 5*d*lambda**2 + 18*d*lambda + 2*lambda**2 - 18*lambda + 24))

In [9]:
c.eps1.apart(c.d)

lambda/2 - 3*(lambda - 2)/(2*(d*lambda**2 - lambda**2 + 3))

In [10]:
display('eps1 > eps3', (c.eps1 - c.eps3).factor()) # type: ignore
display('eps3 > eps5', (c.eps3 - c.eps5).factor()) # type: ignore
display('eps5 > eps7', (c.eps5 - c.eps7).factor()) # type: ignore
display('eps7 > eps_s5', (c.eps7 - c.eps_s5).factor()) # type: ignore
display('eps_s5 > eps4', (c.eps_s5 - c.eps4).factor()) # type: ignore
display('eps4 > eps6', (c.eps4 - c.eps6).factor()) # type: ignore
display('eps6 > eps2', (c.eps6 - c.eps2).factor()) # type: ignore

'eps1 > eps3'

lambda*(d - 1)*(lambda - 2)*(2*lambda - 3)/(2*(d*lambda - lambda + 2)*(d*lambda**2 - lambda**2 + 3))

'eps3 > eps5'

(lambda - 2)**2/((d*lambda - 2*lambda + 4)*(d*lambda - lambda + 2))

'eps5 > eps7'

(d - 2)*(lambda - 2)**2/(2*d*(d*lambda - 2*lambda + 4))

'eps7 > eps_s5'

(lambda - 2)*(3*d*lambda**2 - 6*d*lambda - 2*lambda**2 + 18*lambda - 24)/(2*d*(3*d**2*lambda**2 - 5*d*lambda**2 + 18*d*lambda + 2*lambda**2 - 18*lambda + 24))

'eps_s5 > eps4'

-lambda**2*(lambda - 2)/((d*lambda - lambda + 2)*(3*d**2*lambda**2 - 5*d*lambda**2 + 18*d*lambda + 2*lambda**2 - 18*lambda + 24))

'eps4 > eps6'

lambda*(lambda - 2)**2/(2*(d*lambda - 2*lambda + 4)*(d*lambda - lambda + 2))

'eps6 > eps2'

-lambda*(lambda - 3)*(lambda - 2)*(lambda - 1)/(2*(d*lambda - 2*lambda + 4)*(d*lambda**2 - lambda**2 + 3))

## Weak Adversarial Setting (Case 1) ($0 \leq \epsilon \leq \epsilon_6$)

In [11]:
t = c.r4 + c.r2 # type: ignore
display(t.factor().collect(c.eps).collect(c.lam))
t = t.subs(c.eps, c.eps6)
display(t.factor())

-epsilon*(lambda**2*(4*d - 4) + lambda*(6*d - 6) + 24)/12 + lambda**3*(d/6 - 1/3) + lambda**2*(d/4 - 1/6) + 3*lambda/2

lambda**2*(lambda - 2)*(2*lambda - 5)/(12*(d*lambda - 2*lambda + 4))

## Weak Adversarial Setting (Case 2) ($\epsilon_6 \leq \epsilon \leq \epsilon_4$)

$s_5(d', 1)$ is nonnegative.

In [12]:
c.s5.subs(c.b_d_plus_1, 1).subs(c.dd, c.d).subs(c.eps, c.eps4).factor() # type: ignore

-lambda**2*(lambda - 2)/(12*(d*lambda - lambda + 2))

$s_7(d', 1)$ is nonnegative.

In [13]:
c.s7.subs(c.b_d_plus_1, 1).subs(c.dd, c.d).subs(c.eps, c.eps4).factor() # type: ignore

-lambda*(lambda - 2)**3/(8*(d*lambda - lambda + 2))

$s_8(d', 1)$ is larger than $s_6(d', 1)$.

In [14]:
t = (c.s8 - c.s6).subs(c.b_d_plus_1, 1) # type: ignore
display(t.factor().collect(c.eps).collect(c.lam))
t = t.subs(c.eps, c.eps6)
display(t.factor())

epsilon*(lambda**2*(2*d + 4) - 24*lambda + 24)/24 + lambda**3*(-d/24 - 1/24) + 5*lambda**2/12 - lambda/2

-lambda*(lambda - 3)*(lambda - 2)*(lambda - 1)/(6*(d*lambda - 2*lambda + 4))

Case: $s_6(d', 1) \geq 0$ and $s_8(d', 1) \geq 0$

In [15]:
t = c.calc_score([1]*8).diff(c.dd).factor()
display(t)
t = - sp.numer(t).args[1] # type: ignore
display(t)
display(t.subs(c.eps, c.eps4).factor())

-(2*d*epsilon - d*lambda + lambda - 2)*(3*d**2*lambda**2 - 5*d*lambda**2 + 18*d*lambda + 2*lambda**2 - 18*lambda + 24)/24

-2*d*epsilon + d*lambda - lambda + 2

(lambda - 2)**2/(d*lambda - lambda + 2)

Case: $s_6(d', 1) \leq 0$ and $s_8(d', 1) \geq 0$

In [16]:
t = c.calc_score([1, 1, 1, 1, 1, 0, 1, 1]).subs(c.b_d_plus_1, 1).diff(c.dd).factor().collect(c.eps).collect(c.dd).collect(c.d) # type: ignore
display(t)
t = t.subs(c.eps, c.eps4).factor().collect(c.dd)
display(t)
t_coeff_dd = sum(sp.numer(t).args[1].args[4:6]) # type: ignore
display(t_coeff_dd)
display(t_coeff_dd.subs(c.dd, c.d).factor())
display(t.subs(c.dd, c.d).factor())

d**3*lambda**3/8 + d**2*(-3*lambda**3/8 + 3*lambda**2/4) + d*(lambda**3/4 - lambda**2 + lambda) + d'**2*(3*d*lambda**3/8 - 9*lambda**3/8 + 9*lambda**2/4) + d'*(-d**2*lambda**3/4 + d*(5*lambda**3/6 - lambda**2) + lambda**3/12 - 11*lambda**2/6 + 3*lambda) - epsilon*(6*d**3*lambda**2 + d**2*(-12*lambda**2 + 24*lambda) + d'**2*(18*d*lambda**2 - 36*lambda**2 + 72*lambda) + d'*(-12*d**2*lambda**2 + d*(28*lambda**2 - 24*lambda) + 8*lambda**2 - 72*lambda + 96))/24 - lambda**3/12 + 11*lambda**2/12 - 5*lambda/2 + 2

(lambda - 2)*(3*d**2*lambda**3 - 6*d**2*lambda**2 - 8*d*lambda**3 + 42*d*lambda**2 - 48*d*lambda + d'**2*(-9*lambda**3 + 18*lambda**2) + d'*(6*d*lambda**3 - 12*d*lambda**2 + 6*lambda**3 - 28*lambda**2 + 24*lambda) + 2*lambda**3 - 22*lambda**2 + 60*lambda - 48)/(24*(d*lambda - lambda + 2))

d'**2*(-9*lambda**3 + 18*lambda**2) + d'*(6*d*lambda**3 - 12*d*lambda**2 + 6*lambda**3 - 28*lambda**2 + 24*lambda)

-d*lambda*(3*d*lambda**2 - 6*d*lambda - 6*lambda**2 + 28*lambda - 24)

-(lambda - 2)*(d*lambda**3 - 7*d*lambda**2 + 12*d*lambda - lambda**3 + 11*lambda**2 - 30*lambda + 24)/(12*(d*lambda - lambda + 2))

Case: $s_6(d', 1) \leq 0$ and $s_8(d', 1) \leq 0$

In [17]:
t = c.calc_score([1, 1, 1, 1, 1, 0, 1, 0]).subs(c.b_d_plus_1, 1).diff(c.dd).factor().collect(c.eps).collect(c.dd).collect(c.d) # type: ignore
display(t)
t = t.subs(c.eps, c.eps4).factor().collect(c.dd)
display(t)
display(t.subs(c.dd, c.d).factor())

d**2*(lambda**3/8 + lambda**2/2) + d*(-lambda**3/8 - 3*lambda**2/2 + 3*lambda) + d'*(d**2*lambda**3/4 + d*(-11*lambda**3/12 + 3*lambda**2/2) + 5*lambda**3/6 - 7*lambda**2/3 + lambda) - epsilon*(d**2*(6*lambda**2 + 24*lambda) + d*(-12*lambda**2 - 12*lambda + 48) + d'*(12*d**2*lambda**2 + d*(-32*lambda**2 + 48*lambda) + 32*lambda**2 - 96*lambda + 48) + 12*lambda - 24)/24 - lambda**3/12 + 17*lambda**2/12 - 7*lambda/2 + 2

(lambda - 2)*(6*d**2*lambda**3 - 12*d**2*lambda**2 - 11*d*lambda**3 + 48*d*lambda**2 - 48*d*lambda + d'*(-6*d*lambda**3 + 12*d*lambda**2 + 12*lambda**3 - 40*lambda**2 + 24*lambda) + 2*lambda**3 - 22*lambda**2 + 60*lambda - 48)/(24*(d*lambda - lambda + 2))

(lambda - 2)*(d*lambda**3 + 8*d*lambda**2 - 24*d*lambda + 2*lambda**3 - 22*lambda**2 + 60*lambda - 48)/(24*(d*lambda - lambda + 2))

## Adversarial Setting ($\epsilon = \epsilon_7$)

$s_1(d', 1)$ is nonnegative.

In [None]:
c.s1.subs(c.b_d_plus_1, 0).subs(c.eps, c.eps7).factor() # type: ignore

(d - d')*(lambda - 2)**2/(4*d)

Optimal $b_{d+1}$

In [19]:
t = c.calc_score([0, 0, 1, 1, 1, 0, 1, 0]).expand().coeff(c.b_d_plus_1)
t = t.subs(c.eps, c.eps7).factor()
display(t)

(d - 1)*(d - d')*(lambda - 2)**2/(4*d)

$s_7(d', 1)$ is larger than $s_5(d', 1)$.

In [None]:
(c.s7 - c.s5).subs(c.b_d_plus_1, 1).subs(c.eps, c.eps7).factor() # type: ignore

-(lambda - 2)*(lambda**2 - 6*lambda + 6)/(12*d)

Case: $s_5(d', 1), s_7(d', 1) \geq 0$

In [21]:
t = c.calc_score([1, 0, 1, 1, 1, 0, 1, 0]).subs(c.b_d_plus_1, 1).diff(c.dd).factor().collect(c.eps).collect(c.dd).collect(c.d) # type: ignore
display(t)
t = t.subs(c.eps, c.eps7).factor().collect(c.dd)
display(t)
d = sp.solve(c.s5.subs(c.b_d_plus_1, 1).subs(c.eps, c.eps7), c.dd)[0]
display(d.factor())
display(t.subs(c.dd, d).factor().collect(c.d))

d**2*(lambda**3/8 + lambda**2/4) + d*(-lambda**3/8 - lambda**2 + 5*lambda/2) + d'*(d**2*lambda**3/4 + d*(-11*lambda**3/12 + 2*lambda**2) + 5*lambda**3/6 - 10*lambda**2/3 + 3*lambda) - epsilon*(d**2*(6*lambda**2 + 12*lambda) + d*(48 - 12*lambda**2) + d'*(12*d**2*lambda**2 + d*(-32*lambda**2 + 72*lambda) + 32*lambda**2 - 120*lambda + 96) + 12*lambda - 24)/24 - lambda**3/12 + 17*lambda**2/12 - 4*lambda + 3

(lambda - 2)*(3*d**2*lambda**2 - 6*d**2*lambda - 4*d*lambda**2 + 12*d*lambda - 6*d + d'*(-6*d*lambda**2 + 12*d*lambda + 8*lambda**2 - 30*lambda + 24) + 3*lambda - 6)/(12*d)

(3*d*lambda**2 - 6*d*lambda + 2*lambda**2 - 18*lambda + 24)/(6*lambda*(lambda - 2))

-(d*(6*lambda**4 - 57*lambda**3 + 144*lambda**2 - 108*lambda) - 8*lambda**4 + 93*lambda**3 - 354*lambda**2 + 540*lambda - 288)/(36*d*lambda)

Case: $s_5(d', 1) \leq 0, s_7(d', 1) \geq 0$

In [22]:
t = c.calc_score([1, 0, 1, 1, 0, 0, 1, 0]).subs(c.b_d_plus_1, 1).diff(c.dd).factor().collect(c.eps).collect(c.dd).collect(c.d) # type: ignore
display(t)
t = t.subs(c.eps, c.eps7).factor().collect(c.dd)
display(t)
d = sp.solve(c.s7.subs(c.b_d_plus_1, 1).subs(c.eps, c.eps7), c.dd)[0]
display(d.factor())
display(t.subs(c.dd, d).factor().collect(c.d))

d**2*(lambda**3/8 + lambda**2/4) + d*(-lambda**3/12 - 3*lambda**2/4 + 5*lambda/2) + d'**2*(-3*d*lambda**3/8 + 9*lambda**3/8 - 9*lambda**2/4) + d'*(d**2*lambda**3/4 + d*(-3*lambda**3/4 + 3*lambda**2/2) - lambda**3/4 + lambda**2 - lambda) - epsilon*(d**2*(6*lambda**2 + 12*lambda) + d*(-10*lambda**2 + 12*lambda + 48) + d'**2*(-18*d*lambda**2 + 36*lambda**2 - 72*lambda) + d'*(12*d**2*lambda**2 + d*(-24*lambda**2 + 48*lambda)) + 4*lambda**2 - 24*lambda + 24)/24 + lambda**3/12 - 2*lambda + 3

(lambda - 2)*(6*d**2*lambda**2 - 12*d**2*lambda - 5*d*lambda**2 + 18*d*lambda - 12*d + d'**2*(18*lambda**2 - 36*lambda) + d'*(-18*d*lambda**2 + 36*d*lambda) + 2*lambda**2 - 12*lambda + 12)/(24*d)

(d*lambda - 2)/(2*lambda)

(lambda - 2)*(d**2*(3*lambda**3 - 6*lambda**2) + d*(-10*lambda**3 + 36*lambda**2 - 24*lambda) + 4*lambda**3 - 24*lambda**2 + 60*lambda - 72)/(48*d*lambda)

Case: $s_5(d', 1), s_7(d', 1) \leq 0$

In [23]:
t = c.calc_score([1, 0, 1, 1, 0, 0, 0, 0]).subs(c.b_d_plus_1, 1).diff(c.dd).factor().collect(c.eps).collect(c.dd).collect(c.d) # type: ignore
display(t)
t = t.subs(c.eps, c.eps7).factor().collect(c.dd)
display(t)
t = t.subs(c.dd, c.d-1).factor().collect(c.lam) # type: ignore
display(t)
t1 = c.calc_score([1, 0, 1, 1, 0, 0, 0, 0]).subs(c.b_d_plus_1, 1).subs(c.eps, c.eps7).subs(c.dd, c.d)
t2 = c.calc_score([1, 0, 1, 1, 0, 0, 0, 0]).subs(c.b_d_plus_1, 1).subs(c.eps, c.eps7).subs(c.dd, c.d-1) # type: ignore
t = t1 - t2 # type: ignore
display(t.factor())

d**2*lambda**3/8 + d*(-5*lambda**3/24 + lambda**2/4 + lambda) + d'*(d*lambda**2/2 - lambda**2 + 2*lambda) - epsilon*(6*d**2*lambda**2 + d*(-10*lambda**2 + 24*lambda + 24) + d'*(24*d*lambda - 24*lambda + 48) + 4*lambda**2 - 24*lambda + 24)/24 + lambda**3/12 - 2*lambda + 3

(lambda - 2)*(3*d**2*lambda**2 - 6*d**2*lambda - 5*d*lambda**2 + 24*d*lambda - 24*d + d'*(24 - 12*lambda) + 2*lambda**2 - 12*lambda + 12)/(24*d)

(lambda - 2)*(lambda**2*(3*d**2 - 5*d + 2) + lambda*(-6*d**2 + 12*d) - 12)/(24*d)

lambda*(d - 1)*(lambda - 2)*(3*d*lambda - 6*d - 2*lambda + 6)/(24*d)