-
Notifications
You must be signed in to change notification settings - Fork 0
/
fmt_table.cu
184 lines (156 loc) · 5.05 KB
/
fmt_table.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
#define ulong unsigned long long
#define uint unsigned int
//A*B%MODP
__device__ uint ABModC(uint a,uint b){
ulong tmp=((ulong)(__umulhi(a,b)))*(1ULL<<32)+(ulong)(a*b);
return (uint)(tmp%MODP);
}
//exp(a,b)%MODP
__device__ uint ModExp(uint a,uint b){
ulong ans=1ULL;
ulong aa=a;
while(b!=0){
if (b%2==1) ans=ans*aa%MODP;
aa=aa*aa%MODP;
b/=2;
}
return (uint)ans;
}
//逆変換後は、FFTでいうNで除算しないといけない。
// a/arrayLength mod P
__device__ uint DivN_f(uint a,uint arrayLength)
{
uint as =a/arrayLength;
uint ar =a-as*arrayLength;
uint pn =MODP/arrayLength;
if (ar!=0){
as+=(arrayLength-ar)*pn+1;
}
return as;
}
__global__ void uFMT(uint *arrayA,uint loopCnt,uint omega,uint arrayLength ,uint *mtable) {
uint idx = threadIdx.x+blockIdx.x*256;
uint loopCnt_Pow2=1<<loopCnt;
uint t2 = idx%loopCnt_Pow2;
uint t0 = idx*2-t2;
uint t1 = t0+loopCnt_Pow2;
uint w0;
uint w1;
uint arrayAt0=arrayA[t0];
uint arrayAt1=arrayA[t1];
uint r0;
uint r1;
uint ridx=t2*(arrayLength>>(loopCnt+1));
if (ridx>=arrayLength)ridx-=arrayLength;
w0=mtable[ridx];
//w0=ModExp(omega,t2*(arrayLength2/loopCnt_Pow2));
r0=arrayAt0-arrayAt1+MODP;
r1=arrayAt0+arrayAt1;
if (r0>=MODP){r0-=MODP;}
if (r1>=MODP){r1-=MODP;}
w1=ABModC(r0,w0);
arrayA[t1]=w1;
arrayA[t0]=r1;
}
__global__ void iFMT(uint *arrayA,uint loopCnt,uint omega,uint arrayLength ,uint *mtable) {
uint idx = threadIdx.x+blockIdx.x*256;
uint loopCnt_Pow2=1<<loopCnt;
uint t2 = idx%loopCnt_Pow2;
uint t0 = idx*2-t2;
uint t1 = t0+loopCnt_Pow2;
uint w0;
uint w1;
uint arrayAt0=arrayA[t0];
uint arrayAt1=arrayA[t1];
uint r0;
uint r1;
uint ridx=arrayLength-t2*(arrayLength>>(loopCnt+1));
if (ridx>=arrayLength)ridx-=arrayLength;
w0=mtable[ridx];
//w0=ModExp(omega,arrayLength2*2-t2*(arrayLength2/loopCnt_Pow2));
w1=ABModC(arrayAt1,w0);
r0=arrayAt0-w1+MODP;
r1=arrayAt0+w1;
if (r0>=MODP){r0-=MODP;}
if (r1>=MODP){r1-=MODP;}
arrayA[t1]=r0;
arrayA[t0]=r1;
}
//同じ要素同士の掛け算
__global__ void Mul_i_i(uint *arrayA,uint *arrayB ) {
uint idx = threadIdx.x+blockIdx.x*256;
uint w0;
w0=ABModC(arrayB[idx],arrayA[idx]);
arrayB[idx]=w0;
}
//逆変換後のNで割るやつ。剰余下で割るには特殊処理が必要
__global__ void DivN(uint *arrayA,uint arrayLength ) {
uint idx = threadIdx.x+blockIdx.x*256;
arrayA[idx]=DivN_f(arrayA[idx],arrayLength);
}
//負巡回計算の前処理
//sqrt_omegaの2N乗が1 (mod P)
//a[0]*=ModExp(sqrt_omega,0)
//a[1]*=ModExp(sqrt_omega,1)
//a[2]*=ModExp(sqrt_omega,2)
//a[3]*=ModExp(sqrt_omega,3)
__global__ void PreNegFMT(uint *arrayA,uint *arrayB,uint sqrt_omega,uint *mtable,uint arrayLength) {
uint idx = threadIdx.x+blockIdx.x*256;
//w0=ModExp(sqrt_omega,idx);
uint w0=mtable[idx/2];
if (idx%2==1)
w0=ABModC(sqrt_omega,w0);
arrayA[idx]%=MODP;//これは本来必要ないが、一番最初に入力されたA,Bが必ずMODPの剰余下の値になっているとは限らないので
arrayB[idx]=ABModC(arrayA[idx],w0);
}
//負巡回計算の後処理
//sqrt_omegaの2N乗が1 (mod P)
//a[0]*=ModExp(sqrt_omega,-0)
//a[1]*=ModExp(sqrt_omega,-1)
//a[2]*=ModExp(sqrt_omega,-2)
//a[3]*=ModExp(sqrt_omega,-3)
__global__ void PostNegFMT(uint *arrayA,uint sqrt_omega,uint *mtable,uint arrayLength) {
uint idx = threadIdx.x+blockIdx.x*256;
//uint w0=ModExp(sqrt_omega,arrayLength*2-idx);
uint w0=mtable[(arrayLength*2-idx)%(arrayLength*2)/2];
if (idx%2==1)
w0=ABModC(sqrt_omega,w0);
arrayA[idx]=ABModC(arrayA[idx],w0);
}
//負巡回計算と正巡回計算結果から、上半分桁と下半分桁を求める
__global__ void PosNeg_To_HiLo(uint *arrayE,uint *arrayA,uint *arrayB,uint arrayLength) {
uint idx = threadIdx.x+blockIdx.x*256;
uint a=arrayA[idx];
uint b=arrayB[idx];
uint subab=(a-b+MODP);//まず(a-b)/2を求めたい
uint flag=subab%2;
subab-=MODP*((subab>=MODP)*2-1)*flag;//ここで絶対偶数になる
subab/=2;//(a-b)/2 MOD Pを算出
arrayE[idx+arrayLength]=subab;//上位桁は(a-b)/2 MOD P
arrayE[idx]=a-subab+MODP*(a<subab);//a-((a-b)/2)=a/2+b/2 つまり(a+b)/2が下位桁
}
//vramへの書き込み回数を減らす目的に作った関数
//PostNegFMT関数とDivN関数とPosNeg_To_HiLo関数の統合版
__global__ void PostFMT_DivN_HiLo(uint *arrayE,uint *arrayA,uint *arrayB,uint arrayLength,uint sqrt_omega) {
uint idx = threadIdx.x+blockIdx.x*256;
uint a=arrayA[idx];
uint b=arrayB[idx];
//ここは負巡回の後処理計算部分
uint w0=ModExp(sqrt_omega,idx+(idx%2)*arrayLength);
b=ABModC(b,w0);
//Nで除算する関数
a=DivN_f(a,arrayLength);
b=DivN_f(b,arrayLength);
//あとは一緒
uint subab=(a-b+MODP);//まず(a-b)/2を求めたい
uint flag=subab%2;
subab-=MODP*((subab>=MODP)*2-1)*flag;//ここで絶対偶数になる
subab/=2;//(a-b)/2 MOD Pを算出
arrayE[idx+arrayLength]=subab;//上位桁は(a-b)/2 MOD P
arrayE[idx]=a-subab+MODP*(a<subab);//a-((a-b)/2)=a/2+b/2 つまり(a+b)/2が下位桁
}
//最初にべき乗余を計算する関数
__global__ void CreateTable(uint *mtable,uint omega) {
uint idx = threadIdx.x+blockIdx.x*256;
mtable[idx]=ModExp(omega,idx);
}