Permalink
Browse files

Generate more transpose kernels

  • Loading branch information...
1 parent 4d90bb0 commit 4a5ee157b5db8006e7a7bdbed47e23ad85bf184e @pkhuong committed Dec 25, 2011
Showing with 118 additions and 107 deletions.
  1. +118 −107 generator-macros.lisp
View
@@ -212,104 +212,107 @@
tmp2 tmp1 tmp-offset))))
(tran dst dst-offset tmp1 tmp-offset size2 size1)))))
-(defun build-transpose-funs (outer-name inner-name twiddle half-size base-case)
- `((,inner-name (dst dst-offset dst-stride src src-offset src-stride size
- ,@(and twiddle '(twiddle twiddle-offset)))
- (declare (type complex-sample-array dst src ,@(and twiddle '(twiddle)))
- (type index dst-offset src-offset ,@(and twiddle '(twiddle-offset)))
- (type half-index dst-stride src-stride size))
- (cond ((= size ,half-size)
- ;; other cases are picked off by specialised FFTs
- (,base-case
- dst dst-offset dst-stride
- src src-offset src-stride
- ,@(and twiddle '(twiddle twiddle-offset))))
- (t
- (let* ((size/2 (truncate size 2))
- (long-dst-stride (* size/2 dst-stride))
- (long-src-stride (* size/2 src-stride)))
- (declare (type index long-dst-stride long-src-stride))
- #-xecto-parallel
- (unrolled-for ((i 4))
- (let* ((short (- (logand i 1)))
- (long (- (logand i 2)))
- (dst-delta (+ (logand size/2 short)
- (logand long-dst-stride long)))
- (src-delta (+ (logand long-src-stride short)
- (logand size/2 long))))
- (declare (type index dst-delta src-delta))
- (,inner-name dst (+ dst-offset dst-delta)
- dst-stride
- src (+ src-offset src-delta)
- src-stride
- size/2
- ,@(and twiddle '(twiddle (+ twiddle-offset dst-delta))))))
- #+xecto-parallel
- (parallel:let ((aa (,inner-name dst dst-offset
- dst-stride
- src src-offset
- src-stride
- size/2
- ,@(and twiddle '(twiddle twiddle-offset))))
- (ab (,inner-name dst (+ dst-offset size/2)
- dst-stride
- src (+ src-offset long-src-stride)
- src-stride
- size/2
- ,@(and twiddle '(twiddle (+ twiddle-offset size/2)))))
- (ba (,inner-name dst (+ dst-offset long-dst-stride)
- dst-stride
- src (+ src-offset size/2)
- src-stride
- size/2
- ,@(and twiddle '(twiddle (+ twiddle-offset long-dst-stride)))))
- (bb (,inner-name dst (+ dst-offset long-dst-stride size/2)
- dst-stride
- src (+ src-offset long-src-stride size/2)
- src-stride
- size/2
- ,@(and twiddle '(twiddle (+ twiddle-offset long-dst-stride size/2)))))
- (:parallel (>= (* size/2 size/2) 1024)))
- (declare (ignore aa ab ba bb)))))))
- (,outer-name (dst dst-offset src src-offset size1 size2
- ,@(and twiddle '(twiddle &aux (twiddle-base (+ (* size1 size2)
- +factor-bias+)))))
- (declare (type complex-sample-array dst src ,@(and twiddle '(twiddle)))
- (type index dst-offset src-offset ,@(and twiddle '(twiddle-base)))
- (type half-index size1 size2))
- (cond ((= size1 size2)
- (,inner-name dst dst-offset size2 src src-offset size1 size1
- ,@(and twiddle '(twiddle twiddle-base))))
- ((< size1 size2)
- (let* ((size size1)
- (block (* size size)))
- (#-xecto-parallel let
- #+xecto-parallel parallel:let
- ((a (,inner-name dst dst-offset size2
- src src-offset size1
- size
- ,@(and twiddle '(twiddle twiddle-base))))
- (b (,inner-name dst (+ size dst-offset) size2
- src (+ block src-offset) size1
- size
- ,@(and twiddle '(twiddle (+ size twiddle-base)))))
- #+xecto-parallel (:parallel (>= (* size size) 1024)))
- (declare (ignore a b)))))
- (t
- (let* ((size size2)
- (block (* size size)))
- (#-xecto-parallel let
- #+xecto-parallel parallel:let
- ((a (,inner-name dst dst-offset size2
- src src-offset size1
- size
- ,@(and twiddle '(twiddle twiddle-base))))
- (b (,inner-name dst (+ block dst-offset) size2
- src (+ size src-offset) size1
- size
- ,@(and twiddle '(twiddle (+ block twiddle-base)))))
- #+xecto-parallel (:parallel (>= (* size size) 1024)))
- (declare (ignore a b)))))))))
+(defun build-transpose-funs (outer-name inner-name twiddle
+ size-name-alist)
+ (let* ((size-name-alist (sort (copy-list size-name-alist) #'> :key #'car))
+ (max-size (car (first size-name-alist))))
+ `((,inner-name (dst dst-offset dst-stride src src-offset src-stride size
+ ,@(and twiddle '(twiddle twiddle-offset)))
+ (declare (type complex-sample-array dst src ,@(and twiddle '(twiddle)))
+ (type index dst-offset src-offset ,@(and twiddle '(twiddle-offset)))
+ (type half-index dst-stride src-stride size))
+ (cond ((> size ,max-size)
+ (let* ((size/2 (truncate size 2))
+ (long-dst-stride (* size/2 dst-stride))
+ (long-src-stride (* size/2 src-stride)))
+ (declare (type index long-dst-stride long-src-stride))
+ #-xecto-parallel
+ (unrolled-for ((i 4))
+ (let* ((short (- (logand i 1)))
+ (long (- (logand i 2)))
+ (dst-delta (+ (logand size/2 short)
+ (logand long-dst-stride long)))
+ (src-delta (+ (logand long-src-stride short)
+ (logand size/2 long))))
+ (declare (type index dst-delta src-delta))
+ (,inner-name dst (+ dst-offset dst-delta)
+ dst-stride
+ src (+ src-offset src-delta)
+ src-stride
+ size/2
+ ,@(and twiddle '(twiddle (+ twiddle-offset dst-delta))))))
+ #+xecto-parallel
+ (parallel:let ((aa (,inner-name dst dst-offset
+ dst-stride
+ src src-offset
+ src-stride
+ size/2
+ ,@(and twiddle '(twiddle twiddle-offset))))
+ (ab (,inner-name dst (+ dst-offset size/2)
+ dst-stride
+ src (+ src-offset long-src-stride)
+ src-stride
+ size/2
+ ,@(and twiddle '(twiddle (+ twiddle-offset size/2)))))
+ (ba (,inner-name dst (+ dst-offset long-dst-stride)
+ dst-stride
+ src (+ src-offset size/2)
+ src-stride
+ size/2
+ ,@(and twiddle '(twiddle (+ twiddle-offset long-dst-stride)))))
+ (bb (,inner-name dst (+ dst-offset long-dst-stride size/2)
+ dst-stride
+ src (+ src-offset long-src-stride size/2)
+ src-stride
+ size/2
+ ,@(and twiddle '(twiddle (+ twiddle-offset long-dst-stride size/2)))))
+ (:parallel (>= (* size/2 size/2) 1024)))
+ (declare (ignore aa ab ba bb)))))
+ ,@(loop for (size . name) in size-name-alist
+ collect `((= size ,size)
+ (,name
+ dst dst-offset dst-stride
+ src src-offset src-stride
+ ,@(and twiddle '(twiddle twiddle-offset)))))))
+ (,outer-name (dst dst-offset src src-offset size1 size2
+ ,@(and twiddle '(twiddle &aux (twiddle-base (+ (* size1 size2)
+ +factor-bias+)))))
+ (declare (type complex-sample-array dst src ,@(and twiddle '(twiddle)))
+ (type index dst-offset src-offset ,@(and twiddle '(twiddle-base)))
+ (type half-index size1 size2))
+ (cond ((= size1 size2)
+ (,inner-name dst dst-offset size2 src src-offset size1 size1
+ ,@(and twiddle '(twiddle twiddle-base))))
+ ((< size1 size2)
+ (let* ((size size1)
+ (block (* size size)))
+ (#-xecto-parallel let
+ #+xecto-parallel parallel:let
+ ((a (,inner-name dst dst-offset size2
+ src src-offset size1
+ size
+ ,@(and twiddle '(twiddle twiddle-base))))
+ (b (,inner-name dst (+ size dst-offset) size2
+ src (+ block src-offset) size1
+ size
+ ,@(and twiddle '(twiddle (+ size twiddle-base)))))
+ #+xecto-parallel (:parallel (>= (* size size) 1024)))
+ (declare (ignore a b)))))
+ (t
+ (let* ((size size2)
+ (block (* size size)))
+ (#-xecto-parallel let
+ #+xecto-parallel parallel:let
+ ((a (,inner-name dst dst-offset size2
+ src src-offset size1
+ size
+ ,@(and twiddle '(twiddle twiddle-base))))
+ (b (,inner-name dst (+ block dst-offset) size2
+ src (+ size src-offset) size1
+ size
+ ,@(and twiddle '(twiddle (+ block twiddle-base)))))
+ #+xecto-parallel (:parallel (>= (* size size) 1024)))
+ (declare (ignore a b))))))))))
(defmacro build-fft-routine-vectors (lb-specialized-size lb-max-size)
(when (oddp lb-specialized-size)
@@ -324,9 +327,16 @@
collect (symbolicate 'fft (ash 1 i) :fwd :scale)))
(bwd/s-funs (loop for i from 1 upto lb-specialized-size
collect (symbolicate 'fft (ash 1 i) :bwd :scale)))
- (half-size (ash 1 (truncate lb-specialized-size 2)))
- (%tran-base-case (symbolicate '%transpose half-size))
- (%tran/twiddle-base-case (symbolicate '%transpose half-size :twiddle)))
+ (%tran-base-cases
+ (loop for i from (truncate lb-specialized-size 2)
+ upto lb-specialized-size
+ for size = (ash 1 i)
+ collect (cons size (symbolicate '%transpose size))))
+ (%tran/twiddle-base-cases
+ (loop for i from (truncate lb-specialized-size 2)
+ upto lb-specialized-size
+ for size = (ash 1 i)
+ collect (cons size (symbolicate '%transpose size :twiddle)))))
`(let ((fwd-fft (make-array ,count))
(bwd-fft (make-array ,count))
(fwd-fft/scale (make-array ,count))
@@ -336,11 +346,12 @@
(type (simple-array (fft-function :scale t) (,count))
fwd-fft/scale bwd-fft/scale)
(optimize speed (safety 0)))
- (with-fft-kernels (,lb-specialized-size ,(truncate lb-specialized-size 2)
+ (with-fft-kernels (,lb-specialized-size ,lb-specialized-size
:fwd t :bwd t :scale t)
(declare (inline ,@fwd-funs ,@bwd-funs
,@fwd/s-funs ,@bwd/s-funs
- ,%tran-base-case ,%tran/twiddle-base-case))
+ ,@(mapcar #'cdr %tran-base-cases)
+ ,@(mapcar #'cdr %tran/twiddle-base-cases)))
(setf (aref fwd-fft 0) #'one-point-fft)
(setf (aref bwd-fft 0) #'one-point-fft)
(setf (aref fwd-fft/scale 0) #'one-point-fft/scale)
@@ -353,11 +364,11 @@
,(build-fft-fun 'fft/bwd-scale 'fft/bwd 'bwd-fft lb-specialized-size
t 'bwd-fft/scale)
,@(build-transpose-funs 'tran '%tran
- nil half-size
- %tran-base-case)
+ nil
+ %tran-base-cases)
,@(build-transpose-funs 'tran/twiddle '%tran/twiddle-base-case
- t half-size
- %tran/twiddle-base-case))
+ t
+ %tran/twiddle-base-cases))
(declare (inline tran tran/twiddle))
,@(loop for vec in '(fwd-fft fwd-fft/scale bwd-fft bwd-fft/scale)
for fun in '(fft/fwd fft/fwd-scale fft/bwd fft/bwd-scale)

0 comments on commit 4a5ee15

Please sign in to comment.