Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvements #935

Open
wants to merge 7 commits into
base: cabg_cardio_gmf_htn_trial
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.mitre.synthea.engine.ExpressedSymptom.SymptomInfo;
import org.mitre.synthea.engine.ExpressedSymptom.SymptomSource;
Expand Down Expand Up @@ -140,8 +140,8 @@ public class ModuleConditions implements Cloneable, Serializable {
*/
public ModuleConditions(String source) {
this.source = source;
onsetConditions = new ConcurrentHashMap<String, OnsetCondition>();
state2conditionMapping = new ConcurrentHashMap<String, String>();
onsetConditions = new HashMap<String, OnsetCondition>();
state2conditionMapping = new HashMap<String, String>();
}

/**
Expand Down Expand Up @@ -231,7 +231,7 @@ public ConditionWithSymptoms(String name, Long onsetTime, Long endTime) {
this.conditionName = name;
this.onsetTime = onsetTime;
this.endTime = endTime;
this.symptoms = new ConcurrentHashMap<String, List<Integer>>();
this.symptoms = new HashMap<String, List<Integer>>();
}

/**
Expand Down Expand Up @@ -295,7 +295,7 @@ public Map<String, List<Integer>> getSymptoms() {

public ExpressedConditionRecord(Person person) {
this.person = person;
sources = new ConcurrentHashMap<String, ModuleConditions>();
sources = new HashMap<String, ModuleConditions>();
}

/**
Expand Down Expand Up @@ -384,7 +384,7 @@ public void onConditionEnd(String module, String condition, long time) {
public Map<Long, List<ConditionWithSymptoms>> getConditionSymptoms() {
Map<String, ExpressedSymptom> symptoms = person.getExpressedSymptoms();
Map<Long, List<ConditionWithSymptoms>> result;
result = new ConcurrentHashMap<Long, List<ConditionWithSymptoms>>();
result = new HashMap<Long, List<ConditionWithSymptoms>>();
for (String module : sources.keySet()) {
ModuleConditions moduleConditions = sources.get(module);
for (String condition : moduleConditions.getOnsetConditions().keySet()) {
Expand Down
24 changes: 13 additions & 11 deletions src/main/java/org/mitre/synthea/engine/ExpressedSymptom.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package org.mitre.synthea.engine;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class ExpressedSymptom implements Cloneable, Serializable {

Expand Down Expand Up @@ -64,7 +64,7 @@ public class SymptomSource implements Cloneable, Serializable {
*/
public SymptomSource(String source) {
this.source = source;
timeInfos = new ConcurrentHashMap<Long, ExpressedSymptom.SymptomInfo>();
timeInfos = new HashMap<Long, ExpressedSymptom.SymptomInfo>();
resolved = false;
lastUpdateTime = null;
}
Expand Down Expand Up @@ -114,8 +114,9 @@ public void addInfo(String cause, long time, int value, Boolean addressed) {
* Get the current value of the symptom.
*/
public Integer getCurrentValue() {
if (timeInfos.containsKey(lastUpdateTime)) {
return timeInfos.get(lastUpdateTime).getValue();
final SymptomInfo timeInfo = timeInfos.get(lastUpdateTime);
if (timeInfo != null) {
return timeInfo.getValue();
}
return null;
}
Expand All @@ -134,7 +135,7 @@ public Map<Long, SymptomInfo> getTimeInfos() {

public ExpressedSymptom(String name) {
this.name = name;
sources = new ConcurrentHashMap<String, SymptomSource>();
sources = new HashMap<String, SymptomSource>();
}

/**
Expand Down Expand Up @@ -165,12 +166,13 @@ public void onSet(String module, String cause, long time, int value, Boolean add
*/
public int getSymptom() {
int max = 0;
for (String module : sources.keySet()) {
Integer value = sources.get(module).getCurrentValue();
Boolean isResolved = sources.get(module).isResolved();
if (value != null && value.intValue() > max && !isResolved) {
max = value.intValue();
}
for (SymptomSource module : sources.values()) {
if(!module.isResolved()) {
Integer value = module.getCurrentValue();
if (value != null && value.intValue() > max) {
max = value.intValue();
}
}
}
return max;
}
Expand Down
7 changes: 6 additions & 1 deletion src/main/java/org/mitre/synthea/engine/Logic.java
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,12 @@ private abstract static class GroupedCondition extends Logic {
public static class And extends GroupedCondition {
@Override
public boolean test(Person person, long time) {
return conditions.stream().allMatch(c -> c.test(person, time));
for(Logic condition: conditions) {
if (!condition.test(person, time))
return false;
}

return true;
}
}

Expand Down
41 changes: 25 additions & 16 deletions src/main/java/org/mitre/synthea/engine/Module.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import org.mitre.synthea.helpers.Config;
import org.mitre.synthea.helpers.Utilities;
Expand Down Expand Up @@ -301,7 +302,7 @@ public Module(JsonObject definition, boolean submodule) throws Exception {
states.put(entry.getKey(), state);
}
}

/**
* Clone this module. Never provide the original.
*/
Expand All @@ -311,10 +312,12 @@ public Module clone() {
clone.submodule = this.submodule;
clone.remarks = this.remarks;
if (this.states != null) {
clone.states = new ConcurrentHashMap<String, State>();
for (String key : this.states.keySet()) {
clone.states.put(key, this.states.get(key).clone());
}
clone.states = this.states.entrySet().stream()
.collect(Collectors.toMap(
e -> e.getKey(),
e -> e.getValue().clone(),
(oldValue, newValue) -> newValue,
ConcurrentHashMap::new));
}
return clone;
}
Expand Down Expand Up @@ -423,11 +426,13 @@ public Collection<String> getStateNames() {
*/
public static class ModuleSupplier implements Supplier<Module> {

private final Object getLock = new Object();

public final boolean core;
public final boolean submodule;
public final String path;

private boolean loaded;
private volatile boolean loaded;
private Callable<Module> loader;
private Module module;
private Throwable fault;
Expand Down Expand Up @@ -461,17 +466,21 @@ public ModuleSupplier(Module module) {
}

@Override
public synchronized Module get() {
public Module get() {
if (!loaded) {
try {
module = loader.call();
} catch (Throwable e) {
e.printStackTrace();
fault = e;
} finally {
loaded = true;
loader = null;
}
synchronized(getLock) {
if (!loaded) {
try {
module = loader.call();
} catch (Throwable e) {
e.printStackTrace();
fault = e;
} finally {
loaded = true;
loader = null;
}
}
}
}
if (fault != null) {
throw new RuntimeException(fault);
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/org/mitre/synthea/engine/State.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.math.ode.DerivativeException;
Expand Down Expand Up @@ -1724,10 +1725,9 @@ public ObservationGroup clone() {
// we need to ensure we deep clone the list
// (otherwise as this gets passed around, the same objects are used for different patients
// which causes weird and unexpected results)
List<Observation> cloneObs = new ArrayList<>(observations.size());
for (Observation o : observations) {
cloneObs.add(o.clone());
}
List<Observation> cloneObs = observations.stream()
.map(Observation::clone)
.collect(Collectors.toCollection(ArrayList::new));
clone.observations = cloneObs;

return clone;
Expand Down
25 changes: 22 additions & 3 deletions src/main/java/org/mitre/synthea/helpers/Utilities.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@
import org.mitre.synthea.world.concepts.HealthRecord.Code;

public class Utilities {

private static final TimeZone UTC_TIME_ZONE = TimeZone.getTimeZone("UTC");

private static final ThreadLocal<Calendar> threadLocalCalendar = new ThreadLocal<Calendar>(){
@Override
protected Calendar initialValue() {
return Calendar.getInstance(UTC_TIME_ZONE);
}
};

/**
* Convert a quantity of time in a specified units into milliseconds.
*
Expand Down Expand Up @@ -94,7 +104,7 @@ public static long convertCalendarYearsToTime(int years) {
* Get the year of a Unix timestamp.
*/
public static int getYear(long time) {
Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("UTC"));
Calendar calendar = threadLocalCalendar.get();
calendar.setTimeInMillis(time);
return calendar.get(Calendar.YEAR);
}
Expand All @@ -103,7 +113,7 @@ public static int getYear(long time) {
* Get the month of a Unix timestamp.
*/
public static int getMonth(long time) {
Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("UTC"));
Calendar calendar = threadLocalCalendar.get();
calendar.setTimeInMillis(time);
return calendar.get(Calendar.MONTH) + 1;
}
Expand Down Expand Up @@ -133,11 +143,20 @@ public static Object primitive(JsonPrimitive p) {
return retVal;
}

private static double timestepCache = Double.NaN;

private static double getTimestep() {
if(timestepCache == Double.NaN)
timestepCache = Double.parseDouble(Config.get("generate.timestep"));

return timestepCache;
}

/**
* Calculates 1 - (1-risk)^(currTimeStepInMS/originalPeriodInMS).
*/
public static double convertRiskToTimestep(double risk, double originalPeriodInMS) {
double currTimeStepInMS = Double.parseDouble(Config.get("generate.timestep"));
double currTimeStepInMS = getTimestep();

return convertRiskToTimestep(risk, originalPeriodInMS, currTimeStepInMS);
}
Expand Down
11 changes: 10 additions & 1 deletion src/main/java/org/mitre/synthea/modules/LifecycleModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,15 @@ private static Map<String, Double> createDrugImpactsMap() {
private static final double[] RESPIRATION_RATE_NORMAL =
BiometricsConfig.doubles("respiratory.respiration_rate.normal");

private static long timestepCache = Long.MIN_VALUE;

private static long getTimestep() {
if(timestepCache == Long.MIN_VALUE)
timestepCache = Long.parseLong(Config.get("generate.timestep"));

return timestepCache;
}

/**
* Calculate this person's vital signs,
* based on their conditions, medications, body composition, etc.
Expand Down Expand Up @@ -743,7 +752,7 @@ private static void calculateVitalSigns(Person person, long time) {
person.setVitalSign(VitalSign.CARBON_DIOXIDE, person.rand(CO2_RANGE));
person.setVitalSign(VitalSign.SODIUM, person.rand(SODIUM_RANGE));

long timestep = Long.parseLong(Config.get("generate.timestep"));
long timestep = getTimestep();
double heartStart = person.rand(HEART_RATE_NORMAL);
double heartEnd = person.rand(HEART_RATE_NORMAL);
person.setVitalSign(VitalSign.HEART_RATE,
Expand Down
Loading