forked from cloudflare/redoctober
/
matrix.go
149 lines (114 loc) · 2.73 KB
/
matrix.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// Package msp implements matrix operations for elements in GF(2^128).
package msp
type Row []FieldElem
// NewRow returns a row of length s with all zero entries.
func NewRow(s int) Row {
out := Row(make([]FieldElem, s))
for i := 0; i < s; i++ {
out[i] = NewFieldElem()
}
return out
}
// AddM adds two vectors.
func (e Row) AddM(f Row) {
le, lf := e.Size(), f.Size()
if le != lf {
panic("Can't add rows that are different sizes!")
}
for i, fI := range f {
e[i].AddM(fI)
}
return
}
// MulM multiplies the row by a scalar.
func (e Row) MulM(f FieldElem) {
for i := range e {
e[i] = e[i].Mul(f)
}
}
func (e Row) Mul(f FieldElem) Row {
out := NewRow(e.Size())
for i := 0; i < e.Size(); i++ {
out[i] = e[i].Mul(f)
}
return out
}
// DotProduct computes the dot product of two vectors.
func (e Row) DotProduct(f Row) FieldElem {
if e.Size() != f.Size() {
panic("Can't get dot product of rows of different length!")
}
out := NewFieldElem()
for i := 0; i < e.Size(); i++ {
out.AddM(e[i].Mul(f[i]))
}
return out
}
func (e Row) Size() int {
return len(e)
}
type Matrix []Row
// Mul right-multiplies a matrix by a row.
func (e Matrix) Mul(f Row) Row {
out, in := e.Size()
if in != f.Size() {
panic("Can't multiply by row that is wrong size!")
}
res := NewRow(out)
for i := 0; i < out; i++ {
res[i] = e[i].DotProduct(f)
}
return res
}
// Recovery returns the row vector that takes this matrix to the target vector [1 0 0 ... 0].
func (e Matrix) Recovery() (Row, bool) {
a, b := e.Size()
// aug is the target vector.
aug := NewRow(a)
aug[0] = One.Dup()
// Duplicate e away so we don't mutate it; transpose it at the same time.
f := make([]Row, b)
for i := range f {
f[i] = NewRow(a)
}
for i := 0; i < a; i++ {
for j := 0; j < b; j++ {
f[j][i] = e[i][j].Dup()
}
}
for row := range f {
if row >= b { // The matrix is tall and thin--we've finished before exhausting all the rows.
break
}
// Find a row with a non-zero entry in the (row)th position
candId := -1
for j, fJ := range f[row:] {
if !fJ[row].IsZero() {
candId = j + row
break
}
}
if candId == -1 { // If we can't find one, fail and return our partial work.
return aug, false
}
// Move it to the top
f[row], f[candId] = f[candId], f[row]
aug[row], aug[candId] = aug[candId], aug[row]
// Make the pivot 1.
fInv := f[row][row].Invert()
f[row].MulM(fInv)
aug[row] = aug[row].Mul(fInv)
// Cancel out the (row)th position for every row above and below it.
for i := range f {
if i != row && !f[i][row].IsZero() {
c := f[i][row].Dup()
f[i].AddM(f[row].Mul(c))
aug[i].AddM(aug[row].Mul(c))
}
}
}
return aug, true
}
func (e Matrix) Size() (int, int) {
return len(e), e[0].Size()
}