In [2]:
class Mat:
    def __init__(self, nrows, ncols):
        self.nrows: int = nrows
        self.ncols: int = ncols

    def __str__(self):
        return f"Mat({self.nrows}, {self.ncols})"
    
    def __repr__(self):
        return f"Mat({self.nrows}, {self.ncols})"
    
    def __mul__(self, other):
        return Mat(self.nrows, other.ncols)
    
    def __rmul__(self, other):
        return Mat(other.nrows, self.ncols)

In [3]:
def set3(a, b, c, d):
    return Mat(a, b), Mat(b, c), Mat(c, d)

def set4(a, b, c, d, e):
    return Mat(a, b), Mat(b, c), Mat(c, d), Mat(d, e)

In [4]:
def mul_cost(X: Mat, Y: Mat):
    return X.nrows * X.ncols * Y.ncols

In [5]:
def optimal_cost(X: Mat, Y: Mat, Z: Mat, W: Mat):
    xy_z__w: int = mul_cost(X, Y) + mul_cost(X*Y, Z) + mul_cost((X*Y)*Z, W) # cost(A=XY) + cost(B=AZ) + cost(BW)
    x_yz__w: int = mul_cost(Y, Z) + mul_cost(X, Y*Z) + mul_cost(X*(Y*Z), W) # cost(A=YZ) + cost(B=XA) + cost(BW)
    x__yz_w: int = mul_cost(Y, Z) + mul_cost(Y*Z, W) + mul_cost(X, (Y*Z)*W) # cost(A=YZ) + cost(B=AW) + cost(XB)
    x__y_zw: int = mul_cost(Z, W) + mul_cost(Y, Z*W) + mul_cost(X, Y*(Z*W)) # cost(A=ZW) + cost(B=YA) + cost(XB)
    xy__zw: int  = mul_cost(X, Y) + mul_cost(Z, W) + mul_cost(X*Y, Z*W) # cost(A=XY) + cost(B=ZW) + cost(AB)

    return {
        "xy_z__w": xy_z__w,
        "x_yz__w": x_yz__w,
        "x__yz_w": x__yz_w,
        "x__y_zw": x__y_zw,
        "xy__zw": xy__zw
    }

# {'xy_z__w': 282000, 'x_yz__w': 32000, 'x__yz_w': 28200, 'x__y_zw': 55500, 'xy__zw': 393000}
optimal_cost(Mat(100, 10), Mat(10, 150), Mat(150, 8), Mat(8, 15))

{'xy_z__w': 282000,
 'x_yz__w': 32000,
 'x__yz_w': 28200,
 'x__y_zw': 55500,
 'xy__zw': 393000}

In [6]:
optimal_cost(Mat(90, 10), Mat(10, 90), Mat(90, 10), Mat(10, 10))

{'xy_z__w': 171000,
 'x_yz__w': 27000,
 'x__yz_w': 19000,
 'x__y_zw': 27000,
 'xy__zw': 171000}

In [7]:
opt = optimal_cost(*set4(90, 25, 100, 25, 10))
opt

{'xy_z__w': 472500,
 'x_yz__w': 141250,
 'x__yz_w': 91250,
 'x__y_zw': 72500,
 'xy__zw': 340000}

In [34]:
opt = optimal_cost(*set4(200, 175, 250, 150, 10))
opt

{'xy_z__w': 16550000,
 'x_yz__w': 12112500,
 'x__yz_w': 7175000,
 'x__y_zw': 1162500,
 'xy__zw': 9625000}

In [8]:
min(opt, key=opt.get)

'x__y_zw'

In [9]:
sorted_opt = sorted(opt.items(), key=lambda x: x[1])
-(sorted_opt[0][1] - sorted_opt[1][1])

18750

In [10]:
def greedy_cost(a, b, c, d, e): # greedy cost of ((XY)Z)W
    # layer 1 for (XY)Z, this layer chooses between (XY)Z and X(YZ)
    xy_z = (a * b * c) + (a * c * d)
    x_yz = (b * c * d) + (a * b * d)

    sol = []
    if xy_z <= x_yz: # (XY)Z < X(YZ), A = XY with dims a x c
        # options are: (AZ)W and A(ZW)
        az_w = (a * c * d) + (a * d * e)
        a_zw = (c * d * e) + (a * c * e)
        
        if az_w < a_zw:
            sol.append("xy_z__w")
        elif a_zw < az_w:
            sol.append("xy__zw")
        else:
            sol.append("xy_z__w")
            sol.append("xy__zw")
    elif x_yz < xy_z: # X(YZ) < (XY)Z, B = YZ with dims b x d
        # options are: (XB)W and X(BW)
        xb_w = (a * b * d) + (a * d * e)
        x_bw = (b * d * e) + (a * b * e)

        if xb_w < x_bw:
            sol.append("x_yz__w")
        elif x_bw < xb_w:
            sol.append("x__yz_w")
        else:
            sol.append("x_yz__w")
            sol.append("x__yz_w")

    return sol

In [28]:
options = []
for a in range(200, 0, -10):
    for b in range(0, 200, 10):
        for c in range(200, 0, -10):
            for d in range(0, 200, 10):
                for e in range(0, 200, 10):
                    X = Mat(a, b)
                    Y = Mat(b, c)
                    Z = Mat(c, d)
                    W = Mat(d, e)
                    
                    opt_cost = optimal_cost(X, Y, Z, W)
                    min_cost = min(opt_cost, key=opt_cost.get)

                    if min_cost == 'x__y_zw':
                        sorted_opt_cost = sorted(opt_cost.items(), key=lambda x: x[1])
                        if sorted_opt_cost[0][1] == 0 or sorted_opt_cost[1][1] == 0:
                            continue
                        top_diff = sorted_opt_cost[1][1] / sorted_opt_cost[0][1]
                        if top_diff > 2:
                            print(a, b, c, d, e, top_diff)

200 30 200 80 10 2.0142857142857142
200 30 200 90 10 2.09
200 30 200 100 10 2.15625
200 30 200 110 10 2.2147058823529413
200 30 200 120 10 2.2666666666666666
200 30 200 130 10 2.3131578947368423
200 30 200 140 10 2.355
200 30 200 150 10 2.392857142857143
200 30 200 160 10 2.4272727272727272
200 30 200 170 10 2.458695652173913
200 30 200 180 10 2.4875
200 30 200 190 10 2.514
200 30 190 80 10 2.007434944237918
200 30 190 90 10 2.0833333333333335
200 30 190 100 10 2.1498371335504887
200 30 190 110 10 2.208588957055215
200 30 190 120 10 2.260869565217391
200 30 190 130 10 2.3076923076923075
200 30 190 140 10 2.349869451697128
200 30 190 150 10 2.388059701492537
200 30 190 160 10 2.4228028503562946
200 30 190 170 10 2.4545454545454546
200 30 190 180 10 2.4836601307189543
200 30 190 190 10 2.510460251046025
200 30 180 90 10 2.0760869565217392
200 30 180 100 10 2.142857142857143
200 30 180 110 10 2.201923076923077
200 30 180 120 10 2.2545454545454544
200 30 180 130 10 2.3017241379310347
200 3