diff --git a/src/main/java/com/tamingtext/util/SplitInput.java b/src/main/java/com/tamingtext/util/SplitInput.java index 9f1ba0e..f3c316d 100644 --- a/src/main/java/com/tamingtext/util/SplitInput.java +++ b/src/main/java/com/tamingtext/util/SplitInput.java @@ -26,7 +26,8 @@ import java.io.Writer; import java.nio.charset.Charset; import java.util.BitSet; -import java.util.Collections; +import java.util.HashSet; +import java.util.Set; import org.apache.commons.cli2.CommandLine; import org.apache.commons.cli2.Group; @@ -353,16 +354,19 @@ else if (fs.getFileStatus(inputFile).isDir()) { BufferedReader reader = new BufferedReader(new InputStreamReader(fs.open(inputFile), charset)); Writer trainingWriter = new OutputStreamWriter(fs.create(trainingOutputFile), charset); Writer testWriter = new OutputStreamWriter(fs.create(testOutputFile), charset); + Set writers = new HashSet(); + writers.add(trainingWriter); + writers.add(testWriter); int pos = 0; int trainCount = 0; int testCount = 0; String line; + Writer writer; while ((line = reader.readLine()) != null) { pos++; - Writer writer; if (testRandomSelectionPct > 0) { // Randomly choose writer = randomSel.get(pos) ? testWriter : trainingWriter; } else { // Choose based on location @@ -384,9 +388,8 @@ else if (fs.getFileStatus(inputFile).isDir()) { writer.write(line); writer.write('\n'); } - - IOUtils.close(Collections.singleton(trainingWriter)); - IOUtils.close(Collections.singleton(testWriter)); + + IOUtils.close(writers); log.info("file: {}, input: {} train: {}, test: {} starting at {}", new Object[] {inputFile.getName(), lineCount, trainCount, testCount, testSplitStart}); @@ -580,7 +583,12 @@ public static int countLines(FileSystem fs, Path inputFile, Charset charset) thr lineCount++; } } finally { - IOUtils.close(Collections.singleton(countReader)); + try { + countReader.close(); + } + catch (IOException ex) { + log.warn("Could not close line count reader", ex); + } } return lineCount;