Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented partial sum cache for SMAIndicator #1140

Merged
merged 7 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Changelog for `ta4j`, roughly following [keepachangelog.com](http://keepachangelog.com/en/1.0.0/) from version 0.9 onwards.

## 0.17

- **Implemented inner cache for SMAIndicator**


## 0.16 (released May 15, 2024)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
package org.ta4j.core.indicators;

import org.ta4j.core.Indicator;
import org.ta4j.core.indicators.helpers.RunningTotalIndicator;
import org.ta4j.core.num.Num;

/**
Expand All @@ -34,8 +35,8 @@
*/
public class SMAIndicator extends CachedIndicator<Num> {

private final Indicator<Num> indicator;
private final int barCount;
private RunningTotalIndicator previousSum;

/**
* Constructor.
Expand All @@ -45,21 +46,21 @@ public class SMAIndicator extends CachedIndicator<Num> {
*/
public SMAIndicator(Indicator<Num> indicator, int barCount) {
super(indicator);
this.indicator = indicator;
this.previousSum = new RunningTotalIndicator(indicator, barCount);
this.barCount = barCount;
}

@Override
protected Num calculate(int index) {
Num sum = zero();
for (int i = Math.max(0, index - barCount + 1); i <= index; i++) {
sum = sum.plus(indicator.getValue(i));
}

final int realBarCount = Math.min(barCount, index + 1);
final var sum = partialSum(index);
return sum.dividedBy(numOf(realBarCount));
}

private Num partialSum(int index) {
return this.previousSum.getValue(index);
}

/** @return {@link #barCount} */
@Override
public int getUnstableBars() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
public class RunningTotalIndicator extends CachedIndicator<Num> {
private final Indicator<Num> indicator;
private final int barCount;
private Num previousSum = zero();

// serial access detection
private int previousIndex = -1;

public RunningTotalIndicator(Indicator<Num> indicator, int barCount) {
super(indicator);
Expand All @@ -45,20 +49,53 @@ public RunningTotalIndicator(Indicator<Num> indicator, int barCount) {

@Override
protected Num calculate(int index) {
// serial access can benefit from previous partial sums
// which saves a lot of CPU work for very long barCounts
if (previousIndex != -1 && previousIndex == index - 1) {
return fastPath(index);
}

return slowPath(index);
}

private Num fastPath(final int index) {
var newSum = partialSum(index);
updatePartialSum(index, newSum);
return newSum;
}

private Num slowPath(final int index) {
Num sum = zero();
for (int i = Math.max(getBarSeries().getBeginIndex(), index - barCount + 1); i <= index; i++) {
for (int i = Math.max(0, index - barCount + 1); i <= index; i++) {
sum = sum.plus(indicator.getValue(i));
}

updatePartialSum(index, sum);
return sum;
}

@Override
public int getUnstableBars() {
return 0;
private void updatePartialSum(final int index, final Num sum) {
previousIndex = index;
previousSum = sum;
}

private Num partialSum(int index) {
var sum = this.previousSum.plus(indicator.getValue(index));

if (index >= barCount) {
return sum.minus(indicator.getValue(index - barCount));
}

return sum;
}

@Override
public String toString() {
return getClass().getSimpleName() + " barCount: " + barCount;
}

@Override
public int getUnstableBars() {
return barCount;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/**
* The MIT License (MIT)
*
* Copyright (c) 2017-2023 Ta4j Organization & respective
* authors (see AUTHORS)
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of
* this software and associated documentation files (the "Software"), to deal in
* the Software without restriction, including without limitation the rights to
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
* the Software, and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
package org.ta4j.core.indicators;

import static org.junit.Assert.assertEquals;
import static org.ta4j.core.TestUtils.assertNumEquals;

import java.util.function.Function;

import org.junit.Before;
import org.junit.Test;
import org.ta4j.core.BarSeries;
import org.ta4j.core.Indicator;
import org.ta4j.core.indicators.helpers.ClosePriceIndicator;
import org.ta4j.core.mocks.MockBar;
import org.ta4j.core.mocks.MockBarSeries;
import org.ta4j.core.num.Num;

public class SMAIndicatorMovingSerieTest extends AbstractIndicatorTest<Indicator<Num>, Num> {

public SMAIndicatorMovingSerieTest(Function<Number, Num> numFunction) {
super((data, params) -> new SMAIndicator(data, (int) params[0]), numFunction);
}

private BarSeries data;

@Before
public void setUp() {
data = new MockBarSeries(numFunction, 1, 2, 3, 4, 7);
data.setMaximumBarCount(4);
}

@Test
public void usingBarCount3MovingSeries() {
firstAddition();
secondAddition();
thirdAddition();
fourthAddition();
randomAccessAfterFourAdditions();
}

private void firstAddition() {
data.addBar(new MockBar(5., numFunction));
Indicator<Num> indicator2 = getIndicator(new ClosePriceIndicator(data), 2);

// unstable bars skipped, unpredictable results
assertNumEquals((3d + 4d) / 2, indicator2.getValue(data.getBeginIndex() + 1));
assertNumEquals((4d + 7d) / 2, indicator2.getValue(data.getBeginIndex() + 2));
assertNumEquals((7d + 5d) / 2, indicator2.getValue(data.getBeginIndex() + 3));

Indicator<Num> indicator3 = getIndicator(new ClosePriceIndicator(data), 3);

// unstable bars skipped, unpredictable results
assertNumEquals((3d + 4d + 7d) / 3, indicator3.getValue(data.getBeginIndex() + 2));
assertNumEquals((4d + 7d + 5d) / 3, indicator3.getValue(data.getBeginIndex() + 3));
}

private void secondAddition() {
data.addBar(new MockBar(10., numFunction));
Indicator<Num> indicator2 = getIndicator(new ClosePriceIndicator(data), 2);

// unstable bars skipped, unpredictable results
assertNumEquals((4d + 7d) / 2, indicator2.getValue(data.getBeginIndex() + 1));
assertNumEquals((7d + 5d) / 2, indicator2.getValue(data.getBeginIndex() + 2));
assertNumEquals((5d + 10d) / 2, indicator2.getValue(data.getBeginIndex() + 3));

Indicator<Num> indicator3 = getIndicator(new ClosePriceIndicator(data), 3);

// unstable bars skipped, unpredictable results
assertNumEquals((4d + 7d + 5d) / 3, indicator3.getValue(data.getBeginIndex() + 2));
assertNumEquals((7d + 5d + 10d) / 3, indicator3.getValue(data.getBeginIndex() + 3));
}

private void thirdAddition() {
data.addBar(new MockBar(20., numFunction));
Indicator<Num> indicator2 = getIndicator(new ClosePriceIndicator(data), 2);

// unstable bars skipped, unpredictable results
assertNumEquals((7d + 5d) / 2, indicator2.getValue(data.getBeginIndex() + 1));
assertNumEquals((5d + 10d) / 2, indicator2.getValue(data.getBeginIndex() + 2));
assertNumEquals((10d + 20d) / 2, indicator2.getValue(data.getBeginIndex() + 3));

Indicator<Num> indicator3 = getIndicator(new ClosePriceIndicator(data), 3);

// unstable bars skipped, unpredictable results
assertNumEquals((7d + 5d + 10d) / 3, indicator3.getValue(data.getBeginIndex() + 2));
assertNumEquals((5d + 10d + 20d) / 3, indicator3.getValue(data.getBeginIndex() + 3));
}

private void fourthAddition() {
data.addBar(new MockBar(30., numFunction));
Indicator<Num> indicator2 = getIndicator(new ClosePriceIndicator(data), 2);

// unstable bars skipped, unpredictable results
assertNumEquals((5d + 10d) / 2, indicator2.getValue(data.getBeginIndex() + 1));
assertNumEquals((10d + 20d) / 2, indicator2.getValue(data.getBeginIndex() + 2));
assertNumEquals((20d + 30d) / 2, indicator2.getValue(data.getBeginIndex() + 3));

Indicator<Num> indicator3 = getIndicator(new ClosePriceIndicator(data), 3);

// unstable bars skipped, unpredictable results
assertNumEquals((5d + 10d + 20d) / 3, indicator3.getValue(data.getBeginIndex() + 2));
assertNumEquals((10d + 20d + 30d) / 3, indicator3.getValue(data.getBeginIndex() + 3));
}

private void randomAccessAfterFourAdditions() {
Indicator<Num> indicator2 = getIndicator(new ClosePriceIndicator(data), 2);

// unstable bars skipped, unpredictable results
assertNumEquals((10d + 20d) / 2, indicator2.getValue(data.getBeginIndex() + 2));
assertNumEquals((5d + 10d) / 2, indicator2.getValue(data.getBeginIndex() + 1));
assertNumEquals((20d + 30d) / 2, indicator2.getValue(data.getBeginIndex() + 3));

Indicator<Num> indicator3 = getIndicator(new ClosePriceIndicator(data), 3);

// unstable bars skipped, unpredictable results
assertNumEquals((10d + 20d + 30d) / 3, indicator3.getValue(data.getBeginIndex() + 3));
assertNumEquals((5d + 10d + 20d) / 3, indicator3.getValue(data.getBeginIndex() + 2));
}

@Test
public void whenBarCountIs1ResultShouldBeIndicatorValue() {
data.addBar(new MockBar(5., numFunction));
data.addBar(new MockBar(5., numFunction));

Indicator<Num> indicator = getIndicator(new ClosePriceIndicator(data), 1);
for (int i = 0; i < data.getBarCount(); i++) {
assertEquals(data.getBar(i).getClosePrice(), indicator.getValue(i));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,16 @@
import org.ta4j.core.Indicator;
import org.ta4j.core.TestUtils;
import org.ta4j.core.indicators.helpers.ClosePriceIndicator;
import org.ta4j.core.mocks.MockBar;
import org.ta4j.core.mocks.MockBarSeries;
import org.ta4j.core.num.Num;

public class SMAIndicatorTest extends AbstractIndicatorTest<Indicator<Num>, Num> {

private final ExternalIndicatorTest xls;

public SMAIndicatorTest(Function<Number, Num> numFunction) throws Exception {
super((data, params) -> new SMAIndicator((Indicator<Num>) data, (int) params[0]), numFunction);
public SMAIndicatorTest(Function<Number, Num> numFunction) {
super((data, params) -> new SMAIndicator(data, (int) params[0]), numFunction);
xls = new XLSIndicatorTest(this.getClass(), "SMA.xls", 6, numFunction);
}

Expand All @@ -56,7 +57,7 @@ public void setUp() {
}

@Test
public void usingBarCount3UsingClosePrice() throws Exception {
public void usingBarCount3UsingClosePrice() {
Indicator<Num> indicator = getIndicator(new ClosePriceIndicator(data), 3);

assertNumEquals(1, indicator.getValue(0));
Expand All @@ -75,7 +76,22 @@ public void usingBarCount3UsingClosePrice() throws Exception {
}

@Test
public void whenBarCountIs1ResultShouldBeIndicatorValue() throws Exception {
public void usingBarCount3UsingClosePriceMovingSerie() {
data.setMaximumBarCount(13);
data.addBar(new MockBar(5., numFunction));

Indicator<Num> indicator = getIndicator(new ClosePriceIndicator(data), 3);

// unstable bars skipped, unpredictable results
assertNumEquals((3d + 4d + 3d) / 3, indicator.getValue(data.getBeginIndex() + 3));
assertNumEquals((4d + 3d + 4d) / 3, indicator.getValue(data.getBeginIndex() + 4));
assertNumEquals((3d + 4d + 5d) / 3, indicator.getValue(data.getBeginIndex() + 5));
assertNumEquals((4d + 5d + 4d) / 3, indicator.getValue(data.getBeginIndex() + 6));
assertNumEquals((3d + 2d + 5d) / 3, indicator.getValue(data.getBeginIndex() + 12));
}

@Test
public void whenBarCountIs1ResultShouldBeIndicatorValue() {
Indicator<Num> indicator = getIndicator(new ClosePriceIndicator(data), 1);
for (int i = 0; i < data.getBarCount(); i++) {
assertEquals(data.getBar(i).getClosePrice(), indicator.getValue(i));
Expand Down