In [1]:
class Mat:
    def __init__(self, nrows, ncols):
        self.nrows: int = nrows
        self.ncols: int = ncols

    def __str__(self):
        return f"Mat({self.nrows}, {self.ncols})"
    
    def __repr__(self):
        return f"Mat({self.nrows}, {self.ncols})"
    
    def __mul__(self, other):
        return Mat(self.nrows, other.ncols)
    
    def __rmul__(self, other):
        return Mat(other.nrows, self.ncols)

In [2]:
def set3(a, b, c, d):
    return Mat(a, b), Mat(b, c), Mat(c, d)

def set4(a, b, c, d, e):
    return Mat(a, b), Mat(b, c), Mat(c, d), Mat(d, e)

In [3]:
def mul_cost(X: Mat, Y: Mat):
    return X.nrows * X.ncols * Y.ncols

In [4]:
def optimal_cost(X: Mat, Y: Mat, Z: Mat, W: Mat):
    xy_z__w: int = mul_cost(X, Y) + mul_cost(X*Y, Z) + mul_cost((X*Y)*Z, W) # cost(A=XY) + cost(B=AZ) + cost(BW)
    x_yz__w: int = mul_cost(Y, Z) + mul_cost(X, Y*Z) + mul_cost(X*(Y*Z), W) # cost(A=YZ) + cost(B=XA) + cost(BW)
    x__yz_w: int = mul_cost(Y, Z) + mul_cost(Y*Z, W) + mul_cost(X, (Y*Z)*W) # cost(A=YZ) + cost(B=AW) + cost(XB)
    x__y_zw: int = mul_cost(Z, W) + mul_cost(Y, Z*W) + mul_cost(X, Y*(Z*W)) # cost(A=ZW) + cost(B=YA) + cost(XB)
    xy__zw: int  = mul_cost(X, Y) + mul_cost(Z, W) + mul_cost(X*Y, Z*W) # cost(A=XY) + cost(B=ZW) + cost(AB)

    return {
        "xy_z__w": xy_z__w,
        "x_yz__w": x_yz__w,
        "x__yz_w": x__yz_w,
        "x__y_zw": x__y_zw,
        "xy__zw": xy__zw
    }

# {'xy_z__w': 282000, 'x_yz__w': 32000, 'x__yz_w': 28200, 'x__y_zw': 55500, 'xy__zw': 393000}
optimal_cost(Mat(100, 10), Mat(10, 150), Mat(150, 8), Mat(8, 15))

{'xy_z__w': 282000,
 'x_yz__w': 32000,
 'x__yz_w': 28200,
 'x__y_zw': 55500,
 'xy__zw': 393000}

In [5]:
optimal_cost(Mat(90, 10), Mat(10, 90), Mat(90, 10), Mat(10, 10))

{'xy_z__w': 171000,
 'x_yz__w': 27000,
 'x__yz_w': 19000,
 'x__y_zw': 27000,
 'xy__zw': 171000}

In [6]:
opt = optimal_cost(*set4(90, 25, 100, 25, 10))
opt

{'xy_z__w': 472500,
 'x_yz__w': 141250,
 'x__yz_w': 91250,
 'x__y_zw': 72500,
 'xy__zw': 340000}

In [7]:
opt = optimal_cost(*set4(200, 175, 250, 150, 10))
opt

{'xy_z__w': 16550000,
 'x_yz__w': 12112500,
 'x__yz_w': 7175000,
 'x__y_zw': 1162500,
 'xy__zw': 9625000}

In [8]:
min(opt, key=opt.get)

'x__y_zw'

In [9]:
sorted_opt = sorted(opt.items(), key=lambda x: x[1])
-(sorted_opt[0][1] - sorted_opt[1][1])

6012500

In [10]:
def greedy_cost(a, b, c, d, e): # greedy cost of ((XY)Z)W
    # layer 1 for (XY)Z, this layer chooses between (XY)Z and X(YZ)
    xy_z = (a * b * c) + (a * c * d)
    x_yz = (b * c * d) + (a * b * d)

    sol = []
    if xy_z <= x_yz: # (XY)Z < X(YZ), A = XY with dims a x c
        # options are: (AZ)W and A(ZW)
        az_w = (a * c * d) + (a * d * e)
        a_zw = (c * d * e) + (a * c * e)
        
        if az_w < a_zw:
            sol.append("xy_z__w")
        elif a_zw < az_w:
            sol.append("xy__zw")
        else:
            sol.append("xy_z__w")
            sol.append("xy__zw")
    elif x_yz < xy_z: # X(YZ) < (XY)Z, B = YZ with dims b x d
        # options are: (XB)W and X(BW)
        xb_w = (a * b * d) + (a * d * e)
        x_bw = (b * d * e) + (a * b * e)

        if xb_w < x_bw:
            sol.append("x_yz__w")
        elif x_bw < xb_w:
            sol.append("x__yz_w")
        else:
            sol.append("x_yz__w")
            sol.append("x__yz_w")

    return sol

In [15]:
options = []
for a in range(0, 30, 10):
    for b in range(0, 30, 10):
        for c in range(0, 30, 10):
            for d in range(0, 30, 10):
                for e in range(0, 30, 10):
                    X = Mat(a, b)
                    Y = Mat(b, c)
                    Z = Mat(c, d)
                    W = Mat(d, e)
                    
                    opt_cost = optimal_cost(X, Y, Z, W)
                    min_cost = min(opt_cost, key=opt_cost.get)

                    if min_cost == 'x__y_zw':
                        sorted_opt_cost = sorted(opt_cost.items(), key=lambda x: x[1])
                        if sorted_opt_cost[0][1] == 0 or sorted_opt_cost[1][1] == 0:
                            continue
                        top_diff = sorted_opt_cost[1][1] / sorted_opt_cost[0][1]
                        if top_diff > 1:
                            print(a, b, c, d, e, top_diff)

20 10 10 20 10 1.2
20 20 20 20 10 1.3333333333333333
