Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Newer
Older
100644 350 lines (304 sloc) 7.694 kb
47e4e58 @tomz upgraded to LIBSVM 3.1 and OS X enabled
authored
1 #include <stdlib.h>
2 #include <string.h>
3 #include "svm.h"
4
5 #include "mex.h"
6
7 #if MX_API_VER < 0x07030000
8 typedef int mwIndex;
9 #endif
10
11 #define NUM_OF_RETURN_FIELD 10
12
13 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
14
15 static const char *field_names[] = {
16 "Parameters",
17 "nr_class",
18 "totalSV",
19 "rho",
20 "Label",
21 "ProbA",
22 "ProbB",
23 "nSV",
24 "sv_coef",
25 "SVs"
26 };
27
28 const char *model_to_matlab_structure(mxArray *plhs[], int num_of_feature, struct svm_model *model)
29 {
30 int i, j, n;
31 double *ptr;
32 mxArray *return_model, **rhs;
33 int out_id = 0;
34
35 rhs = (mxArray **)mxMalloc(sizeof(mxArray *)*NUM_OF_RETURN_FIELD);
36
37 // Parameters
38 rhs[out_id] = mxCreateDoubleMatrix(5, 1, mxREAL);
39 ptr = mxGetPr(rhs[out_id]);
40 ptr[0] = model->param.svm_type;
41 ptr[1] = model->param.kernel_type;
42 ptr[2] = model->param.degree;
43 ptr[3] = model->param.gamma;
44 ptr[4] = model->param.coef0;
45 out_id++;
46
47 // nr_class
48 rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
49 ptr = mxGetPr(rhs[out_id]);
50 ptr[0] = model->nr_class;
51 out_id++;
52
53 // total SV
54 rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
55 ptr = mxGetPr(rhs[out_id]);
56 ptr[0] = model->l;
57 out_id++;
58
59 // rho
60 n = model->nr_class*(model->nr_class-1)/2;
61 rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
62 ptr = mxGetPr(rhs[out_id]);
63 for(i = 0; i < n; i++)
64 ptr[i] = model->rho[i];
65 out_id++;
66
67 // Label
68 if(model->label)
69 {
70 rhs[out_id] = mxCreateDoubleMatrix(model->nr_class, 1, mxREAL);
71 ptr = mxGetPr(rhs[out_id]);
72 for(i = 0; i < model->nr_class; i++)
73 ptr[i] = model->label[i];
74 }
75 else
76 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
77 out_id++;
78
79 // probA
80 if(model->probA != NULL)
81 {
82 rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
83 ptr = mxGetPr(rhs[out_id]);
84 for(i = 0; i < n; i++)
85 ptr[i] = model->probA[i];
86 }
87 else
88 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
89 out_id ++;
90
91 // probB
92 if(model->probB != NULL)
93 {
94 rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
95 ptr = mxGetPr(rhs[out_id]);
96 for(i = 0; i < n; i++)
97 ptr[i] = model->probB[i];
98 }
99 else
100 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
101 out_id++;
102
103 // nSV
104 if(model->nSV)
105 {
106 rhs[out_id] = mxCreateDoubleMatrix(model->nr_class, 1, mxREAL);
107 ptr = mxGetPr(rhs[out_id]);
108 for(i = 0; i < model->nr_class; i++)
109 ptr[i] = model->nSV[i];
110 }
111 else
112 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
113 out_id++;
114
115 // sv_coef
116 rhs[out_id] = mxCreateDoubleMatrix(model->l, model->nr_class-1, mxREAL);
117 ptr = mxGetPr(rhs[out_id]);
118 for(i = 0; i < model->nr_class-1; i++)
119 for(j = 0; j < model->l; j++)
120 ptr[(i*(model->l))+j] = model->sv_coef[i][j];
121 out_id++;
122
123 // SVs
124 {
125 int ir_index, nonzero_element;
126 mwIndex *ir, *jc;
127 mxArray *pprhs[1], *pplhs[1];
128
129 if(model->param.kernel_type == PRECOMPUTED)
130 {
131 nonzero_element = model->l;
132 num_of_feature = 1;
133 }
134 else
135 {
136 nonzero_element = 0;
137 for(i = 0; i < model->l; i++) {
138 j = 0;
139 while(model->SV[i][j].index != -1)
140 {
141 nonzero_element++;
142 j++;
143 }
144 }
145 }
146
147 // SV in column, easier accessing
148 rhs[out_id] = mxCreateSparse(num_of_feature, model->l, nonzero_element, mxREAL);
149 ir = mxGetIr(rhs[out_id]);
150 jc = mxGetJc(rhs[out_id]);
151 ptr = mxGetPr(rhs[out_id]);
152 jc[0] = ir_index = 0;
153 for(i = 0;i < model->l; i++)
154 {
155 if(model->param.kernel_type == PRECOMPUTED)
156 {
157 // make a (1 x model->l) matrix
158 ir[ir_index] = 0;
159 ptr[ir_index] = model->SV[i][0].value;
160 ir_index++;
161 jc[i+1] = jc[i] + 1;
162 }
163 else
164 {
165 int x_index = 0;
166 while (model->SV[i][x_index].index != -1)
167 {
168 ir[ir_index] = model->SV[i][x_index].index - 1;
169 ptr[ir_index] = model->SV[i][x_index].value;
170 ir_index++, x_index++;
171 }
172 jc[i+1] = jc[i] + x_index;
173 }
174 }
175 // transpose back to SV in row
176 pprhs[0] = rhs[out_id];
177 if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
178 return "cannot transpose SV matrix";
179 rhs[out_id] = pplhs[0];
180 out_id++;
181 }
182
183 /* Create a struct matrix contains NUM_OF_RETURN_FIELD fields */
184 return_model = mxCreateStructMatrix(1, 1, NUM_OF_RETURN_FIELD, field_names);
185
186 /* Fill struct matrix with input arguments */
187 for(i = 0; i < NUM_OF_RETURN_FIELD; i++)
188 mxSetField(return_model,0,field_names[i],mxDuplicateArray(rhs[i]));
189 /* return */
190 plhs[0] = return_model;
191 mxFree(rhs);
192
193 return NULL;
194 }
195
196 struct svm_model *matlab_matrix_to_model(const mxArray *matlab_struct, const char **msg)
197 {
198 int i, j, n, num_of_fields;
199 double *ptr;
200 int id = 0;
201 struct svm_node *x_space;
202 struct svm_model *model;
203 mxArray **rhs;
204
205 num_of_fields = mxGetNumberOfFields(matlab_struct);
206 if(num_of_fields != NUM_OF_RETURN_FIELD)
207 {
208 *msg = "number of return field is not correct";
209 return NULL;
210 }
211 rhs = (mxArray **) mxMalloc(sizeof(mxArray *)*num_of_fields);
212
213 for(i=0;i<num_of_fields;i++)
214 rhs[i] = mxGetFieldByNumber(matlab_struct, 0, i);
215
216 model = Malloc(struct svm_model, 1);
217 model->rho = NULL;
218 model->probA = NULL;
219 model->probB = NULL;
220 model->label = NULL;
221 model->nSV = NULL;
222 model->free_sv = 1; // XXX
223
224 ptr = mxGetPr(rhs[id]);
225 model->param.svm_type = (int)ptr[0];
226 model->param.kernel_type = (int)ptr[1];
227 model->param.degree = (int)ptr[2];
228 model->param.gamma = ptr[3];
229 model->param.coef0 = ptr[4];
230 id++;
231
232 ptr = mxGetPr(rhs[id]);
233 model->nr_class = (int)ptr[0];
234 id++;
235
236 ptr = mxGetPr(rhs[id]);
237 model->l = (int)ptr[0];
238 id++;
239
240 // rho
241 n = model->nr_class * (model->nr_class-1)/2;
242 model->rho = (double*) malloc(n*sizeof(double));
243 ptr = mxGetPr(rhs[id]);
244 for(i=0;i<n;i++)
245 model->rho[i] = ptr[i];
246 id++;
247
248 // label
249 if(mxIsEmpty(rhs[id]) == 0)
250 {
251 model->label = (int*) malloc(model->nr_class*sizeof(int));
252 ptr = mxGetPr(rhs[id]);
253 for(i=0;i<model->nr_class;i++)
254 model->label[i] = (int)ptr[i];
255 }
256 id++;
257
258 // probA
259 if(mxIsEmpty(rhs[id]) == 0)
260 {
261 model->probA = (double*) malloc(n*sizeof(double));
262 ptr = mxGetPr(rhs[id]);
263 for(i=0;i<n;i++)
264 model->probA[i] = ptr[i];
265 }
266 id++;
267
268 // probB
269 if(mxIsEmpty(rhs[id]) == 0)
270 {
271 model->probB = (double*) malloc(n*sizeof(double));
272 ptr = mxGetPr(rhs[id]);
273 for(i=0;i<n;i++)
274 model->probB[i] = ptr[i];
275 }
276 id++;
277
278 // nSV
279 if(mxIsEmpty(rhs[id]) == 0)
280 {
281 model->nSV = (int*) malloc(model->nr_class*sizeof(int));
282 ptr = mxGetPr(rhs[id]);
283 for(i=0;i<model->nr_class;i++)
284 model->nSV[i] = (int)ptr[i];
285 }
286 id++;
287
288 // sv_coef
289 ptr = mxGetPr(rhs[id]);
290 model->sv_coef = (double**) malloc((model->nr_class-1)*sizeof(double));
291 for( i=0 ; i< model->nr_class -1 ; i++ )
292 model->sv_coef[i] = (double*) malloc((model->l)*sizeof(double));
293 for(i = 0; i < model->nr_class - 1; i++)
294 for(j = 0; j < model->l; j++)
295 model->sv_coef[i][j] = ptr[i*(model->l)+j];
296 id++;
297
298 // SV
299 {
300 int sr, sc, elements;
301 int num_samples;
302 mwIndex *ir, *jc;
303 mxArray *pprhs[1], *pplhs[1];
304
305 // transpose SV
306 pprhs[0] = rhs[id];
307 if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
308 {
309 svm_free_and_destroy_model(&model);
310 *msg = "cannot transpose SV matrix";
311 return NULL;
312 }
313 rhs[id] = pplhs[0];
314
315 sr = (int)mxGetN(rhs[id]);
316 sc = (int)mxGetM(rhs[id]);
317
318 ptr = mxGetPr(rhs[id]);
319 ir = mxGetIr(rhs[id]);
320 jc = mxGetJc(rhs[id]);
321
322 num_samples = (int)mxGetNzmax(rhs[id]);
323
324 elements = num_samples + sr;
325
326 model->SV = (struct svm_node **) malloc(sr * sizeof(struct svm_node *));
327 x_space = (struct svm_node *)malloc(elements * sizeof(struct svm_node));
328
329 // SV is in column
330 for(i=0;i<sr;i++)
331 {
332 int low = (int)jc[i], high = (int)jc[i+1];
333 int x_index = 0;
334 model->SV[i] = &x_space[low+i];
335 for(j=low;j<high;j++)
336 {
337 model->SV[i][x_index].index = (int)ir[j] + 1;
338 model->SV[i][x_index].value = ptr[j];
339 x_index++;
340 }
341 model->SV[i][x_index].index = -1;
342 }
343
344 id++;
345 }
346 mxFree(rhs);
347
348 return model;
349 }
Something went wrong with that request. Please try again.