/
diag.ts
87 lines (72 loc) · 2.02 KB
/
diag.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import { assert } from "@thi.ng/api";
import { ASparseMatrix } from "./amatrix";
import type { NzEntry } from "./api";
import { CSC } from "./csc";
import { CSR } from "./csr";
import { SparseVec } from "./vec";
export class Diag extends ASparseMatrix {
static identity(m: number) {
return new Diag(new Array(m).fill(1));
}
data: SparseVec;
constructor(data: SparseVec | number[]) {
if (data instanceof SparseVec) {
super(data.m, data.m);
this.data = data;
} else {
super(data.length, data.length);
this.data = SparseVec.fromDense(data);
}
}
*nzEntries() {
for (let e of this.data.nzEntries()) {
yield <NzEntry>[e[0], e[0], e[2]];
}
}
at(m: number, n: number, safe = true) {
safe && this.ensureIndex(m, n);
return m === n ? this.data.at(m, false) : 0;
}
setAt(m: number, n: number, v: number, safe = true) {
safe &&
assert(m === n && m >= 0 && m < this.m, `invalid index: ${m},${n}`);
this.data.setAt(m, v, false);
return this;
}
nnz(): number {
return this.data.length;
}
nnzCol(n: number): number {
return this.data.at(n) !== 0 ? 1 : 0;
}
nnzRow(m: number): number {
return this.nnzCol(m);
}
nzColRows(n: number): number[] {
return this.data.at(n) !== 0 ? [n] : [];
}
nzColVals(n: number): number[] {
const x = this.data.at(n);
return x !== 0 ? [x] : [];
}
nzRowCols(m: number): number[] {
return this.nzColRows(m);
}
nzRowVals(m: number): number[] {
return this.nzColVals(m);
}
toDense() {
const { data, n } = this;
const res = new Array(n * n).fill(0);
for (let i = 0; i < n; i++) {
res[i * n + i] = data.at(i, false);
}
return res;
}
toCSC() {
return CSC.diag(this.data.toDense());
}
toCSR() {
return CSR.diag(this.data.toDense());
}
}