Skip to content

Commit

Permalink
fix multiple reads in same scope (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
tersec committed Jun 27, 2023
1 parent d78b9dc commit 720fc5e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 48 deletions.
78 changes: 30 additions & 48 deletions faststreams/inputs.nim
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ when fsAsyncSupport:
template disconnectInputDevice(s: AsyncInputStream) =
disconnectInputDevice InputStream(s)

proc preventFurtherReading(s: InputStream) =
func preventFurtherReading(s: InputStream) =
s.vtable = nil
s.span = default(PageSpan)

Expand Down Expand Up @@ -203,7 +203,7 @@ let memFileInputVTable = InputStreamVTable(
except OSError as err:
raise newException(IOError, "Failed to close file", err)
,
getLenSync: proc (s: InputStream): Option[Natural]
getLenSync: func (s: InputStream): Option[Natural]
{.nimcall, gcsafe, raises: [IOError, Defect].} =
some s.span.len
)
Expand Down Expand Up @@ -262,7 +262,7 @@ func getNewSpanOrDieTrying(s: InputStream) =
getNewSpan s
fsAssert s.span.hasRunway

proc readableNow*(s: InputStream): bool =
func readableNow*(s: InputStream): bool =
if s.span.hasRunway: return true
getNewSpan s
s.span.hasRunway
Expand Down Expand Up @@ -337,7 +337,7 @@ func totalUnconsumedBytes*(s: InputStream): Natural =

localRunway + runwayInBuffers

proc limitReadableRange(s: InputStream, rangeLen: Natural): Natural =
func limitReadableRange(s: InputStream, rangeLen: Natural): Natural =
s.vtable = nil

let runway = s.span.len
Expand Down Expand Up @@ -441,7 +441,7 @@ proc fileInput*(filename: string,
let file = system.open(filename, fmRead)
return fileInput(file, offset, pageSize)

proc unsafeMemoryInput*(mem: openArray[byte]): InputStreamHandle =
func unsafeMemoryInput*(mem: openArray[byte]): InputStreamHandle =
let head = cast[ptr byte](mem)

makeHandle InputStream(
Expand All @@ -450,7 +450,7 @@ proc unsafeMemoryInput*(mem: openArray[byte]): InputStreamHandle =
endAddr: offset(head, mem.len)),
spanEndPos: mem.len)

proc unsafeMemoryInput*(str: string): InputStreamHandle =
func unsafeMemoryInput*(str: string): InputStreamHandle =
unsafeMemoryInput str.toOpenArrayByte(0, str.len - 1)

proc len*(s: InputStream): Option[Natural] {.raises: [Defect, IOError].} =
Expand Down Expand Up @@ -494,7 +494,7 @@ func memoryInput*(data: openArray[byte]): InputStreamHandle =
func memoryInput*(data: openArray[char]): InputStreamHandle =
memoryInput charsToBytes(data)

proc resetBuffers*(s: InputStream, buffers: PageBuffers) =
func resetBuffers*(s: InputStream, buffers: PageBuffers) =
# This should be used only on safe memory input streams
fsAssert s.vtable == nil and s.buffers != nil and buffers.len > 0
s.spanEndPos = 0
Expand Down Expand Up @@ -710,7 +710,7 @@ when fsAsyncSupport:
template read*(s: AsyncInputStream): byte =
read InputStream(s)

proc peekAt*(s: InputStream, pos: int): byte {.inline.} =
func peekAt*(s: InputStream, pos: int): byte {.inline.} =
# TODO implement page flipping
let peekHead = offset(s.span.startAddr, pos)
fsAssert cast[uint](peekHead) < cast[uint](s.span.endAddr)
Expand All @@ -720,13 +720,13 @@ when fsAsyncSupport:
template peekAt*(s: AsyncInputStream, pos: int): byte =
peekAt InputStream(s), pos

proc advance*(s: InputStream) =
func advance*(s: InputStream) =
if hasRunway(s.span):
bumpPointer s.span
else:
getNewSpan s

proc advance*(s: InputStream, n: Natural) =
func advance*(s: InputStream, n: Natural) =
# TODO This is silly, implement it properly
for i in 0 ..< n:
advance s
Expand All @@ -738,7 +738,7 @@ when fsAsyncSupport:
template advance*(s: AsyncInputStream, n: Natural) =
advance InputStream(s), n

proc drainBuffersInto*(s: InputStream, dstAddr: ptr byte, dstLen: Natural): Natural =
func drainBuffersInto*(s: InputStream, dstAddr: ptr byte, dstLen: Natural): Natural =
var
dst = dstAddr
remainingBytes = dstLen
Expand Down Expand Up @@ -872,31 +872,12 @@ when fsAsyncSupport:
let (dstAddr, dstLen) = openArrayToPair(dst)
readIntoExImpl(s, dstAddr, dstLen, fsAwait, readAsync) == dstLen

template useHeapMem(_: Natural) =
var buffer: seq[byte]

when (NimMajor, NimMinor) > (1, 6):
template allocMem(n: Natural): ptr byte {.redefine.} =
buffer.setLen(n)
addr buffer[0]
else:
template allocMem(n: Natural): ptr byte =
buffer.setLen(n)
addr buffer[0]

template useStackMem(n: static Natural) =
var buffer: array[n + 1, byte]

when (NimMajor, NimMinor) > (1, 6):
template allocMem(_: Natural): ptr byte {.redefine.} =
addr buffer[0]
else:
template allocMem(_: Natural): ptr byte =
addr buffer[0]
type MemAllocType {.pure.} = enum
StackMem, HeapMem

template readNImpl(sp: InputStream,
np: Natural,
createAllocMemOp: untyped): openArray[byte] =
memAllocType: static MemAllocType): openArray[byte] =
let
s = sp
n = np
Expand All @@ -909,19 +890,20 @@ template readNImpl(sp: InputStream,
# an `openArray` from the existing span.
var startAddr: ptr byte

# This defines the `allocMem` operation used below.
# See `useHeapMem` and `useStackMem` for the possible definitions.
#
# If the "var buffer" from `useHeapMem` or `useStackMem` is in the
# `block`, the ARC and ORC memory managers free it, when the block
# scope ends. The default approach, without -d:useMalloc, tends to
# obscure the resulting use-after-free which only means it becomes
# slightly rarer and less predictable.
createAllocMemOp(np)
# If the "var buffer" is in the `block`, the ARC and ORC memory managers free
# it when the block scope ends.
when memAllocType == MemAllocType.StackMem:
var buffer: array[np + 1, byte]
elif memAllocType == MemAllocType.HeapMem:
var buffer: seq[byte]
else:
static: doAssert false

block:
if n > runway:
startAddr = allocMem(n)
when memAllocType == MemAllocType.HeapMem:
buffer.setLen(n)
startAddr = addr buffer[0]
let drained {.used.} = drainBuffersInto(s, startAddr, n)
fsAssert drained == n
else:
Expand All @@ -933,18 +915,18 @@ template readNImpl(sp: InputStream,
template read*(sp: InputStream, np: static Natural): openArray[byte] =
const n = np
when n < maxStackUsage:
readNImpl(sp, n, useStackMem)
readNImpl(sp, n, MemAllocType.StackMem)
else:
readNImpl(sp, n, useHeapMem)
readNImpl(sp, n, MemAllocType.HeapMem)

template read*(s: InputStream, n: Natural): openArray[byte] =
readNImpl(s, n, useHeapMem)
readNImpl(s, n, MemAllocType.HeapMem)

when fsAsyncSupport:
template read*(s: AsyncInputStream, n: Natural): openArray[byte] =
read InputStream(s), n

proc lookAheadMatch*(s: InputStream, data: openArray[byte]): bool =
func lookAheadMatch*(s: InputStream, data: openArray[byte]): bool =
for i in 0 ..< data.len:
if s.peekAt(i) != data[i]:
return false
Expand All @@ -967,7 +949,7 @@ when fsAsyncSupport:
else:
none byte

proc pos*(s: InputStream): int {.inline.} =
func pos*(s: InputStream): int {.inline.} =
s.spanEndPos - s.span.len

when fsAsyncSupport:
Expand Down
19 changes: 19 additions & 0 deletions tests/test_inputs.nim
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,25 @@ suite "input stream":
fileContents.add input.read.char
break

elif r < 60:
# Test the ability to call readable() and read() multiple times from
# the same scope.
let readSize = 6 + rand(10)

if input.readable(readSize):
fileContents.add input.read(readSize).str
else:
while input.readable:
fileContents.add input.read.char
break

if input.readable(readSize):
fileContents.add input.read(readSize).str
else:
while input.readable:
fileContents.add input.read.char
break

else:
if input.readable:
fileContents.add input.read.char
Expand Down

0 comments on commit 720fc5e

Please sign in to comment.