-
Notifications
You must be signed in to change notification settings - Fork 62
/
Copy pathMatrixChainMultiplication.java
57 lines (47 loc) · 1.42 KB
/
MatrixChainMultiplication.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class MatrixChainMultiplication
{
// Function to find the most efficient way to multiply
// given sequence of matrices
public static int MatrixChainMultiplication(int[] dims, int i, int j)
{
// base case: one matrix
if (j <= i + 1) {
return 0;
}
// stores minimum number of scalar multiplications (i.e., cost)
// needed to compute the matrix M[i+1]...M[j] = M[i..j]
int min = Integer.MAX_VALUE;
// take the minimum over each possible position at which the
// sequence of matrices can be split
/*
(M[i+1]) x (M[i+2]..................M[j])
(M[i+1]M[i+2]) x (M[i+3.............M[j])
...
...
(M[i+1]M[i+2]............M[j-1]) x (M[j])
*/
for (int k = i + 1; k <= j - 1; k++)
{
// recur for M[i+1]..M[k] to get a i x k matrix
int cost = MatrixChainMultiplication(dims, i, k);
// recur for M[k+1]..M[j] to get a k x j matrix
cost += MatrixChainMultiplication(dims, k, j);
// cost to multiply two (i x k) and (k x j) matrix
cost += dims[i] * dims[k] * dims[j];
if (cost < min) {
min = cost;
}
}
// return min cost to multiply M[j+1]..M[j]
return min;
}
// main function
public static void main(String[] args)
{
// Matrix M[i] has dimension dims[i-1] x dims[i] for i = 1..n
// input is 10 × 30 matrix, 30 × 5 matrix, 5 × 60 matrix
int[] dims = { 10, 30, 5, 60 };
System.out.print("Minimum cost is " +
MatrixChainMultiplication(dims, 0, dims.length - 1));
}
}