In [1]:
from tqdm import tqdm

In [2]:
class Mat:
    def __init__(self, nrows, ncols):
        self.nrows: int = nrows
        self.ncols: int = ncols

    def __str__(self):
        return f"Mat({self.nrows}, {self.ncols})"
    
    def __repr__(self):
        return f"Mat({self.nrows}, {self.ncols})"
    
    def __mul__(self, other):
        return Mat(self.nrows, other.ncols)
    
    def __rmul__(self, other):
        return Mat(other.nrows, self.ncols)

In [3]:
def set3(a, b, c, d):
    return Mat(a, b), Mat(b, c), Mat(c, d)

def set4(a, b, c, d, e):
    return Mat(a, b), Mat(b, c), Mat(c, d), Mat(d, e)

In [4]:
def mul_cost(X: Mat, Y: Mat):
    return X.nrows * X.ncols * Y.ncols

In [5]:
def optimal_cost(X: Mat, Y: Mat, Z: Mat, W: Mat):
    xy_z__w: int = mul_cost(X, Y) + mul_cost(X*Y, Z) + mul_cost((X*Y)*Z, W) # cost(A=XY) + cost(B=AZ) + cost(BW)
    x_yz__w: int = mul_cost(Y, Z) + mul_cost(X, Y*Z) + mul_cost(X*(Y*Z), W) # cost(A=YZ) + cost(B=XA) + cost(BW)
    x__yz_w: int = mul_cost(Y, Z) + mul_cost(Y*Z, W) + mul_cost(X, (Y*Z)*W) # cost(A=YZ) + cost(B=AW) + cost(XB)
    x__y_zw: int = mul_cost(Z, W) + mul_cost(Y, Z*W) + mul_cost(X, Y*(Z*W)) # cost(A=ZW) + cost(B=YA) + cost(XB)
    xy__zw: int  = mul_cost(X, Y) + mul_cost(Z, W) + mul_cost(X*Y, Z*W) # cost(A=XY) + cost(B=ZW) + cost(AB)

    return {
        "xy_z__w": xy_z__w,
        "x_yz__w": x_yz__w,
        "x__yz_w": x__yz_w,
        "x__y_zw": x__y_zw,
        "xy__zw": xy__zw
    }

# {'xy_z__w': 282000, 'x_yz__w': 32000, 'x__yz_w': 28200, 'x__y_zw': 55500, 'xy__zw': 393000}
optimal_cost(Mat(100, 10), Mat(10, 150), Mat(150, 8), Mat(8, 15))

{'xy_z__w': 282000,
 'x_yz__w': 32000,
 'x__yz_w': 28200,
 'x__y_zw': 55500,
 'xy__zw': 393000}

In [6]:
optimal_cost(Mat(90, 10), Mat(10, 90), Mat(90, 10), Mat(10, 10))

{'xy_z__w': 171000,
 'x_yz__w': 27000,
 'x__yz_w': 19000,
 'x__y_zw': 27000,
 'xy__zw': 171000}

In [7]:
opt = optimal_cost(*set4(90, 25, 100, 25, 10))
opt

{'xy_z__w': 472500,
 'x_yz__w': 141250,
 'x__yz_w': 91250,
 'x__y_zw': 72500,
 'xy__zw': 340000}

In [8]:
min(opt, key=opt.get)

'x__y_zw'

In [9]:
sorted_opt = sorted(opt.items(), key=lambda x: x[1])
-(sorted_opt[0][1] - sorted_opt[1][1])

18750

In [19]:
def greedy_cost(a, b, c, d, e): # greedy cost of ((XY)Z)W
    # layer 1 for (XY)Z, this layer chooses between (XY)Z and X(YZ)
    xy_z = (a * b * c) + (a * c * d)
    x_yz = (b * c * d) + (a * b * d)

    sol = []
    if xy_z <= x_yz: # (XY)Z < X(YZ), A = XY with dims a x c
        # options are: (AZ)W and A(ZW)
        az_w = (a * c * d) + (a * d * e)
        a_zw = (c * d * e) + (a * c * e)
        
        if az_w < a_zw:
            sol.append("xy_z__w")
        elif a_zw < az_w:
            sol.append("xy__zw")
        else:
            sol.append("xy_z__w")
            sol.append("xy__zw")
    elif x_yz < xy_z: # X(YZ) < (XY)Z, B = YZ with dims b x d
        # options are: (XB)W and X(BW)
        xb_w = (a * b * d) + (a * d * e)
        x_bw = (b * d * e) + (a * b * e)

        if xb_w < x_bw:
            sol.append("x_yz__w")
        elif x_bw < xb_w:
            sol.append("x__yz_w")
        else:
            sol.append("x_yz__w")
            sol.append("x__yz_w")

    return sol

In [11]:
options = []
for a in tqdm(range(90, 201)):
    for b in tqdm(range(10, 16)):
        for c in tqdm(range(90, 201)):
            for d in tqdm(range(10, 16)):
                for e in tqdm(range(10, 16)):
                    X = Mat(a, b)
                    Y = Mat(b, c)
                    Z = Mat(c, d)
                    W = Mat(d, e)
                    
                    opt_cost = optimal_cost(X, Y, Z, W)
                    min_cost = min(opt_cost, key=opt_cost.get)

                    if min_cost == 'x__y_zw':
                        sorted_opt_cost = sorted(opt_cost.items(), key=lambda x: x[1])
                        top_diff = -(sorted_opt_cost[0][1] - sorted_opt_cost[1][1])
                        options.append((top_diff, (a, b, c, d, e)))

sorted_options = sorted(options, key=lambda x: -x[0])
print(sorted_options[:10])

  0%|          | 0/111 [00:00<?, ?it/s]
[A

[A[A



100%|██████████| 6/6 [00:00<00:00, 69711.42it/s]




100%|██████████| 6/6 [00:00<00:00, 42224.54it/s]




100%|██████████| 6/6 [00:00<00:00, 59353.36it/s]




100%|██████████| 6/6 [00:00<00:00, 41053.55it/s]




100%|██████████| 6/6 [00:00<00:00, 29468.18it/s]




100%|██████████| 6/6 [00:00<00:00, 56552.41it/s]
100%|██████████| 6/6 [00:00<00:00, 581.84it/s]


[A[A



100%|██████████| 6/6 [00:00<00:00, 76959.71it/s]




100%|██████████| 6/6 [00:00<00:00, 47934.90it/s]




100%|██████████| 6/6 [00:00<00:00, 58798.65it/s]




100%|██████████| 6/6 [00:00<00:00, 66576.25it/s]




100%|██████████| 6/6 [00:00<00:00, 40787.40it/s]




100%|██████████| 6/6 [00:00<00:00, 71493.82it/s]
100%|██████████| 6/6 [00:00<00:00, 530.54it/s]


[A[A



100%|██████████| 6/6 [00:00<00:00, 42871.93it/s]




100%|██████████| 6/6 [00:00<00:00, 93553.25it/s]




100%|██████████| 6/6 [00:00<00:00, 75800.67it/s]




100%|██████████| 6/6 [00:00<00:00, 77912

KeyboardInterrupt: 