Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tune thomaswue submission. #1

Merged
merged 1 commit into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions calculate_average_thomaswue.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# limitations under the License.
#

NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native"
JAVA_OPTS=""
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -march=native --enable-preview"
JAVA_OPTS="--enable-preview"

if [ -z "${JVM_MODE}" ]; then
# Chosing native image mode, set JVM_MODE variable to select JVM mode.
Expand Down
188 changes: 106 additions & 82 deletions src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,48 @@
*/
package dev.morling.onebrc;

import sun.misc.Unsafe;

import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.lang.foreign.Arena;
import java.lang.reflect.Field;
import java.nio.channels.FileChannel;
import java.nio.file.Paths;
import java.nio.channels.FileChannel.MapMode;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.stream.IntStream;

public class CalculateAverage_thomaswue {
private static final String FILE = "./measurements.txt";
private static final int MAX_CITY_NAME_LENGTH = 100;

// Segment in the file that will be processed in parallel.
private record Segment(long start, int size) {
};

// Holding the current result for a single city.
private static class Result {
int max;
int min;
int max;
long sum;
int count;
byte[] name;
final long nameAddress;
final int nameLength;

private Result(long nameAddress, int nameLength, int value) {
this.nameAddress = nameAddress;
this.nameLength = nameLength;
this.min = value;
this.max = value;
this.sum = value;
this.count = 1;
}

public String toString() {
return round(((double) min) / 10.0) + "/" + round((((double) sum) / 10.0) / count) + "/" + round(((double) max) / 10.0);
}

private double round(double value) {
private static double round(double value) {
return Math.round(value * 10.0) / 10.0;
}

Expand All @@ -56,35 +69,29 @@ private void add(Result other) {
}
}

public static void main(String[] args) {
public static void main(String[] args) throws IOException {
// Calculate input segments.
List<Segment> segments = getSegments();
int numberOfChunks = Runtime.getRuntime().availableProcessors();
long[] chunks = getSegments(numberOfChunks);

// Parallel processing of segments.
List<HashMap<String, Result>> allResults = segments.stream().map(s -> {
HashMap<String, Result> cities = new HashMap<>();
byte[] name = new byte[MAX_CITY_NAME_LENGTH];
Result[] results = new Result[1 << 18];
try (FileChannel ch = (FileChannel) java.nio.file.Files.newByteChannel(Paths.get(FILE), StandardOpenOption.READ)) {
ByteBuffer bf = ch.map(FileChannel.MapMode.READ_ONLY, s.start(), s.size());
parseLoop(bf, name, results, cities);
}
catch (IOException e) {
throw new RuntimeException(e);
}
List<HashMap<String, Result>> allResults = IntStream.range(0, chunks.length - 1).mapToObj(chunkIndex -> {
HashMap<String, Result> cities = HashMap.newHashMap(1 << 10);
Result[] results = new Result[1 << 14];
parseLoop(chunks[chunkIndex], chunks[chunkIndex + 1], results, cities);
return cities;
}).parallel().toList();

// Accumulate results sequentially.
HashMap<String, Result> result = allResults.getFirst();
for (int i = 1; i < allResults.size(); ++i) {
for (Map.Entry<String, Result> r : allResults.get(i).entrySet()) {
Result current = result.get(r.getKey());
for (Map.Entry<String, Result> entry : allResults.get(i).entrySet()) {
Result current = result.get(entry.getKey());
if (current != null) {
current.add(r.getValue());
current.add(entry.getValue());
}
else {
result.put(r.getKey(), r.getValue());
result.put(entry.getKey(), entry.getValue());
}
}
}
Expand All @@ -93,91 +100,108 @@ public static void main(String[] args) {
System.out.println(new TreeMap<>(result));
}

private static void parseLoop(ByteBuffer bf, byte[] name, Result[] results, HashMap<String, Result> cities) {
int pos = 0;
private static final Unsafe UNSAFE = initUnsafe();

private static Unsafe initUnsafe() {
try {
Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
theUnsafe.setAccessible(true);
return (Unsafe) theUnsafe.get(Unsafe.class);
}
catch (NoSuchFieldException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}

static boolean unsafeEquals(long aStart, long aLength, long bStart, long bLength) {
if (aLength != bLength) {
return false;
}
for (int i = 0; i < aLength; ++i) {
if (UNSAFE.getByte(aStart + i) != UNSAFE.getByte(bStart + i)) {
return false;
}
}
return true;
}

private static void parseLoop(long chunkStart, long chunkEnd, Result[] results, HashMap<String, Result> cities) {
long scanPtr = chunkStart;
byte b;
while (pos < bf.limit()) {
int hash = 0;
int nameIndex = 0;
while ((b = bf.get(pos++)) != ';') {
while (scanPtr < chunkEnd) {
long nameAddress = scanPtr;

int hash = UNSAFE.getByte(scanPtr++);
while ((b = UNSAFE.getByte(scanPtr++)) != ';') {
hash += b;
hash += hash << 10;
hash ^= hash >> 6;
name[nameIndex++] = b;
}

int nameLength = (int) (scanPtr - 1 - nameAddress);
hash = hash & (results.length - 1);

int number;
byte sign = bf.get(pos++);
boolean isMinus = false;
byte sign = UNSAFE.getByte(scanPtr++);
if (sign == '-') {
isMinus = true;
number = bf.get(pos++) - '0';
number = UNSAFE.getByte(scanPtr++) - '0';
}
else {
number = sign - '0';
}
while ((b = bf.get(pos++)) != '.') {
while ((b = UNSAFE.getByte(scanPtr++)) != '.') {
number = number * 10 + b - '0';
}
number = number * 10 + bf.get(pos++) - '0';
if (isMinus) {
number = number * 10 + UNSAFE.getByte(scanPtr++) - '0';
if (sign == '-') {
number = -number;
}

while (true) {
Result existingResult = results[hash];
if (existingResult == null) {
Result r = new Result();
r.name = new byte[nameIndex];
r.max = number;
r.min = number;
r.count = 1;
r.sum = number;
System.arraycopy(name, 0, r.name, 0, nameIndex);
cities.put(new String(r.name), r);
Result r = new Result(nameAddress, nameLength, number);
results[hash] = r;
byte[] bytes = new byte[nameLength];
UNSAFE.copyMemory(null, nameAddress, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, nameLength);
cities.put(new String(bytes, StandardCharsets.UTF_8), r);
break;
}
else if (unsafeEquals(existingResult.nameAddress, existingResult.nameLength, nameAddress, nameLength)) {
existingResult.min = Math.min(existingResult.min, number);
existingResult.max = Math.max(existingResult.max, number);
existingResult.sum += number;
existingResult.count++;
break;
}
else {
if (Arrays.equals(existingResult.name, 0, nameIndex, name, 0, nameIndex)) {
existingResult.count++;
existingResult.max = Math.max(existingResult.max, number);
existingResult.min = Math.min(existingResult.min, number);
existingResult.sum += number;
break;
}
else {
// Collision error, try next.
hash = (hash + 1) & (results.length - 1);
}
// Collision error, try next.
hash = (hash + 1) & (results.length - 1);
}
}

// Skip new line.
pos++;
scanPtr++;
}
}

private static List<Segment> getSegments() {
try (RandomAccessFile file = new RandomAccessFile(FILE, "r")) {
long totalSize = file.length();
int cores = Runtime.getRuntime().availableProcessors();
int segmentSize = ((int) (totalSize / cores));
List<Segment> segments = new ArrayList<>();
long filePos = 0;
while (filePos < totalSize - segmentSize) {
file.seek(filePos + segmentSize);
while (file.read() != '\n')
;
segments.add(new Segment(filePos, (int) (file.getFilePointer() - filePos)));
filePos = file.getFilePointer();
private static long[] getSegments(int numberOfChunks) throws IOException {
try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) {
long fileSize = fileChannel.size();
long segmentSize = (fileSize + numberOfChunks - 1) / numberOfChunks;
long[] chunks = new long[numberOfChunks + 1];
long mappedAddress = fileChannel.map(MapMode.READ_ONLY, 0, fileSize, Arena.global()).address();
chunks[0] = mappedAddress;
for (int i = 1; i < numberOfChunks; ++i) {
long chunkAddress = mappedAddress + i * segmentSize;
// Align to first row start.
while (UNSAFE.getByte(chunkAddress++) != '\n') {
// nop
}
chunks[i] = chunkAddress;
}
segments.add(new Segment(filePos, (int) (totalSize - filePos)));
return segments;
}
catch (IOException e) {
throw new RuntimeException(e);
chunks[numberOfChunks] = mappedAddress + fileSize;
return chunks;
}
}
}
}