-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
SGMatrixList.cpp
127 lines (106 loc) · 3.04 KB
/
SGMatrixList.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Fernando José Iglesias García
* Copyright (C) 2012 Fernando José Iglesias García
*/
#include <shogun/lib/SGMatrixList.h>
namespace shogun {
template <class T>
SGMatrixList<T>::SGMatrixList() : SGReferencedData()
{
init_data();
}
template <class T>
SGMatrixList<T>::SGMatrixList(SGMatrix<T>* ml, int32_t nmats, bool ref_counting)
: SGReferencedData(ref_counting), matrix_list(ml), num_matrices(nmats)
{
}
template <class T>
SGMatrixList<T>::SGMatrixList(int32_t nmats, bool ref_counting)
: SGReferencedData(ref_counting), num_matrices(nmats)
{
matrix_list = SG_MALLOC(SGMatrix<T>, nmats);
}
template <class T>
SGMatrixList<T>::SGMatrixList(SGMatrixList const & orig) : SGReferencedData(orig)
{
copy_data(orig);
}
template <class T>
SGMatrixList<T>::~SGMatrixList()
{
unref();
}
template <class T>
SGMatrix<T> SGMatrixList<T>::get_matrix(index_t index) const
{
return matrix_list[index];
}
template <class T>
SGMatrix<T> SGMatrixList<T>::operator[](index_t index) const
{
return matrix_list[index];
}
template <class T>
void SGMatrixList<T>::set_matrix(index_t index, const SGMatrix<T> matrix)
{
matrix_list[index] = matrix;
}
template <class T>
void SGMatrixList<T>::copy_data(const SGReferencedData &orig)
{
matrix_list = ((SGMatrixList*) (&orig))->matrix_list;
num_matrices = ((SGMatrixList*) (&orig))->num_matrices;
}
template <class T>
void SGMatrixList<T>::init_data()
{
matrix_list = NULL;
num_matrices = 0;
}
template <class T>
void SGMatrixList<T>::free_data()
{
SG_FREE(matrix_list);
num_matrices = 0;
matrix_list = NULL;
}
template <class T>
SGMatrixList<T> SGMatrixList<T>::split(SGMatrix<T> matrix, int32_t num_components)
{
REQUIRE((matrix.num_cols % num_components) == 0,
"The number of columns (%d) must be multiple of the number "
"of components (%d).\n",
matrix.num_cols, num_components);
int32_t new_num_cols = matrix.num_cols / num_components;
SGMatrixList<T> out(num_components);
for ( int32_t i = 0 ; i < num_components ; ++i )
{
SGMatrix<T> new_matrix = SGMatrix<T>(matrix.num_rows, new_num_cols);
for ( int32_t row = 0 ; row < matrix.num_rows ; ++row )
{
for ( int32_t col = 0 ; col < new_num_cols ; ++col )
new_matrix(row, col) = matrix(row, i*new_num_cols + col);
}
out.set_matrix(i, new_matrix);
}
return out;
}
template class SGMatrixList<bool>;
template class SGMatrixList<char>;
template class SGMatrixList<int8_t>;
template class SGMatrixList<uint8_t>;
template class SGMatrixList<int16_t>;
template class SGMatrixList<uint16_t>;
template class SGMatrixList<int32_t>;
template class SGMatrixList<uint32_t>;
template class SGMatrixList<int64_t>;
template class SGMatrixList<uint64_t>;
template class SGMatrixList<float32_t>;
template class SGMatrixList<float64_t>;
template class SGMatrixList<floatmax_t>;
} /* namespace shogun */