Skip to content

Commit

Permalink
Implement Q reinforcement learning as in
Browse files Browse the repository at this point in the history
"Learning-Based Controlled Concurrency Testing" by
Mukherjee et al.

https://www.microsoft.com/en-us/research/uploads/prod/2019/12/QL-OOPSLA-2020.pdf

Use the Counter.tla spec to reproduce the calculator example from the
paper where the Reset action is downrated.

[Feature][TLC][Changelog]
  • Loading branch information
lemmy committed Dec 20, 2022
1 parent 213165a commit 8a0da69
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 6 deletions.
173 changes: 173 additions & 0 deletions tlatools/org.lamport.tlatools/src/tlc2/tool/RLSimulationWorker.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*******************************************************************************
* Copyright (c) 2022 Microsoft Research. All rights reserved.
*
* The MIT License (MIT)
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
* of the Software, and to permit persons to whom the Software is furnished to do
* so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN
* AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
* WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* Contributors:
* Markus Alexander Kuppe - initial API and implementation
******************************************************************************/
package tlc2.tool;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.LongAdder;

import tlc2.tool.liveness.ILiveCheck;
import tlc2.util.RandomGenerator;

// See https://www.microsoft.com/en-us/research/uploads/prod/2019/12/QL-OOPSLA-2020.pdf

// See https://github.com/microsoft/coyote/compare/main...pdeligia/rl-fuzzing

public class RLSimulationWorker extends SimulationWorker {

// Alpha = Learning Rate
private static final double ALPHA = Double.valueOf(System.getProperty(Simulator.class.getName() + ".cct.alpha", ".3d"));
// Gamma = Discount factor
private static final double GAMMA = Double.valueOf(System.getProperty(Simulator.class.getName() + ".cct.gamma", ".7d"));
private static final double REWARD = Double.valueOf(System.getProperty(Simulator.class.getName() + ".cct.reward", "-10d"));

private final Map<Action, Map<Long, Double>> q = new HashMap<>();

public RLSimulationWorker(int id, ITool tool, BlockingQueue<SimulationWorkerResult> resultQueue, long seed,
int maxTraceDepth, long maxTraceNum, boolean checkDeadlock, String traceFile, ILiveCheck liveCheck) {
this(id, tool, resultQueue, seed, maxTraceDepth, maxTraceNum, null, checkDeadlock, traceFile, liveCheck,
new LongAdder(), new AtomicLong(), new AtomicLong());
}

public RLSimulationWorker(int id, ITool tool, BlockingQueue<SimulationWorkerResult> resultQueue, long seed,
int maxTraceDepth, long maxTraceNum, String traceActions, boolean checkDeadlock, String traceFile,
ILiveCheck liveCheck, LongAdder numOfGenStates, AtomicLong numOfGenTraces, AtomicLong m2AndMean) {
super(id, tool, resultQueue, seed, maxTraceDepth, maxTraceNum, traceActions, checkDeadlock, traceFile, liveCheck,
numOfGenStates, numOfGenTraces, m2AndMean);

for (final Action a : tool.getActions()) {
q.put(a, new HashMap<>());
}
}

private final double getReward(final long fp, final Action a) {
// The reward is negative to force RL to find alternative solutions instead of
// finding the best (one) solution over again. For example, in a maze, RL would
// be rewarded +1 if it makes it to the exit. Here, we want to find other paths
// elsewhere.
return REWARD;
// TODO Experiment with other rewards.
}

private final double getMaxQ(final long fp) {
double max = -Double.MAX_VALUE;
for (Action a : q.keySet()) {
// Map#get instead of Map#getOrDefaults causes an NPE in max when the fp is unknown.
double d = this.q.get(a).getOrDefault(fp, -Double.MAX_VALUE);
max = Math.max(max, d);
}
return max;
}

protected final int getNextActionIndex(final RandomGenerator rng, final Action[] actions, final TLCState state) {
final long s = state.fingerPrint();

this.q.values().forEach(m -> m.putIfAbsent(s, 0d));

// Calculate the sum over all actions.
double denum = 0;
double[] d = new double[actions.length];
for (int i = 0; i < d.length; i++) {
d[i] = Math.exp(this.q.get(actions[i]).get(s));
denum += d[i];
}

// Apache commons-math with its Pair impl can replace the rest of this method.
// However, TLC does not come with commons-math.
// final List<Pair<Integer, Double>> arr = new ArrayList<>(d.length);
// for (int i = 0; i < d.length; i++) {
// arr.add(new Pair<>(i, d[i] / denum));
// }
//
// return new EnumeratedDistribution<>(arr).sample();

// Calculate the individual weight.
final ArrayList<Pair> m = new ArrayList<>(d.length);
for (int i = 0; i < d.length; i++) {
m.add(new Pair(i, d[i] / denum));
}

final double nd = rng.nextDouble();

// Sort m and calculate the cumulative weights.
java.util.Collections.sort(m);
for (int i = 0; i < d.length; i++) {
final Pair p = m.get(i);
d[i] = i == 0 ? p.key : d[i - 1] + p.key;

// Preemptively exit if the cumulative weight exceeds
// the uniformly chosen nd at random.
if (d[i] >= nd) {
return p.value;
}
}
// Fallback for issues with double precision above.
return m.get(d.length - 1).value;
}

protected boolean postTrace(TLCState s) {
final int level = s.getLevel();
for (int i = level - 1; i > 0; i--) {
final double maxQ = getMaxQ(s.fingerPrint());

final TLCState p = s.getPredecessor();
final long fp = p.fingerPrint();

final Action ai = s.getAction();

final double qi = this.q.get(ai).get(fp);
final double q = ((1d - ALPHA) * qi) + (ALPHA * (getReward(fp, ai) + (GAMMA * maxQ)));

this.q.get(ai).put(fp, q);

s = p;
}
return true;
}

private static class Pair implements Comparable<Pair> {
public final double key;
public final int value;

public Pair(int v, double k) {
key = k;
value = v;
}

@Override
public int compareTo(Pair o) {
return Double.compare(o.key, key);
}

@Override
public String toString() {
return "[key=" + key + ", value=" + value + "]";
}
}
}
15 changes: 13 additions & 2 deletions tlatools/org.lamport.tlatools/src/tlc2/tool/SimulationWorker.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
******************************************************************************/
package tlc2.tool;

import java.io.FileNotFoundException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -444,6 +445,10 @@ public SetOfStates getStates() {
}

private final StateVec nextStates = new StateVec(1);

protected int getNextActionIndex(RandomGenerator rng, Action[] actions, TLCState curState) {
return (int) Math.floor(this.localRng.nextDouble() * actions.length);
}

/**
* Generates a single random trace.
Expand Down Expand Up @@ -477,7 +482,7 @@ private Optional<SimulationWorkerError> simulateRandomTrace() throws Exception {

// b) Get the current state's successor states.
nextStates.clear();
int index = (int) Math.floor(this.localRng.nextDouble() * len);
int index = getNextActionIndex(this.localRng, actions, curState);
final int p = this.localRng.nextPrime();
for (int i = 0; i < len; i++) {
try {
Expand Down Expand Up @@ -577,12 +582,18 @@ public StateVec get() {
pw.close();
}

postTrace(curState);

// Finished trace generation without any errors.
return Optional.empty();
}

protected boolean postTrace(final TLCState finalState) throws FileNotFoundException {
return true;
}

public final long getTraceCnt() {
return this.traceCnt + 1; // +1 to account the currently generated behavior.
return this.traceCnt + 1; // +1 to account for the currently generated behavior.
}

public final StateVec getTrace() {
Expand Down
14 changes: 10 additions & 4 deletions tlatools/org.lamport.tlatools/src/tlc2/tool/Simulator.java
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,15 @@ public Simulator(ITool tool, String metadir, String traceFile, boolean deadlock,
this.numWorkers = numWorkers;
this.workers = new ArrayList<>(numWorkers);
for (int i = 0; i < this.numWorkers; i++) {
this.workers.add(new SimulationWorker(i, this.tool, this.workerResultQueue, this.rng.nextLong(),
this.traceDepth, this.traceNum, this.traceActions, this.checkDeadlock, this.traceFile,
this.liveCheck, this.numOfGenStates, this.numOfGenTraces, this.welfordM2AndMean));
if (Boolean.getBoolean(Simulator.class.getName() + ".cct")) {
this.workers.add(new RLSimulationWorker(i, this.tool, this.workerResultQueue, this.rng.nextLong(),
this.traceDepth, this.traceNum, this.traceActions, this.checkDeadlock, this.traceFile,
this.liveCheck, this.numOfGenStates, this.numOfGenTraces, this.welfordM2AndMean));
} else {
this.workers.add(new SimulationWorker(i, this.tool, this.workerResultQueue, this.rng.nextLong(),
this.traceDepth, this.traceNum, this.traceActions, this.checkDeadlock, this.traceFile,
this.liveCheck, this.numOfGenStates, this.numOfGenTraces, this.welfordM2AndMean));
}
}

// Eagerly create the config value in case the next-state relation involves
Expand Down Expand Up @@ -589,7 +595,7 @@ public void run() {
reportCoverage();
count = TLCGlobals.coverageInterval / TLCGlobals.progressInterval;
}

writeActionFlowGraph();
}
} catch (Exception e) {
Expand Down
54 changes: 54 additions & 0 deletions tlatools/org.lamport.tlatools/test-model/qlearning/Counter.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
require(ggplot2)
require(dplyr)
require(gridExtra)

df <- read.csv(header = TRUE, sep = "#", file = "./Counter.csv")
df$Val <- as.numeric(df$Val)
df$Time <- as.numeric(df$Time)

####################

## Action Frequency
rlaf <- df %>%
filter(Mode == "rl") %>%
group_by(Action) %>%
summarize(count = n()) %>%
ggplot(aes(x=Action, y=count)) +
geom_bar(stat="identity") +
ggtitle("RL Actions")

rndaf <- df %>%
filter(Mode == "random") %>%
group_by(Action) %>%
summarize(count = n()) %>%
ggplot(aes(x=Action, y=count)) +
geom_bar(stat="identity") +
ggtitle("Random Actions")

grid.arrange(rlaf, rndaf, ncol=2)

####################

dfrl <- df %>%
filter(Mode == "rl") %>%
group_by(Val) %>%
summarize(count = n())

rl <- dfrl %>%
ggplot(aes(x=Val, y=count)) +
geom_bar(stat="identity") +
ggtitle("RL") +
scale_y_continuous(trans='log2')

dfrnd <- df %>%
filter(Mode == "random") %>%
group_by(Val) %>%
summarize(count = n())

rnd <- dfrnd %>%
ggplot(aes(x=Val, y=count)) +
geom_bar(stat="identity") +
ggtitle("Random") +
scale_y_continuous(trans='log2')

grid.arrange(rl, rnd, ncol=2)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
INIT Init
NEXT Next

CONSTRAINT
StatisticsStateConstraint
44 changes: 44 additions & 0 deletions tlatools/org.lamport.tlatools/test-model/qlearning/Counter.tla
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
------ MODULE Counter ------
EXTENDS Integers, CSV, TLC, TLCExt, IOUtils

VARIABLE x

Init ==
x = 0

Add == x' = x + 1
Sub == x' = x - 1



Divide == x' = x \div 2
Multiply == x' = x * 2



Reset == x' = 0

Next ==
\/ Add
\/ Sub
\/ Reset
\/ Multiply
\/ Divide

---------------------------

CSVFile == "Counter.csv"

CSVColumnHeaders ==
"Mode#Time#Val#Action"

ASSUME
CSVRecords(CSVFile) = 0 =>
CSVWrite(CSVColumnHeaders, <<>>, CSVFile)

StatisticsStateConstraint ==
\* (TLCGet("level") > TLCGet("config").depth) =>
TLCDefer(CSVWrite("%1$s#%2$s#%3$s#%4$s",
<< IOEnv.mode,TLCGet("stats").generated,x,TLCGet("action").name>>, CSVFile))

=====

0 comments on commit 8a0da69

Please sign in to comment.