-
-
Notifications
You must be signed in to change notification settings - Fork 60
/
Copy pathuint.lua
369 lines (335 loc) · 8.88 KB
/
uint.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
-- Unsigned integers with a given base.
-- Support basic arithmetic (increment, decrement, addition, subtraction, multiplication, exponentiation),
-- with both mutating variants (via methods) and copying variants (via arithmetic operators),
-- comparisons, as well as efficient base conversion.
-- Cache class tables for given bases.
-- Important for equality comparisons to work for the same base.
local uints = setmetatable({}, { __mode = "v" })
return function(
-- number, base to be used; digits will be 0 to base (exclusive);
-- should be at most `2^26` for products to stay in exact float bounds,
-- default is `2^24` (3 bytes per digit)
base
)
base = base or 2 ^ 24
assert(base >= 2 and base % 1 == 0 and base <= 2 ^ 26)
if uints[base] then
return uints[base]
end
local uint = { base = base }
local mt = { __index = uint }
local function bin_op(name, f)
mt["__" .. name] = function(a, b)
if type(a) == "number" then
a = uint.from_number(a)
elseif type(b) == "number" then
b = uint.from_number(b)
end
return f(a, b)
end
end
function uint.zero()
return setmetatable({}, mt)
end
function uint.one()
return setmetatable({ 1 }, mt)
end
function uint.from_digits(
digits -- list of digits in the appropriate base, little endian; is consumed
)
return setmetatable(digits, mt)
end
function uint.from_number(
number -- exact integer >= 0
)
local digits = {}
assert(number >= 0 and number % 1 == 0)
while number > 0 do
local digit = number % base
table.insert(digits, digit)
number = (number - digit) / base
end
return uint.from_digits(digits)
end
function uint:to_number(
exact -- whether to allow losing precision, defaults to true
)
local number = 0
local pow = 1
for _, digit in ipairs(self) do
local val = digit * pow
pow = pow * base
if exact ~= false and 2 ^ 53 - val < number then
return -- if not representable without loss of precision
end
number = number + val
end
return number -- exact number representing the same integer
end
function uint:copy()
local copy = {}
for i, v in ipairs(self) do
copy[i] = v
end
return uint.from_digits(copy)
end
function uint.copy_from(dst, src)
assert(dst.base == src.base)
for i, v in ipairs(src) do
dst[i] = v
end
for i = #src + 1, #dst do
dst[i] = nil
end
end
--> sign(a - b)
--! For better efficiency, prefer a single `uint.compare` over multiple `<`/`>`/`<=`/`>=`/`==` comparisons
function uint.compare(a, b)
assert(a.base == b.base)
if #a < #b then
return -1
end
if #a > #b then
return 1
end
for i = #a, 1, -1 do
if a[i] < b[i] then
return -1
end
if a[i] > b[i] then
return 1
end
end
return 0
end
-- Note: These will only run if the metatables are equal,
-- so there is no risk of comparing numbers of different bases.
bin_op("eq", function(a, b)
return a:compare(b) == 0
end)
bin_op("lt", function(a, b)
return a:compare(b) < 0
end)
bin_op("le", function(a, b)
return a:compare(b) <= 0
end)
function uint:increment()
local i = 1
while self[i] == base - 1 do
self[i] = 0
i = i + 1
end
self[i] = (self[i] or 0) + 1
end
function uint:decrement()
local i = 1
while self[i] == 0 do
self[i] = base - 1
i = i + 1
end
self[i] = assert(self[i], "result < 0") - 1
if self[i] == 0 then
self[i] = nil
end
end
local function add_shifted(dst, src, srcshift)
local i, j = srcshift + 1, 1
local carry = 0
while dst[i] or src[j] or carry > 0 do
local digit_sum = (dst[i] or 0) + (src[j] or 0) + carry
dst[i] = digit_sum % base
carry = (digit_sum - dst[i]) / base
i, j = i + 1, j + 1
end
end
function uint.add(dst, src)
return add_shifted(dst, src, 0)
end
bin_op("add", function(a, b)
local res = a:copy()
res:add(b)
return res
end)
local function strip_leading_zeros(dst)
local i = #dst
while dst[i] == 0 do
dst[i] = nil
i = i - 1
end
end
function uint.subtract(dst, src)
do
local i = 1
local borrow = 0
while src[i] or borrow > 0 do
local digit_diff = assert(dst[i], "result < 0") - (src[i] or 0) - borrow
-- Works since Lua's remainder operator is special -
-- it computes `(base + digit_diff) % base` for a negative `digit_diff`
dst[i] = digit_diff % base
if digit_diff < 0 then
-- Unfortunately `math.floor` rounds negative numbers in the wrong direction
-- borrow = math.floor(-digit_diff / base)
-- assert(borrow >= 0)
borrow = 1
else
borrow = 0
end
i = i + 1
end
assert(borrow == 0, "result < 0")
end
strip_leading_zeros(dst)
end
bin_op("sub", function(a, b)
local res = a:copy()
res:subtract(b)
return res
end)
local function product_naive(a, b)
assert(a[#a] ~= 0 and b[#b] ~= 0)
local res = uint.zero()
-- Enforce #a <= #b
if #b < #a then
a, b = b, a
end
for i, a_digit in ipairs(a) do
if a_digit == 0 then
res[i] = res[i] or 0 -- no holes!
else
local term = {}
local carry = 0
for j, b_digit in ipairs(b) do
local res_digit = a_digit * b_digit + carry
term[j] = res_digit % base
carry = (res_digit - term[j]) / base
end
if carry > 0 then
table.insert(term, carry)
end
add_shifted(res, term, i - 1)
end
end
assert(res[#res] ~= 0)
return res
end
-- Note: Returns `{}` (zero) if `from > to`
local function slice(self, from, to)
local digits = {}
for i = from, to do
table.insert(digits, self[i])
end
strip_leading_zeros(digits)
return uint.from_digits(digits)
end
local function product_karatsuba(a, b)
-- Ensure that `b` is the longer number
if #b < #a then
a, b = b, a
end
if #a < 10 then -- base case: Naive multiplication, to be tweaked
return product_naive(a, b)
end
local mid = math.floor(#b / 2) -- split at the middle of the longer of the two numbers
local prod_high, prod_low, prod_sum
do
local a_low, a_high = slice(a, 1, mid), slice(a, mid + 1, #a)
local b_low, b_high = slice(b, 1, mid), slice(b, mid + 1, #b)
prod_high = product_karatsuba(a_high, b_high)
prod_low = product_karatsuba(a_low, b_low)
-- Note: We can mutate a_low and b_low here since we already used them above.
local a_sum = a_low
a_sum:add(a_high)
local b_sum = b_low
b_sum:add(b_high)
prod_sum = product_karatsuba(a_sum, b_sum)
-- At this point we have `prod_sum = (a_low + a_high) * (b_low + b_high)`
prod_sum:subtract(prod_high)
prod_sum:subtract(prod_low)
-- This leaves us with `prod_sum = a_high * b_low + b_high * a_low`
end
local res = prod_low
-- Ensure that we produce no holes
local min_len = (prod_high[1] and 2 * mid) or (prod_sum[1] and mid) or 0
for i = #res + 1, min_len do
res[i] = res[i] or 0
end
add_shifted(res, prod_sum, mid)
add_shifted(res, prod_high, 2 * mid)
return res
end
bin_op("mul", product_karatsuba)
function uint.multiply(dst, src)
assert(dst.base == src.base)
dst:copy_from(product_karatsuba(dst, src))
end
-- TODO division, modulo
-- Exponentiation by squaring. Some details have to be different for uints.
local function fastpow(
expbase, -- Base (uint)
exp -- Exponent, non-negative integer
)
if exp == 1 then
return expbase
end
local res = expbase.one()
while exp > 0 do -- loop invariant: `res * expbase^exp = expbase^exp`
if exp % 2 == 1 then
-- `res * expbase * expbase^(exp-1) = expbase^exp`
res = res * expbase
exp = exp - 1
else
-- `res * (expbase^2)^(exp/2) = expbase^exp`
expbase = expbase * expbase
exp = exp / 2
end
end
return res
end
function mt.__pow(expbase, exp)
local zero_base = expbase == 0 or expbase == uint.zero()
if exp == 0 or exp == uint.zero() then
assert(not zero_base, "0^0")
return uint.one()
end
if zero_base then
return uint.zero()
end
if type(expbase) == "number" then
expbase = uint.from_number(expbase)
end
-- Try conversion of exponent to number for consistency;
-- taking n^m with n >= 2 and m >= 2^53 wouldn't fit into memory anyways
if type(exp) ~= "number" then
exp = assert(exp:to_number(), "exponent too large")
end
return fastpow(expbase, exp)
end
local function convert_base(self, other_uint)
if not self[1] then
return other_uint.zero()
end
if not self[2] then
return other_uint.from_number(self[1])
end
local mid = math.floor(#self / 2)
local low = convert_base(slice(self, 1, mid), other_uint)
local high = convert_base(slice(self, mid + 1, #self), other_uint)
-- TODO this fastpow could be optimized with memoization, but shouldn't matter asymptotically
low:add(high * fastpow(other_uint.from_number(base), mid))
return low
end
--> uint instance of `other_uint`
function uint:convert_base_to(
other_uint -- "class" table to convert to
)
if not self[1] then
return other_uint.zero()
end
if uint.base == other_uint.base then
return other_uint.copy(self)
end
return convert_base(self, other_uint)
end
uints[base] = uint
return uint
end