Skip to content

Commit d1424a2

Browse files
xuyang1706walterddr
authored andcommitted
[FLINK-13676][ml] Add class of Vector to Columns mapper
This closes apache#9413.
1 parent 15f8f3c commit d1424a2

File tree

3 files changed

+238
-0
lines changed

3 files changed

+238
-0
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.flink.ml.common.dataproc.vector;
21+
22+
import org.apache.flink.api.common.typeinfo.TypeInformation;
23+
import org.apache.flink.api.common.typeinfo.Types;
24+
import org.apache.flink.ml.api.misc.param.Params;
25+
import org.apache.flink.ml.common.linalg.DenseVector;
26+
import org.apache.flink.ml.common.linalg.SparseVector;
27+
import org.apache.flink.ml.common.linalg.Vector;
28+
import org.apache.flink.ml.common.mapper.Mapper;
29+
import org.apache.flink.ml.common.utils.OutputColsHelper;
30+
import org.apache.flink.ml.common.utils.TableUtil;
31+
import org.apache.flink.ml.params.dataproc.vector.VectorToColumnsParams;
32+
import org.apache.flink.table.api.TableSchema;
33+
import org.apache.flink.types.Row;
34+
35+
/**
36+
* This mapper maps vector to table columns.
37+
*/
38+
public class VectorToColumnsMapper extends Mapper {
39+
private int colSize;
40+
private int idx;
41+
private OutputColsHelper outputColsHelper;
42+
43+
public VectorToColumnsMapper(TableSchema dataSchema, Params params) {
44+
super(dataSchema, params);
45+
String selectedColName = this.params.get(VectorToColumnsParams.SELECTED_COL);
46+
idx = TableUtil.findColIndex(dataSchema.getFieldNames(), selectedColName);
47+
if (idx < 0) {
48+
throw new IllegalArgumentException("Can not find column: " + selectedColName);
49+
}
50+
String[] outputColNames = this.params.get(VectorToColumnsParams.OUTPUT_COLS);
51+
if (outputColNames == null) {
52+
throw new IllegalArgumentException("VectorToTable: outputColNames must set.");
53+
}
54+
this.colSize = outputColNames.length;
55+
TypeInformation[] types = new TypeInformation[colSize];
56+
for (int i = 0; i < colSize; ++i) {
57+
types[i] = Types.DOUBLE;
58+
}
59+
this.outputColsHelper = new OutputColsHelper(dataSchema, outputColNames, types,
60+
this.params.get(VectorToColumnsParams.RESERVED_COLS));
61+
}
62+
63+
@Override
64+
public Row map(Row row) {
65+
Row result = new Row(colSize);
66+
Object obj = row.getField(idx);
67+
if (null == obj) {
68+
for (int i = 0; i < colSize; i++) {
69+
result.setField(i, null);
70+
}
71+
return outputColsHelper.getResultRow(row, result);
72+
}
73+
74+
Vector vec = (Vector) obj;
75+
76+
if (vec instanceof SparseVector) {
77+
for (int i = 0; i < colSize; ++i) {
78+
result.setField(i, 0.0);
79+
}
80+
SparseVector sparseVector = (SparseVector) vec;
81+
int nnz = sparseVector.numberOfValues();
82+
int[] indices = sparseVector.getIndices();
83+
double[] values = sparseVector.getValues();
84+
for (int i = 0; i < nnz; ++i) {
85+
if (indices[i] < colSize) {
86+
result.setField(indices[i], values[i]);
87+
} else {
88+
break;
89+
}
90+
}
91+
} else {
92+
DenseVector denseVector = (DenseVector) vec;
93+
for (int i = 0; i < colSize; ++i) {
94+
result.setField(i, denseVector.get(i));
95+
}
96+
}
97+
return outputColsHelper.getResultRow(row, result);
98+
}
99+
100+
@Override
101+
public TableSchema getOutputSchema() {
102+
return outputColsHelper.getResultSchema();
103+
}
104+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.flink.ml.params.dataproc.vector;
21+
22+
import org.apache.flink.ml.params.shared.colname.HasOutputCols;
23+
import org.apache.flink.ml.params.shared.colname.HasReservedCols;
24+
import org.apache.flink.ml.params.shared.colname.HasSelectedCol;
25+
26+
/**
27+
* parameters of vector to columns.
28+
*/
29+
public interface VectorToColumnsParams<T> extends
30+
HasSelectedCol<T>,
31+
HasOutputCols<T>,
32+
HasReservedCols<T> {
33+
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.flink.ml.common.dataproc.vector;
21+
22+
import org.apache.flink.api.common.typeinfo.TypeInformation;
23+
import org.apache.flink.api.common.typeinfo.Types;
24+
import org.apache.flink.ml.api.misc.param.Params;
25+
import org.apache.flink.ml.common.linalg.DenseVector;
26+
import org.apache.flink.ml.common.linalg.SparseVector;
27+
import org.apache.flink.ml.common.utils.RowCollector;
28+
import org.apache.flink.ml.params.dataproc.vector.VectorToColumnsParams;
29+
import org.apache.flink.table.api.TableSchema;
30+
import org.apache.flink.types.Row;
31+
32+
import org.junit.Test;
33+
34+
import java.util.List;
35+
36+
import static org.junit.Assert.assertEquals;
37+
38+
/**
39+
* Unit test for VectorToColumnsMapper.
40+
*/
41+
public class VectorToColumnsMapperTest {
42+
@Test
43+
public void test1() throws Exception {
44+
TableSchema schema = new TableSchema(new String[]{"vec"}, new TypeInformation<?>[]{Types.STRING});
45+
46+
Params params = new Params()
47+
.set(VectorToColumnsParams.SELECTED_COL, "vec")
48+
.set(VectorToColumnsParams.OUTPUT_COLS, new String[]{"f0", "f1"});
49+
50+
VectorToColumnsMapper mapper = new VectorToColumnsMapper(schema, params);
51+
RowCollector collector = new RowCollector();
52+
mapper.flatMap(Row.of(new DenseVector(new double[]{3.0, 4.0})), collector);
53+
List<Row> rows = collector.getRows();
54+
assertEquals(rows.get(0).getField(1), 3.0);
55+
assertEquals(rows.get(0).getField(2), 4.0);
56+
assertEquals(mapper.getOutputSchema(), new TableSchema(new String[]{"vec", "f0", "f1"},
57+
new TypeInformation<?>[]{Types.STRING, Types.DOUBLE, Types.DOUBLE}));
58+
59+
}
60+
61+
@Test
62+
public void test2() throws Exception {
63+
TableSchema schema = new TableSchema(new String[]{"vec"}, new TypeInformation<?>[]{Types.STRING});
64+
65+
Params params = new Params()
66+
.set(VectorToColumnsParams.SELECTED_COL, "vec")
67+
.set(VectorToColumnsParams.RESERVED_COLS, new String[]{})
68+
.set(VectorToColumnsParams.OUTPUT_COLS, new String[]{"f0", "f1"});
69+
70+
VectorToColumnsMapper mapper = new VectorToColumnsMapper(schema, params);
71+
72+
RowCollector collector = new RowCollector();
73+
mapper.flatMap(Row.of(new DenseVector(new double[]{3.0, 4.0})), collector);
74+
List<Row> rows = collector.getRows();
75+
assertEquals(rows.get(0).getField(0), 3.0);
76+
assertEquals(rows.get(0).getField(1), 4.0);
77+
assertEquals(mapper.getOutputSchema(), new TableSchema(new String[]{"f0", "f1"},
78+
new TypeInformation<?>[]{Types.DOUBLE, Types.DOUBLE}));
79+
}
80+
81+
@Test
82+
public void test3() throws Exception {
83+
TableSchema schema = new TableSchema(new String[]{"vec"}, new TypeInformation<?>[]{Types.STRING});
84+
85+
Params params = new Params()
86+
.set(VectorToColumnsParams.SELECTED_COL, "vec")
87+
.set(VectorToColumnsParams.OUTPUT_COLS, new String[]{"f0", "f1", "f2"});
88+
89+
VectorToColumnsMapper mapper = new VectorToColumnsMapper(schema, params);
90+
91+
RowCollector collector = new RowCollector();
92+
mapper.flatMap(Row.of(new SparseVector(3, new int[]{1, 2}, new double[]{3.0, 4.0})), collector);
93+
List<Row> rows = collector.getRows();
94+
assertEquals(rows.get(0).getField(0), new SparseVector(3, new int[]{1, 2}, new double[]{3.0, 4.0}));
95+
assertEquals(rows.get(0).getField(1), 0.0);
96+
assertEquals(rows.get(0).getField(2), 3.0);
97+
assertEquals(rows.get(0).getField(3), 4.0);
98+
assertEquals(mapper.getOutputSchema(), new TableSchema(new String[]{"vec", "f0", "f1", "f2"},
99+
new TypeInformation<?>[]{Types.STRING, Types.DOUBLE, Types.DOUBLE, Types.DOUBLE}));
100+
}
101+
}

0 commit comments

Comments
 (0)