package org.apache.logging.log4j.spi;

import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.TimeUnit;

import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;

@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Benchmark)
@Warmup(iterations = 3, time = 1, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
@Fork(value = 1, jvmArgs = { "-Xmx3g", "-XX:+UseParallelGC" })
public class DefaultThreadContextMapCopyBenchmark {

    private static final int POISON_ITERATIONS = 10000;

    public static void main(String[] args) throws RunnerException {
        Options opt = new OptionsBuilder().include(DefaultThreadContextMapCopyBenchmark.class.getSimpleName()).build();
        new Runner(opt).run();
    }

    @Param({ "5" })
    private int mapSize;

    private DefaultThreadContextMap contextMap;

    private HashMap<String, Integer> inputHashMap;
    private TreeMap<String, Integer> inputTreeMap;
    private LinkedHashMap<String, Integer> inputLinkedHashMap;

    @Benchmark
    public Object copyMap() {
        return contextMap.getCopy();
    }

    @Setup(Level.Trial)
    public void setup() {
        // Create test data with identical contents
        inputHashMap = new HashMap<>();
        inputTreeMap = new TreeMap<>();
        inputLinkedHashMap = new LinkedHashMap<>();
        contextMap = new DefaultThreadContextMap();
        for (int i = 0; i < mapSize; i++) {
            contextMap.put(String.format("key%03d", i), "value" + i);
        }

        for (int i = 0; i < mapSize; i++) {
            String key = "key" + i;
            Integer value = i;
            inputHashMap.put(key, value);
            inputTreeMap.put(key, value);
            inputLinkedHashMap.put(key, value);
        }

        // Poison HashMap.<init>(Map) call site with polymorphic calls
        // This ensures the constructor becomes a polymorphic call site before
        // benchmarking
        for (int i = 0; i < POISON_ITERATIONS; i++) {
            Map<String, Integer> source = (i % 3 == 0) ? inputHashMap
                    : (i % 3 == 1) ? inputTreeMap : inputLinkedHashMap;
            HashMap<String, Integer> temp = new HashMap<>(source);
            if (temp.size() != mapSize)
                throw new RuntimeException();
        }
        // Poison the calls made from the manualInlining benchmark method, without
        // poisoning the benchmark method itself
        for (int i = 0; i < POISON_ITERATIONS; i++) {
            Map<String, Integer> source = (i % 3 == 0) ? inputHashMap
                    : (i % 3 == 1) ? inputTreeMap : inputLinkedHashMap;
            HashMap<String, Integer> temp = new HashMap<>(source.size());
            for (Map.Entry<String, Integer> entry : source.entrySet()) {
                temp.put(entry.getKey(), entry.getValue());
            }
            if (temp.size() != mapSize)
                throw new RuntimeException();
        }
    }

}
