Skip to content

Commit

Permalink
add cell.above to refer to previous row value (#295)
Browse files Browse the repository at this point in the history
  • Loading branch information
MykolaGolubyev committed Apr 2, 2019
1 parent 2cbe09e commit 1a94257
Show file tree
Hide file tree
Showing 12 changed files with 318 additions and 17 deletions.
Expand Up @@ -58,6 +58,15 @@ class TableDataExtensionTest {
assert tableData.row(5).toMap() == ['Col A': 'v2a', 'Col B': 20, 'Col C': 'v2c']
}

@Test
void "cell previous should be substituted with value from a previous row"() {
def tableData = createTableDataWithPreviousRef()
assert tableData.numberOfRows() == 3
assert tableData.row(0).toMap() == ["Col A": "v1a", "Col B": "v1b", "Col C": 10]
assert tableData.row(1).toMap() == ["Col A": "v2a", "Col B": "v2b", "Col C": 10]
assert tableData.row(2).toMap() == ["Col A": "v2a", "Col B": "v2b", "Col C": 20]
}

@Test
void "should ignore underscore under header"() {
def table = ["hello" | "world"] {
Expand Down Expand Up @@ -87,6 +96,14 @@ class TableDataExtensionTest {
"v2a" | permute(10, 20) | "v2c" }
}

static TableData createTableDataWithPreviousRef() {
["Col A" | "Col B" | "Col C"] {
__________________________________________
"v1a" | "v1b" | 10
"v2a" | "v2b" | cell.above
"v2a" | "v2b" | cell.above + 10 }
}

private static void validateTableData(TableData tableData) {
tableData.numberOfRows().should == 2
tableData.row(0).toMap().should == ["Col A": "v1a", "Col B": "v1b", "Col C": "v1c"]
Expand Down
3 changes: 3 additions & 0 deletions webtau-core/src/main/java/com/twosigma/webtau/Ddjt.java
Expand Up @@ -18,6 +18,7 @@

import com.twosigma.webtau.data.MultiValue;
import com.twosigma.webtau.data.table.TableData;
import com.twosigma.webtau.data.table.autogen.TableDataCellValueGenFunctions;
import com.twosigma.webtau.data.table.TableDataUnderscoreOrPlaceholder;
import com.twosigma.webtau.expectation.ActualCode;
import com.twosigma.webtau.expectation.ActualCodeExpectations;
Expand Down Expand Up @@ -55,6 +56,8 @@ public class Ddjt {
public static final TableDataUnderscoreOrPlaceholder ________________________________________________________________________________ = TableDataUnderscoreOrPlaceholder.INSTANCE;
public static final TableDataUnderscoreOrPlaceholder ________________________________________________________________________________________________ = TableDataUnderscoreOrPlaceholder.INSTANCE;

public static final TableDataCellValueGenFunctions cell = new TableDataCellValueGenFunctions();

public static TableData table(String... columnNames) {
return new TableData(Arrays.stream(columnNames));
}
Expand Down
Expand Up @@ -17,24 +17,28 @@
package com.twosigma.webtau.data.table;

import com.twosigma.webtau.data.MultiValue;
import com.twosigma.webtau.data.table.autogen.TableDataCellValueGenerator;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Stream;

import static java.util.stream.Collectors.toList;

public class Record {
private final Header header;
private final List<Object> values;
private final CompositeKey key;

private final boolean hasMultiValues;
private final boolean hasValueGenerators;

public Record(Header header, Stream<Object> values) {
this.header = header;
this.values = values.collect(toList());
RecordFromStream recordFromStream = new RecordFromStream(values);

hasMultiValues = recordFromStream.hasMultiValues;
hasValueGenerators = recordFromStream.hasValueGenerators;
this.values = recordFromStream.values;

this.key = header.hasKeyColumns() ?
new CompositeKey(header.getKeyIdxStream().map(this::get)) : null;
}
Expand Down Expand Up @@ -63,7 +67,11 @@ public Stream<Object> values() {
}

public boolean hasMultiValues() {
return this.values.stream().anyMatch(v -> v instanceof MultiValue);
return this.hasMultiValues;
}

public boolean hasValueGenerators() {
return this.hasValueGenerators;
}

@SuppressWarnings("unchecked")
Expand All @@ -78,6 +86,27 @@ public List<Record> unwrapMultiValues() {
return multiValuesUnwrapper.result;
}

public Record evaluateValueGenerators(Record previous, int rowIdx) {
if (!hasValueGenerators()) {
return this;
}

List<Object> newValues = new ArrayList<>(this.values.size());
int colIdx = 0;
for (Object value : this.values) {
if (value instanceof TableDataCellValueGenerator) {
newValues.add(((TableDataCellValueGenerator) value).generate(
this, previous, rowIdx, colIdx, header.columnNameByIdx(colIdx)));
} else {
newValues.add(value);
}

colIdx++;
}

return new Record(header, newValues.stream());
}

public Map<String, Object> toMap() {
Map<String, Object> result = new LinkedHashMap<>();
header.getColumnIdxStream().forEach(i -> result.put(header.columnNameByIdx(i), values.get(i)));
Expand Down Expand Up @@ -118,4 +147,26 @@ void add(Record record) {
result.add(record);
}
}

private static class RecordFromStream {
private boolean hasMultiValues;
private boolean hasValueGenerators;
private List<Object> values;

public RecordFromStream(Stream<Object> valuesStream) {
values = new ArrayList<>();

valuesStream.forEach(v -> {
if (v instanceof MultiValue) {
hasMultiValues = true;
}

if (v instanceof TableDataCellValueGenerator) {
hasValueGenerators = true;
}

values.add(v);
});
}
}
}
Expand Up @@ -120,18 +120,21 @@ public void addRow(Record record) {
int rowIdx = rows.size();
CompositeKey key = getOrBuildKey(rowIdx, record);

Record previous = rowsByKey.put(key, record);
if (previous != null) {
Record existing = rowsByKey.put(key, record);
if (existing != null) {
throw new IllegalArgumentException("duplicate entry found with key: " + key +
"\n" + previous +
"\n" + existing +
"\n" + record);
}

Record previous = rows.isEmpty() ? null : rows.get(rows.size() - 1);
Record withEvaluatedGenerators = record.evaluateValueGenerators(previous, rows.size());

rowIdxByKey.put(key, rowIdx);
rows.add(record);
rows.add(withEvaluatedGenerators);
}

public TableData map(TableDataCellFunction mapper) {
public TableData map(TableDataCellMapFunction mapper) {
TableData mapped = new TableData(header);

int rowIdx = 0;
Expand All @@ -150,7 +153,7 @@ public <T, R> Stream<R> mapColumn(String columnName, Function<T, R> mapper) {
}

@SuppressWarnings("unchecked")
private <T, R> Stream<Object> mapRow(int rowIdx, Record originalRow, TableDataCellFunction mapper) {
private <T, R> Stream<Object> mapRow(int rowIdx, Record originalRow, TableDataCellMapFunction mapper) {
return header.getColumnIdxStream()
.mapToObj(idx -> mapper.apply(rowIdx, idx, header.columnNameByIdx(idx), originalRow.get(idx)));
}
Expand Down
Expand Up @@ -16,6 +16,6 @@

package com.twosigma.webtau.data.table;

public interface TableDataCellFunction<T, R> {
public interface TableDataCellMapFunction<T, R> {
R apply(int rowIdx, int colIdx, String columnName, T v);
}
@@ -0,0 +1,23 @@
/*
* Copyright 2019 TWO SIGMA OPEN SOURCE, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.twosigma.webtau.data.table.autogen;

import com.twosigma.webtau.data.table.Record;

public interface TableDataCellValueGenFullFunction<R> {
R apply(Record row, Record prev, int rowIdx, int colIdx, String columnName);
}
@@ -0,0 +1,42 @@
/*
* Copyright 2019 TWO SIGMA OPEN SOURCE, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.twosigma.webtau.data.table.autogen;

import com.twosigma.webtau.data.table.Record;

/**
* @see com.twosigma.webtau.Ddjt#cell
*/
public class TableDataCellValueGenFunctions {
public final TableDataCellValueGenerator<?> above = new TableDataCellValueGenerator<>(this::previousColumnValue);

public static <R> TableDataCellValueGenerator<R> value(TableDataCellValueGenFullFunction<R> genFunction) {
return new TableDataCellValueGenerator<>(genFunction);
}

public static <R> TableDataCellValueGenerator<R> value(TableDataCellValueGenOnlyRecordFunction<R> genFunction) {
return new TableDataCellValueGenerator<>(genFunction);
}

private <R> R previousColumnValue(Record row, Record prev, int rowIdx, int colIdx, String columnName) {
if (prev == null) {
return null;
}

return prev.get(columnName);
}
}
@@ -0,0 +1,23 @@
/*
* Copyright 2019 TWO SIGMA OPEN SOURCE, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.twosigma.webtau.data.table.autogen;

import com.twosigma.webtau.data.table.Record;

public interface TableDataCellValueGenOnlyRecordFunction<R> {
R apply(Record row);
}
@@ -0,0 +1,64 @@
/*
* Copyright 2019 TWO SIGMA OPEN SOURCE, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.twosigma.webtau.data.table.autogen;

import com.twosigma.webtau.data.table.Record;

public class TableDataCellValueGenerator<R> {
private TableDataCellValueGenFullFunction<R> genFunction;
private TableDataCellValueGenOnlyRecordFunction<R> genOnlyRecordFunction;

public TableDataCellValueGenerator(TableDataCellValueGenFullFunction<R> genFunction) {
this.genFunction = genFunction;
}

public TableDataCellValueGenerator(TableDataCellValueGenOnlyRecordFunction<R> genOnlyRecordFunction) {
this.genOnlyRecordFunction = genOnlyRecordFunction;
}

public Object generate(Record row, Record prev, int rowIdx, int colIdx, String columnName) {
if (genOnlyRecordFunction != null) {
return genOnlyRecordFunction.apply(row);
} else {
return genFunction.apply(row, prev, rowIdx, colIdx, columnName);
}
}

@SuppressWarnings("unchecked")
public <N extends Number> TableDataCellValueGenerator<N> plus(Number v) {
return new TableDataCellValueGenerator<N>(((row, prev, rowIdx, colIdx, columnName) -> {
R calculated = genFunction.apply(row, prev, rowIdx, colIdx, columnName);
return (N) addTwoNumbers((Number) calculated, v);
}));
}

private static Number addTwoNumbers(Number a, Number b) {
if (b instanceof Double) {
return a.doubleValue() + b.doubleValue();
}

if (a instanceof Long || b instanceof Long) {
return a.longValue() + b.longValue();
}

if (b instanceof Integer) {
return a.intValue() + b.intValue();
}

throw new UnsupportedOperationException(a.getClass() + " + " + b.getClass() + " is not supported");
}
}
Expand Up @@ -34,6 +34,17 @@ class TableDataTest {
validateTableData(tableData)
}

@Test
void "cell previous should be substituted with value from a previous row"() {
def tableData = createTableDataWithPreviousRef()
assert tableData.numberOfRows() == 3
assert tableData.row(0).toMap() == ["Col A": "v1a", "Col B": "v1b", "Col C": 10]
assert tableData.row(1).toMap() == ["Col A": "v2a", "Col B": "v2b", "Col C": 10]
assert tableData.row(2).toMap() == ["Col A": "v2a", "Col B": "v2b", "Col C": 20]

DocumentationArtifacts.create(TableDataTest, 'table-with-cell-above.json', tableData.toJson())
}

@Test(expected = IllegalArgumentException)
void "should report columns number mismatch during table creation using header and values vararg methods"() {
table("Col A", "Col B", "Col C").values(
Expand Down Expand Up @@ -71,11 +82,19 @@ class TableDataTest {

static TableData createTableDataWithPermute() {
table("Col A" , "Col B" , "Col C",
________________________________________________________________,
________________________________________________________________,
permute(true, false), "v1b" , permute('a', 'b'),
"v2a" , permute(10, 20) , "v2c")
}

static TableData createTableDataWithPreviousRef() {
table("Col A", "Col B", "Col C",
________________________________________________,
"v1a", "v1b", 10,
"v2a", "v2b", cell.above,
"v2a", "v2b", cell.above.plus(10))
}

private static void validateTableData(TableData tableData) {
assert tableData.numberOfRows() == 2
assert tableData.row(0).toMap() == ["Col A": "v1a", "Col B": "v1b", "Col C": "v1c"]
Expand Down

0 comments on commit 1a94257

Please sign in to comment.