diff --git a/CHANGELOG.md b/CHANGELOG.md index 335b1498d..2f77e8534 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,7 @@ Changelog for `ta4j`, roughly following [keepachangelog.com](http://keepachangelog.com/en/1.0.0/) from version 0.9 onwards. ## 0.17 - +- **Implemented inner cache for SMAIndicator** ## 0.16 (released May 15, 2024) diff --git a/ta4j-core/src/main/java/org/ta4j/core/indicators/SMAIndicator.java b/ta4j-core/src/main/java/org/ta4j/core/indicators/SMAIndicator.java index 704647da4..abbcc89ea 100644 --- a/ta4j-core/src/main/java/org/ta4j/core/indicators/SMAIndicator.java +++ b/ta4j-core/src/main/java/org/ta4j/core/indicators/SMAIndicator.java @@ -24,6 +24,7 @@ package org.ta4j.core.indicators; import org.ta4j.core.Indicator; +import org.ta4j.core.indicators.helpers.RunningTotalIndicator; import org.ta4j.core.num.Num; /** @@ -34,8 +35,8 @@ */ public class SMAIndicator extends CachedIndicator { - private final Indicator indicator; private final int barCount; + private RunningTotalIndicator previousSum; /** * Constructor. @@ -45,21 +46,21 @@ public class SMAIndicator extends CachedIndicator { */ public SMAIndicator(Indicator indicator, int barCount) { super(indicator); - this.indicator = indicator; + this.previousSum = new RunningTotalIndicator(indicator, barCount); this.barCount = barCount; } @Override protected Num calculate(int index) { - Num sum = zero(); - for (int i = Math.max(0, index - barCount + 1); i <= index; i++) { - sum = sum.plus(indicator.getValue(i)); - } - final int realBarCount = Math.min(barCount, index + 1); + final var sum = partialSum(index); return sum.dividedBy(numOf(realBarCount)); } + private Num partialSum(int index) { + return this.previousSum.getValue(index); + } + /** @return {@link #barCount} */ @Override public int getUnstableBars() { diff --git a/ta4j-core/src/main/java/org/ta4j/core/indicators/helpers/RunningTotalIndicator.java b/ta4j-core/src/main/java/org/ta4j/core/indicators/helpers/RunningTotalIndicator.java index d36cfb535..7dcc4c40f 100644 --- a/ta4j-core/src/main/java/org/ta4j/core/indicators/helpers/RunningTotalIndicator.java +++ b/ta4j-core/src/main/java/org/ta4j/core/indicators/helpers/RunningTotalIndicator.java @@ -36,6 +36,10 @@ public class RunningTotalIndicator extends CachedIndicator { private final Indicator indicator; private final int barCount; + private Num previousSum = zero(); + + // serial access detection + private int previousIndex = -1; public RunningTotalIndicator(Indicator indicator, int barCount) { super(indicator); @@ -45,20 +49,53 @@ public RunningTotalIndicator(Indicator indicator, int barCount) { @Override protected Num calculate(int index) { + // serial access can benefit from previous partial sums + // which saves a lot of CPU work for very long barCounts + if (previousIndex != -1 && previousIndex == index - 1) { + return fastPath(index); + } + + return slowPath(index); + } + + private Num fastPath(final int index) { + var newSum = partialSum(index); + updatePartialSum(index, newSum); + return newSum; + } + + private Num slowPath(final int index) { Num sum = zero(); - for (int i = Math.max(getBarSeries().getBeginIndex(), index - barCount + 1); i <= index; i++) { + for (int i = Math.max(0, index - barCount + 1); i <= index; i++) { sum = sum.plus(indicator.getValue(i)); } + + updatePartialSum(index, sum); return sum; } - @Override - public int getUnstableBars() { - return 0; + private void updatePartialSum(final int index, final Num sum) { + previousIndex = index; + previousSum = sum; + } + + private Num partialSum(int index) { + var sum = this.previousSum.plus(indicator.getValue(index)); + + if (index >= barCount) { + return sum.minus(indicator.getValue(index - barCount)); + } + + return sum; } @Override public String toString() { return getClass().getSimpleName() + " barCount: " + barCount; } + + @Override + public int getUnstableBars() { + return barCount; + } } diff --git a/ta4j-core/src/test/java/org/ta4j/core/indicators/SMAIndicatorMovingSerieTest.java b/ta4j-core/src/test/java/org/ta4j/core/indicators/SMAIndicatorMovingSerieTest.java new file mode 100644 index 000000000..1cf19ce6c --- /dev/null +++ b/ta4j-core/src/test/java/org/ta4j/core/indicators/SMAIndicatorMovingSerieTest.java @@ -0,0 +1,152 @@ +/** + * The MIT License (MIT) + * + * Copyright (c) 2017-2023 Ta4j Organization & respective + * authors (see AUTHORS) + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +package org.ta4j.core.indicators; + +import static org.junit.Assert.assertEquals; +import static org.ta4j.core.TestUtils.assertNumEquals; + +import java.util.function.Function; + +import org.junit.Before; +import org.junit.Test; +import org.ta4j.core.BarSeries; +import org.ta4j.core.Indicator; +import org.ta4j.core.indicators.helpers.ClosePriceIndicator; +import org.ta4j.core.mocks.MockBar; +import org.ta4j.core.mocks.MockBarSeries; +import org.ta4j.core.num.Num; + +public class SMAIndicatorMovingSerieTest extends AbstractIndicatorTest, Num> { + + public SMAIndicatorMovingSerieTest(Function numFunction) { + super((data, params) -> new SMAIndicator(data, (int) params[0]), numFunction); + } + + private BarSeries data; + + @Before + public void setUp() { + data = new MockBarSeries(numFunction, 1, 2, 3, 4, 7); + data.setMaximumBarCount(4); + } + + @Test + public void usingBarCount3MovingSeries() { + firstAddition(); + secondAddition(); + thirdAddition(); + fourthAddition(); + randomAccessAfterFourAdditions(); + } + + private void firstAddition() { + data.addBar(new MockBar(5., numFunction)); + Indicator indicator2 = getIndicator(new ClosePriceIndicator(data), 2); + + // unstable bars skipped, unpredictable results + assertNumEquals((3d + 4d) / 2, indicator2.getValue(data.getBeginIndex() + 1)); + assertNumEquals((4d + 7d) / 2, indicator2.getValue(data.getBeginIndex() + 2)); + assertNumEquals((7d + 5d) / 2, indicator2.getValue(data.getBeginIndex() + 3)); + + Indicator indicator3 = getIndicator(new ClosePriceIndicator(data), 3); + + // unstable bars skipped, unpredictable results + assertNumEquals((3d + 4d + 7d) / 3, indicator3.getValue(data.getBeginIndex() + 2)); + assertNumEquals((4d + 7d + 5d) / 3, indicator3.getValue(data.getBeginIndex() + 3)); + } + + private void secondAddition() { + data.addBar(new MockBar(10., numFunction)); + Indicator indicator2 = getIndicator(new ClosePriceIndicator(data), 2); + + // unstable bars skipped, unpredictable results + assertNumEquals((4d + 7d) / 2, indicator2.getValue(data.getBeginIndex() + 1)); + assertNumEquals((7d + 5d) / 2, indicator2.getValue(data.getBeginIndex() + 2)); + assertNumEquals((5d + 10d) / 2, indicator2.getValue(data.getBeginIndex() + 3)); + + Indicator indicator3 = getIndicator(new ClosePriceIndicator(data), 3); + + // unstable bars skipped, unpredictable results + assertNumEquals((4d + 7d + 5d) / 3, indicator3.getValue(data.getBeginIndex() + 2)); + assertNumEquals((7d + 5d + 10d) / 3, indicator3.getValue(data.getBeginIndex() + 3)); + } + + private void thirdAddition() { + data.addBar(new MockBar(20., numFunction)); + Indicator indicator2 = getIndicator(new ClosePriceIndicator(data), 2); + + // unstable bars skipped, unpredictable results + assertNumEquals((7d + 5d) / 2, indicator2.getValue(data.getBeginIndex() + 1)); + assertNumEquals((5d + 10d) / 2, indicator2.getValue(data.getBeginIndex() + 2)); + assertNumEquals((10d + 20d) / 2, indicator2.getValue(data.getBeginIndex() + 3)); + + Indicator indicator3 = getIndicator(new ClosePriceIndicator(data), 3); + + // unstable bars skipped, unpredictable results + assertNumEquals((7d + 5d + 10d) / 3, indicator3.getValue(data.getBeginIndex() + 2)); + assertNumEquals((5d + 10d + 20d) / 3, indicator3.getValue(data.getBeginIndex() + 3)); + } + + private void fourthAddition() { + data.addBar(new MockBar(30., numFunction)); + Indicator indicator2 = getIndicator(new ClosePriceIndicator(data), 2); + + // unstable bars skipped, unpredictable results + assertNumEquals((5d + 10d) / 2, indicator2.getValue(data.getBeginIndex() + 1)); + assertNumEquals((10d + 20d) / 2, indicator2.getValue(data.getBeginIndex() + 2)); + assertNumEquals((20d + 30d) / 2, indicator2.getValue(data.getBeginIndex() + 3)); + + Indicator indicator3 = getIndicator(new ClosePriceIndicator(data), 3); + + // unstable bars skipped, unpredictable results + assertNumEquals((5d + 10d + 20d) / 3, indicator3.getValue(data.getBeginIndex() + 2)); + assertNumEquals((10d + 20d + 30d) / 3, indicator3.getValue(data.getBeginIndex() + 3)); + } + + private void randomAccessAfterFourAdditions() { + Indicator indicator2 = getIndicator(new ClosePriceIndicator(data), 2); + + // unstable bars skipped, unpredictable results + assertNumEquals((10d + 20d) / 2, indicator2.getValue(data.getBeginIndex() + 2)); + assertNumEquals((5d + 10d) / 2, indicator2.getValue(data.getBeginIndex() + 1)); + assertNumEquals((20d + 30d) / 2, indicator2.getValue(data.getBeginIndex() + 3)); + + Indicator indicator3 = getIndicator(new ClosePriceIndicator(data), 3); + + // unstable bars skipped, unpredictable results + assertNumEquals((10d + 20d + 30d) / 3, indicator3.getValue(data.getBeginIndex() + 3)); + assertNumEquals((5d + 10d + 20d) / 3, indicator3.getValue(data.getBeginIndex() + 2)); + } + + @Test + public void whenBarCountIs1ResultShouldBeIndicatorValue() { + data.addBar(new MockBar(5., numFunction)); + data.addBar(new MockBar(5., numFunction)); + + Indicator indicator = getIndicator(new ClosePriceIndicator(data), 1); + for (int i = 0; i < data.getBarCount(); i++) { + assertEquals(data.getBar(i).getClosePrice(), indicator.getValue(i)); + } + } +} diff --git a/ta4j-core/src/test/java/org/ta4j/core/indicators/SMAIndicatorTest.java b/ta4j-core/src/test/java/org/ta4j/core/indicators/SMAIndicatorTest.java index 1d800989e..f6c604cac 100644 --- a/ta4j-core/src/test/java/org/ta4j/core/indicators/SMAIndicatorTest.java +++ b/ta4j-core/src/test/java/org/ta4j/core/indicators/SMAIndicatorTest.java @@ -36,6 +36,7 @@ import org.ta4j.core.Indicator; import org.ta4j.core.TestUtils; import org.ta4j.core.indicators.helpers.ClosePriceIndicator; +import org.ta4j.core.mocks.MockBar; import org.ta4j.core.mocks.MockBarSeries; import org.ta4j.core.num.Num; @@ -43,8 +44,8 @@ public class SMAIndicatorTest extends AbstractIndicatorTest, Num> private final ExternalIndicatorTest xls; - public SMAIndicatorTest(Function numFunction) throws Exception { - super((data, params) -> new SMAIndicator((Indicator) data, (int) params[0]), numFunction); + public SMAIndicatorTest(Function numFunction) { + super((data, params) -> new SMAIndicator(data, (int) params[0]), numFunction); xls = new XLSIndicatorTest(this.getClass(), "SMA.xls", 6, numFunction); } @@ -56,7 +57,7 @@ public void setUp() { } @Test - public void usingBarCount3UsingClosePrice() throws Exception { + public void usingBarCount3UsingClosePrice() { Indicator indicator = getIndicator(new ClosePriceIndicator(data), 3); assertNumEquals(1, indicator.getValue(0)); @@ -75,7 +76,22 @@ public void usingBarCount3UsingClosePrice() throws Exception { } @Test - public void whenBarCountIs1ResultShouldBeIndicatorValue() throws Exception { + public void usingBarCount3UsingClosePriceMovingSerie() { + data.setMaximumBarCount(13); + data.addBar(new MockBar(5., numFunction)); + + Indicator indicator = getIndicator(new ClosePriceIndicator(data), 3); + + // unstable bars skipped, unpredictable results + assertNumEquals((3d + 4d + 3d) / 3, indicator.getValue(data.getBeginIndex() + 3)); + assertNumEquals((4d + 3d + 4d) / 3, indicator.getValue(data.getBeginIndex() + 4)); + assertNumEquals((3d + 4d + 5d) / 3, indicator.getValue(data.getBeginIndex() + 5)); + assertNumEquals((4d + 5d + 4d) / 3, indicator.getValue(data.getBeginIndex() + 6)); + assertNumEquals((3d + 2d + 5d) / 3, indicator.getValue(data.getBeginIndex() + 12)); + } + + @Test + public void whenBarCountIs1ResultShouldBeIndicatorValue() { Indicator indicator = getIndicator(new ClosePriceIndicator(data), 1); for (int i = 0; i < data.getBarCount(); i++) { assertEquals(data.getBar(i).getClosePrice(), indicator.getValue(i));