Skip to content

Commit

Permalink
cs: fix ptr-ref on result of ptr-add
Browse files Browse the repository at this point in the history
Fix `ptr-ref` to properly handle the offset in a pointer for
specialized references, like extracting ` _double`.

Thanks to Laurent Orseau for the bug report and Jens Axel Søgaard for
intial debugging.
  • Loading branch information
mflatt committed May 10, 2020
1 parent e301d30 commit b0e6519
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 118 deletions.
223 changes: 117 additions & 106 deletions pkgs/racket-test-core/tests/racket/foreign-test.rktl
Original file line number Diff line number Diff line change
Expand Up @@ -1163,118 +1163,129 @@
;; ----------------------------------------
;; Test JIT inlining

(define bstr (cast (make-bytes 64) _pointer _pointer))

(for/fold ([v 1.0]) ([i (in-range 100)])
(ptr-set! bstr _float v)
(ptr-set! bstr _float 1 (+ v 0.5))
(ptr-set! bstr _float 'abs 8 (+ v 0.25))
(unless (= v (ptr-ref bstr _float))
(error 'float "failed"))
(unless (= (+ v 0.5) (ptr-ref bstr _float 'abs 4))
(error 'float "failed(2) ~s ~s" (+ v 0.5) (ptr-ref bstr _float 'abs 4)))
(unless (= (+ v 0.25) (ptr-ref bstr _float 2))
(error 'float "failed(3)"))
(+ 1.0 v))

(for/fold ([v 1.0]) ([i (in-range 100)])
(ptr-set! bstr _double v)
(ptr-set! bstr _double 1 (+ v 0.5))
(ptr-set! bstr _double 'abs 16 (+ v 0.25))
(unless (= v (ptr-ref bstr _double))
(error 'double "failed"))
(unless (= (+ v 0.5) (ptr-ref bstr _double 'abs 8))
(error 'double "failed(2)"))
(unless (= (+ v 0.25) (ptr-ref bstr _double 2))
(error 'double "failed(3)"))
(+ 1.0 v))

(for ([i (in-range 256)])
(ptr-set! bstr _uint8 i)
(ptr-set! bstr _uint8 1 (- 255 i))
(unless (= i (ptr-ref bstr _uint8))
(error 'uint8 "fail ~s vs. ~s" i (ptr-ref bstr _uint8)))
(unless (= (- 255 i) (ptr-ref bstr _uint8 'abs 1))
(error 'uint8 "fail(2) ~s vs. ~s" (- 255 i) (ptr-ref bstr _uint8 'abs 1))))

(for ([i (in-range -128 128)])
(ptr-set! bstr _int8 i)
(unless (= i (ptr-ref bstr _int8))
(error 'int8 "fail ~s vs. ~s" i (ptr-ref bstr _int8))))

(for ([i (in-range (expt 2 16))])
(ptr-set! bstr _uint16 i)
(ptr-set! bstr _uint16 3 (- (sub1 (expt 2 16)) i))
(unless (= i (ptr-ref bstr _uint16))
(error 'uint16 "fail ~s vs. ~s" i (ptr-ref bstr _uint16)))
(unless (= (- (sub1 (expt 2 16)) i) (ptr-ref bstr _uint16 'abs 6))
(error 'uint16 "fail(2) ~s vs. ~s" (- (sub1 (expt 2 16)) i) (ptr-ref bstr _uint16 'abs 6))))

(for ([j (in-range 100)])
(for ([i (in-range (- (expt 2 15)) (sub1 (expt 2 15)))])
(ptr-set! bstr _int16 i)
(unless (= i (ptr-ref bstr _int16))
(error 'int16 "fail ~s vs. ~s" i (ptr-ref bstr _int16)))))
(define (test-ptr-jit-inline bstr)
(for/fold ([v 1.0]) ([i (in-range 100)])
(ptr-set! bstr _float v)
(ptr-set! bstr _float 1 (+ v 0.5))
(ptr-set! bstr _float 'abs 8 (+ v 0.25))
(unless (= v (ptr-ref bstr _float))
(error 'float "failed"))
(unless (= (+ v 0.5) (ptr-ref bstr _float 'abs 4))
(error 'float "failed(2) ~s ~s" (+ v 0.5) (ptr-ref bstr _float 'abs 4)))
(unless (= (+ v 0.25) (ptr-ref bstr _float 2))
(error 'float "failed(3)"))
(+ 1.0 v))

(for/fold ([v 1.0]) ([i (in-range 100)])
(ptr-set! bstr _double v)
(ptr-set! bstr _double 1 (+ v 0.5))
(ptr-set! bstr _double 'abs 16 (+ v 0.25))
(ptr-set! (ptr-add bstr 24) _double (+ v 0.125))
(unless (= v (ptr-ref bstr _double))
(error 'double "failed"))
(unless (= (+ v 0.5) (ptr-ref bstr _double 'abs 8))
(error 'double "failed(2)"))
(unless (= (+ v 0.25) (ptr-ref bstr _double 2))
(error 'double "failed(3)"))
(unless (= (+ v 0.5) (ptr-ref (ptr-add bstr 8) _double))
(error 'double "failed(4)"))
(unless (= (+ v 0.125) (ptr-ref (ptr-add bstr 24) _double))
(error 'double "failed(5)"))
(+ 1.0 v))

(for ([i (in-range 256)])
(ptr-set! bstr _uint8 i)
(ptr-set! bstr _uint8 1 (- 255 i))
(unless (= i (ptr-ref bstr _uint8))
(error 'uint8 "fail ~s vs. ~s" i (ptr-ref bstr _uint8)))
(unless (= (- 255 i) (ptr-ref bstr _uint8 'abs 1))
(error 'uint8 "fail(2) ~s vs. ~s" (- 255 i) (ptr-ref bstr _uint8 'abs 1))))

(for ([i (in-range -128 128)])
(ptr-set! bstr _int8 i)
(unless (= i (ptr-ref bstr _int8))
(error 'int8 "fail ~s vs. ~s" i (ptr-ref bstr _int8))))

(for ([i (in-range (expt 2 16))])
(ptr-set! bstr _uint16 i)
(ptr-set! bstr _uint16 3 (- (sub1 (expt 2 16)) i))
(unless (= i (ptr-ref bstr _uint16))
(error 'uint16 "fail ~s vs. ~s" i (ptr-ref bstr _uint16)))
(unless (= (- (sub1 (expt 2 16)) i) (ptr-ref bstr _uint16 'abs 6))
(error 'uint16 "fail(2) ~s vs. ~s" (- (sub1 (expt 2 16)) i) (ptr-ref bstr _uint16 'abs 6))))

(for ([j (in-range 100)])
(for ([i (in-range (- (expt 2 15)) (sub1 (expt 2 15)))])
(ptr-set! bstr _int16 i)
(unless (= i (ptr-ref bstr _int16))
(error 'int16 "fail ~s vs. ~s" i (ptr-ref bstr _int16)))))

(let ()
(define (go lo hi)
(for ([i (in-range lo hi)])
(ptr-set! bstr _uint32 i)
(ptr-set! bstr _uint32 1 (- hi (- i lo) 1))
(unless (= i (ptr-ref bstr _uint32))
(error 'uint32 "fail ~s vs. ~s" i (ptr-ref bstr _uint32)))
(unless (= (- hi (- i lo) 1) (ptr-ref bstr _uint32 'abs 4))
(error 'uint32 "fail ~s vs. ~s" (- hi (- i lo) 1) (ptr-ref bstr _uint32)))))
(go 0 256)
(go (- (expt 2 31) 256) (+ (expt 2 31) 256))
(go (- (expt 2 32) 256) (expt 2 32)))
(let ()
(define (go lo hi)
(for ([i (in-range lo hi)])
(ptr-set! bstr _uint32 i)
(ptr-set! bstr _uint32 1 (- hi (- i lo) 1))
(unless (= i (ptr-ref bstr _uint32))
(error 'uint32 "fail ~s vs. ~s" i (ptr-ref bstr _uint32)))
(unless (= (- hi (- i lo) 1) (ptr-ref bstr _uint32 'abs 4))
(error 'uint32 "fail ~s vs. ~s" (- hi (- i lo) 1) (ptr-ref bstr _uint32)))))
(go 0 256)
(go (- (expt 2 31) 256) (+ (expt 2 31) 256))
(go (- (expt 2 32) 256) (expt 2 32)))

(let ()
(define (go lo hi)
(for ([i (in-range lo hi)])
(ptr-set! bstr _int32 i)
(unless (= i (ptr-ref bstr _int32))
(error 'int32 "fail ~s vs. ~s" i (ptr-ref bstr _int32)))))
(go -256 256)
(go (- (expt 2 31) 256) (sub1 (expt 2 31)))
(go (- (expt 2 31)) (- 256 (expt 2 31))))
(let ()
(define (go lo hi)
(for ([i (in-range lo hi)])
(ptr-set! bstr _int32 i)
(unless (= i (ptr-ref bstr _int32))
(error 'int32 "fail ~s vs. ~s" i (ptr-ref bstr _int32)))))
(go -256 256)
(go (- (expt 2 31) 256) (sub1 (expt 2 31)))
(go (- (expt 2 31)) (- 256 (expt 2 31))))

(let ()
(define (go lo hi)
(for ([i (in-range lo hi)])
(ptr-set! bstr _uint64 i)
(ptr-set! bstr _uint64 1 (- hi (- i lo) 1))
(unless (= i (ptr-ref bstr _uint64))
(error 'uint64 "fail ~s vs. ~s" i (ptr-ref bstr _uint64)))
(unless (= (- hi (- i lo) 1) (ptr-ref bstr _uint64 'abs 8))
(error 'uint32 "fail ~s vs. ~s" (- hi (- i lo) 1) (ptr-ref bstr _uint64)))))
(go 0 256)
(go (- (expt 2 63) 256) (+ (expt 2 63) 256))
(go (- (expt 2 64) 256) (expt 2 64)))
(let ()
(define (go lo hi)
(for ([i (in-range lo hi)])
(ptr-set! bstr _uint64 i)
(ptr-set! bstr _uint64 1 (- hi (- i lo) 1))
(unless (= i (ptr-ref bstr _uint64))
(error 'uint64 "fail ~s vs. ~s" i (ptr-ref bstr _uint64)))
(unless (= (- hi (- i lo) 1) (ptr-ref bstr _uint64 'abs 8))
(error 'uint32 "fail ~s vs. ~s" (- hi (- i lo) 1) (ptr-ref bstr _uint64)))))
(go 0 256)
(go (- (expt 2 63) 256) (+ (expt 2 63) 256))
(go (- (expt 2 64) 256) (expt 2 64)))

(let ()
(define (go lo hi)
(for ([i (in-range lo hi)])
(ptr-set! bstr _int64 i)
(unless (= i (ptr-ref bstr _int64))
(error 'int64 "fail ~s vs. ~s" i (ptr-ref bstr _int64)))))
(go -256 256)
(go (- (expt 2 63) 256) (sub1 (expt 2 63)))
(go (- (expt 2 63)) (- 256 (expt 2 63))))
(let ()
(define (go lo hi)
(for ([i (in-range lo hi)])
(ptr-set! bstr _int64 i)
(unless (= i (ptr-ref bstr _int64))
(error 'int64 "fail ~s vs. ~s" i (ptr-ref bstr _int64)))))
(go -256 256)
(go (- (expt 2 63) 256) (sub1 (expt 2 63)))
(go (- (expt 2 63)) (- 256 (expt 2 63))))

(let ()
(define p (cast bstr _pointer _pointer))
(for ([i (in-range 100)])
(ptr-set! bstr _pointer (ptr-add p i))
(ptr-set! bstr _pointer 2 p)
(unless (ptr-equal? p (ptr-add (ptr-ref bstr _pointer) (- i)))
(error 'pointer "fail ~s vs. ~s"
(cast p _pointer _intptr)
(cast (ptr-ref bstr _pointer) _pointer _intptr)))
(unless (ptr-equal? p (ptr-ref bstr _pointer 'abs (* 2 (ctype-sizeof _pointer))))
(error 'pointer "fail ~s vs. ~s"
(cast p _pointer _intptr)
(cast (ptr-ref bstr _pointer 'abs (ctype-sizeof _pointer)) _pointer _intptr))))))

(let ()
(define p (cast bstr _pointer _pointer))
(for ([i (in-range 100)])
(ptr-set! bstr _pointer (ptr-add p i))
(ptr-set! bstr _pointer 2 p)
(unless (ptr-equal? p (ptr-add (ptr-ref bstr _pointer) (- i)))
(error 'pointer "fail ~s vs. ~s"
(cast p _pointer _intptr)
(cast (ptr-ref bstr _pointer) _pointer _intptr)))
(unless (ptr-equal? p (ptr-ref bstr _pointer 'abs (* 2 (ctype-sizeof _pointer))))
(error 'pointer "fail ~s vs. ~s"
(cast p _pointer _intptr)
(cast (ptr-ref bstr _pointer 'abs (ctype-sizeof _pointer)) _pointer _intptr)))))
(test-ptr-jit-inline (make-bytes 64))
(test-ptr-jit-inline (cast (make-bytes 64) _pointer _pointer))
(let ([p (malloc 'raw 64)])
(test-ptr-jit-inline p)
(free p)))

;; ----------------------------------------

Expand Down
24 changes: 12 additions & 12 deletions racket/src/cs/rumble/foreign.ss
Original file line number Diff line number Diff line change
Expand Up @@ -969,12 +969,12 @@
[(and simple-p
(fixnum? offset)
(or (not abs?) (fx= 0 (fxand offset (fx- (fxsll 1 type-bits) 1)))))
(if (bytevector? simple-p)
(bytes-ref simple-p (if abs? offset (fxsll offset type-bits)))
(let ([offset (let ([offset (if abs? offset (fxsll offset type-bits))])
(if (cpointer+offset? p)
(+ offset (cpointer+offset-offset p))
offset))])
(let ([offset (let ([offset (if abs? offset (fxsll offset type-bits))])
(if (cpointer+offset? p)
(+ offset (cpointer+offset-offset p))
offset))])
(if (bytevector? simple-p)
(bytes-ref simple-p offset)
(foreign-ref 'foreign-type simple-p offset)))]
[else
(if abs?
Expand All @@ -993,12 +993,12 @@
(fixnum? offset)
(or (not abs?) (fx= 0 (fxand offset (fx- (fxsll 1 type-bits) 1))))
(ok-v? v))
(if (bytevector? simple-p)
(bytes-set simple-p (if abs? offset (fxsll offset type-bits)) v)
(let ([offset (let ([offset (if abs? offset (fxsll offset type-bits))])
(if (cpointer+offset? p)
(+ offset (cpointer+offset-offset p))
offset))])
(let ([offset (let ([offset (if abs? offset (fxsll offset type-bits))])
(if (cpointer+offset? p)
(+ offset (cpointer+offset-offset p))
offset))])
(if (bytevector? simple-p)
(bytes-set simple-p offset v)
(foreign-set! 'foreign-type simple-p offset v)))]
[else
(if abs?
Expand Down

0 comments on commit b0e6519

Please sign in to comment.