Skip to content

Commit

Permalink
Refs #130 Prevent reading ipiv array beyond the bound in ?laswp. Use …
Browse files Browse the repository at this point in the history
…laswp instead of laswp_oncopy in getrf.
  • Loading branch information
xianyi committed Aug 9, 2012
1 parent e8306f6 commit 1b056c5
Show file tree
Hide file tree
Showing 13 changed files with 2,693 additions and 674 deletions.
5 changes: 3 additions & 2 deletions lapack/getrf/getrf_parallel.c
Expand Up @@ -118,7 +118,7 @@ static void inner_basic_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *ra
min_jj = js + min_j - jjs;
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;

if (GEMM_UNROLL_N <= 8) {
if (0 && GEMM_UNROLL_N <= 8) {

LASWP_NCOPY(min_jj, off + 1, off + k,
c + (- off + jjs * lda) * COMPSIZE, lda,
Expand Down Expand Up @@ -245,7 +245,8 @@ static int inner_advanced_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *
min_jj = MIN(n_to, xxx + div_n) - jjs;
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;

if (GEMM_UNROLL_N <= 8) {
if (0 && GEMM_UNROLL_N <= 8) {
printf("helllo\n");

LASWP_NCOPY(min_jj, off + 1, off + k,
b + (- off + jjs * lda) * COMPSIZE, lda,
Expand Down
11 changes: 11 additions & 0 deletions lapack/getrf/getrf_parallel_omp.c
Expand Up @@ -77,10 +77,21 @@ static void inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
min_jj = js + min_j - jjs;
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;

#if 0
LASWP_NCOPY(min_jj, off + 1, off + k,
c + (- off + jjs * lda) * COMPSIZE, lda,
ipiv, sb + k * (jjs - js) * COMPSIZE);

#else
LASWP_PLUS(min_jj, off + 1, off + k, ZERO,
#ifdef COMPLEX
ZERO,
#endif
c + (- off + jjs * lda) * COMPSIZE, lda, NULL, 0, ipiv, 1);

GEMM_ONCOPY (k, min_jj, c + jjs * lda * COMPSIZE, lda, sb + (jjs - js) * k * COMPSIZE);
#endif

for (is = 0; is < k; is += GEMM_P) {
min_i = k - is;
if (min_i > GEMM_P) min_i = GEMM_P;
Expand Down
2 changes: 1 addition & 1 deletion lapack/getrf/getrf_single.c
Expand Up @@ -113,7 +113,7 @@ blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa,
min_jj = js + jmin - jjs;
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;

#if 0
#if 1
LASWP_PLUS(min_jj, j + offset + 1, j + jb + offset, ZERO,
#ifdef COMPLEX
ZERO,
Expand Down
97 changes: 88 additions & 9 deletions lapack/laswp/generic/laswp_k_1.c
Expand Up @@ -48,7 +48,7 @@
int CNAME(BLASLONG n, BLASLONG k1, BLASLONG k2, FLOAT dummy1, FLOAT *a, BLASLONG lda,
FLOAT *dummy2, BLASLONG dumy3, blasint *ipiv, BLASLONG incx){

BLASLONG i, j, ip1, ip2;
BLASLONG i, j, ip1, ip2, rows;
blasint *piv;
FLOAT *a1;
FLOAT *b1, *b2;
Expand All @@ -58,13 +58,34 @@ int CNAME(BLASLONG n, BLASLONG k1, BLASLONG k2, FLOAT dummy1, FLOAT *a, BLASLONG
k1 --;

#ifndef MINUS
ipiv += k1
;
ipiv += k1;
#else
ipiv -= (k2 - 1) * incx;
#endif

if (n <= 0) return 0;

rows = k2-k1;
if (rows <=0) return 0;
if (rows == 1) {
//Only have 1 row
ip1 = *ipiv;
a1 = a + k1 + 1;
b1 = a + ip1;

if(a1 == b1) return 0;

for(j=0; j<n; j++){
A1 = *a1;
B1 = *b1;
*a1 = B1;
*b1 = A1;

a1 += lda;
b1 += lda;
}
return 0;
}

j = n;
if (j > 0) {
Expand All @@ -85,10 +106,11 @@ int CNAME(BLASLONG n, BLASLONG k1, BLASLONG k2, FLOAT dummy1, FLOAT *a, BLASLONG
b1 = a + ip1;
b2 = a + ip2;

i = ((k2 - k1) >> 1);

if (i > 0) {
do {
i = (rows >> 1);

i--;
//Main Loop
while (i > 0) {
#ifdef OPTERON
#ifndef MINUS
asm volatile("prefetchw 2 * 128(%0)\n" : : "r"(a1));
Expand Down Expand Up @@ -172,12 +194,69 @@ int CNAME(BLASLONG n, BLASLONG k1, BLASLONG k2, FLOAT dummy1, FLOAT *a, BLASLONG
a1 -= 2;
#endif
i --;
} while (i > 0);
}

//Loop Ending
A1 = *a1;
A2 = *a2;
B1 = *b1;
B2 = *b2;
if (b1 == a1) {
if (b2 == a1) {
*a1 = A2;
*a2 = A1;
} else
if (b2 != a2) {
*a2 = B2;
*b2 = A2;
}
} else
if (b1 == a2) {
if (b2 != a1) {
if (b2 == a2) {
*a1 = A2;
*a2 = A1;
} else {
*a1 = A2;
*a2 = B2;
*b2 = A1;
}
}
} else {
if (b2 == a1) {
*a1 = A2;
*a2 = B1;
*b1 = A1;
} else
if (b2 == a2) {
*a1 = B1;
*b1 = A1;
} else
if (b2 == b1) {
*a1 = B1;
*a2 = A1;
*b1 = A2;
} else {
*a1 = B1;
*a2 = B2;
*b1 = A1;
*b2 = A2;
}
}

#ifndef MINUS
a1 += 2;
#else
a1 -= 2;
#endif

i = ((k2 - k1) & 1);
//Remain
i = (rows & 1);

if (i > 0) {
ip1 = *piv;
b1 = a + ip1;

A1 = *a1;
B1 = *b1;
*a1 = B1;
Expand Down

0 comments on commit 1b056c5

Please sign in to comment.