-
Notifications
You must be signed in to change notification settings - Fork 2
/
matrixproduct.h
108 lines (93 loc) · 2.9 KB
/
matrixproduct.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#ifndef MATRIXPRODUCT_H
#define MATRIXPRODUCT_H
//Custom headers
#include "squarematrix.h"
namespace qmatrixproduct {
/**
* @brief Class for matrix product calculation w/ statistics
*/
class MatrixProduct
{
public:
/**
* @brief MatrixProduct constructor
*/
MatrixProduct();
/**
* @brief Standard O(n^3) matrix multiplication algorithm
* @param A Square matrix
* @param B Square matrix
* @return Matrix C = AB
*/
SquareMatrix standardMultiply(const SquareMatrix &A, const SquareMatrix &B);
/**
* @brief Strassen O(n^log7) matrix multiplication algorithm
* @param A Square matrix
* @param B Square matrix
* @return Matrix C = AB
*/
SquareMatrix strassenMultiply(const SquareMatrix &A, const SquareMatrix &B);
/**
* @brief Winograd-Strassen matrix multiplication algorithm
* @param A Square matrix
* @param B Square matrix
* @return Matrix C = AB
*/
SquareMatrix winogradMultiply(const SquareMatrix &A, const SquareMatrix &B);
/**
* @brief Check whether n is a power of 2
* @param n
* @return true if n is a power of 2
*/
static bool isPowerOfTwo(int n);
/**
* @brief Get closest power of 2 greater than n (i.e. 64 for n=63)
* @param n Positive number
* @return Closest power of 2
*/
static int closestPowerOfTwo(int n);
/**
* @brief Partition matrix A(n) into 4 (n/2) square matrices for Strassen algorithm
* @param A Matrix to partition
* @param A11 Top left
* @param A12 Top right
* @param A21 Bottom left
* @param A22 Bottom right
*/
static void strassenPartition(const SquareMatrix &A,
SquareMatrix &A11, SquareMatrix &A12,
SquareMatrix &A21, SquareMatrix &A22);
/**
* @brief Compose matrix C(n) from 4 (n/2) matrices
* @param C Resulting matrix
* @param C11 Top left
* @param C12 Top right
* @param C21 Bottom left
* @param C22 Bottom right
*/
static void strassenCompose(SquareMatrix &C,
const SquareMatrix &C11, const SquareMatrix &C12,
const SquareMatrix &C21, const SquareMatrix &C22);
/**
* @brief Get number of multiplications done since object creation
* @return Number of multiplications
*/
uint64_t multiplications() const;
/**
* @brief Get number of additions done since object creation
* @return Number of additions
*/
uint64_t additions() const;
/**
* @brief Get number of recursive function calls
* (only the multiplication functions are counted)
* @return Number of recursive calls
*/
uint64_t recursiveCalls() const;
private:
uint64_t m_multiplications;
uint64_t m_additions;
uint64_t m_recursiveCalls;
};
} // namespace qmatrixproduct
#endif // MATRIXPRODUCT_H