Skip to content

Commit

Permalink
Reused RunningTotalIndicator
Browse files Browse the repository at this point in the history
  • Loading branch information
sgflt committed May 19, 2024
1 parent 356116f commit 3269e00
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 43 deletions.
44 changes: 5 additions & 39 deletions ta4j-core/src/main/java/org/ta4j/core/indicators/SMAIndicator.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
package org.ta4j.core.indicators;

import org.ta4j.core.Indicator;
import org.ta4j.core.indicators.helpers.RunningTotalIndicator;
import org.ta4j.core.num.Num;

/**
Expand All @@ -34,12 +35,9 @@
*/
public class SMAIndicator extends CachedIndicator<Num> {

private final Indicator<Num> indicator;
private final int barCount;
private Num previousSum = zero();
private RunningTotalIndicator previousSum;

// serial access detection
private int previousIndex = -1;

/**
* Constructor.
Expand All @@ -49,52 +47,20 @@ public class SMAIndicator extends CachedIndicator<Num> {
*/
public SMAIndicator(Indicator<Num> indicator, int barCount) {
super(indicator);
this.indicator = indicator;
this.previousSum = new RunningTotalIndicator(indicator,barCount);
this.barCount = barCount;
}

@Override
protected Num calculate(int index) {
// serial access can benefit from previous partial sums
// which saves a lot of CPU work for very long barCounts
if (previousIndex != -1 && previousIndex == index - 1) {
return fastPath(index);
}

return slowPath(index);
}

private Num fastPath(final int index) {
var newSum = partialSum(index);
final int realBarCount = Math.min(barCount, index + 1);
updatePartialSum(index, newSum);
return newSum.dividedBy(numOf(realBarCount));
}

private Num slowPath(final int index) {
Num sum = zero();
for (int i = Math.max(0, index - barCount + 1); i <= index; i++) {
sum = sum.plus(indicator.getValue(i));
}

final int realBarCount = Math.min(barCount, index + 1);
updatePartialSum(index, sum);
final var sum = partialSum(index);
return sum.dividedBy(numOf(realBarCount));
}

private void updatePartialSum(final int index, final Num sum) {
previousIndex = index;
previousSum = sum;
}

private Num partialSum(int index) {
var sum = this.previousSum.plus(indicator.getValue(index));

if (index >= barCount) {
return sum.minus(indicator.getValue(index - barCount));
}

return sum;
return this.previousSum.getValue(index);
}

/** @return {@link #barCount} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,68 @@
public class RunningTotalIndicator extends CachedIndicator<Num> {
private final Indicator<Num> indicator;
private final int barCount;
private Num previousSum = zero();

// serial access detection
private int previousIndex = -1;

public RunningTotalIndicator(Indicator<Num> indicator, int barCount) {
super(indicator);
this.indicator = indicator;
this.barCount = barCount;
}


@Override
protected Num calculate(int index) {
// serial access can benefit from previous partial sums
// which saves a lot of CPU work for very long barCounts
if (previousIndex != -1 && previousIndex == index - 1) {
return fastPath(index);
}

return slowPath(index);
}

private Num fastPath(final int index) {
var newSum = partialSum(index);
updatePartialSum(index, newSum);
return newSum;
}

private Num slowPath(final int index) {
Num sum = zero();
for (int i = Math.max(getBarSeries().getBeginIndex(), index - barCount + 1); i <= index; i++) {
for (int i = Math.max(0, index - barCount + 1); i <= index; i++) {
sum = sum.plus(indicator.getValue(i));
}

updatePartialSum(index, sum);
return sum;
}

@Override
public int getUnstableBars() {
return 0;
private void updatePartialSum(final int index, final Num sum) {
previousIndex = index;
previousSum = sum;
}

private Num partialSum(int index) {
var sum = this.previousSum.plus(indicator.getValue(index));

if (index >= barCount) {
return sum.minus(indicator.getValue(index - barCount));
}

return sum;
}

@Override
public String toString() {
return getClass().getSimpleName() + " barCount: " + barCount;
}


@Override
public int getUnstableBars() {
return barCount;
}
}

0 comments on commit 3269e00

Please sign in to comment.