In [13]:
## python count.py --fn cora --range 21


import numpy as np
from scipy.special import comb, factorial
from time import time
import argparse


parser = argparse.ArgumentParser(description='Compute the number of data points in each region')
parser.add_argument("--fn", default='cora', type=str, help="dataset")
parser.add_argument("--range", type=int, default=21, help="range of certified perturbation size")
parser.add_argument("--K", type=int, default=1, help="binary data")

args = parser.parse_args([])



global_comb = dict()
global_powe = dict()

def my_comb(d, m):
	if (d, m) not in global_comb:
		global_comb[(d, m)] = comb(d, m, exact=True)

	return global_comb[(d, m)]

def my_powe(k, p):
	if (k, p) not in global_powe:
		global_powe[(k, p)] = k ** p

	return global_powe[(k, p)]


def get_count(d, m, n, r, K):
	if r == 0 and m == 0 and n == 0:
		return 1
	# early stopping
	if (r == 0 and m != n) or min(m, n) < 0 or max(m, n) > d or m + n < r:
		return 0

	if r == 0:
		return my_comb(d, m) * my_powe(K, m)
	else:
		c = 0

		# the number which are assigned to the (d-r) dimensions
		for i in range(max(0, n-r), min(m, d-r, int(np.floor((m+n-r) * 0.5))) + 1):
			if (m+n-r) / 2 < i:
				break
			x = m - i
			y = n - i
			j = x + y - r
			# j = 0 ## if K = 1
			# the second one implies n <= m+r
			if j < 0 or x < j:
				continue
			tmp = my_powe(K-1, j) * my_comb(r, x-j) * my_comb(r-x+j, j)
			if tmp != 0:
				tmp *= my_comb(d-r, i) * my_powe(K, i)
				c += tmp

		return c



if __name__ == "__main__":

	if args.fn == 'cora':
		global_d =  2708 # cora

	if args.fn == 'citeseer':
		global_d = 3327 # citeseer

	if args.fn == 'pubmed':
		global_d = 19717 # pubmed
	
	if args.fn == 'ogbn-arxiv':
		global_d = 169343 # pubmed


	K = args.K
	r_range = [0, args.range]
	m_range = [0, global_d+1]

	print('fn =', args.fn, 'Range of L0 norm =', r_range, 'm_range =', m_range, 'global_d:', global_d, 'data type:', K)

	real_ttl = (K+1)**global_d

	for r in range(r_range[0],r_range[1]):
		ttl = 0
		complete_cnt = []
		for m in range(m_range[0], m_range[1]):
			start = time()
			for n in range(m, min(m+r, global_d)+1):
				c = get_count(global_d, m, n, r, K)
				if c != 0:
					complete_cnt.append(((m, n), c))
					ttl += c
					# symmetric between d, m, n, r and d, n, m, r
					if n > m:
						ttl += c
			
			if m % 100 == 0:
				print('r = {}, m = {:10d}/{:10d}, ttl ratio = {:.4f}, # of dict = {}'.format(r, m, m_range[1], ttl / real_ttl, len(complete_cnt)))
				print(args.fn, len(global_powe), len(global_comb), time() - start)
		
		# np.save('list_counts/{}/complete_count_{}'.format(args.fn, r), complete_cnt)

		# del complete_cnt 
		# del global_comb, global_powe
		
		# global_comb = dict()
		# global_powe = dict()
		break

fn = cora Range of L0 norm = [0, 21] m_range = [0, 89251] global_d: 89250 data type: 1
r = 0, m =          0/     89251, ttl ratio = 0.0000, # of dict = 1
cora 0 0 1.4066696166992188e-05
r = 0, m =        100/     89251, ttl ratio = 0.0000, # of dict = 101
cora 100 100 2.1219253540039062e-05
r = 0, m =        200/     89251, ttl ratio = 0.0000, # of dict = 201
cora 200 200 4.9591064453125e-05
r = 0, m =        300/     89251, ttl ratio = 0.0000, # of dict = 301
cora 300 300 9.226799011230469e-05
r = 0, m =        400/     89251, ttl ratio = 0.0000, # of dict = 401
cora 400 400 0.00011754035949707031
r = 0, m =        500/     89251, ttl ratio = 0.0000, # of dict = 501
cora 500 500 0.00017380714416503906
r = 0, m =        600/     89251, ttl ratio = 0.0000, # of dict = 601
cora 600 600 0.00023365020751953125
r = 0, m =        700/     89251, ttl ratio = 0.0000, # of dict = 701
cora 700 700 0.00031757354736328125
r = 0, m =        800/     89251, ttl ratio = 0.0000, # of dict = 801
cora 

KeyboardInterrupt: 

In [10]:
complete_cnt

[((0, 0), 1),
 ((1, 1), 2708),
 ((2, 2), 3665278),
 ((3, 3), 3306080756),
 ((4, 4), 2235737111245),
 ((5, 5), 1209086629761296),
 ((6, 6), 544693526707463848),
 ((7, 7), 210251701309081045328),
 ((8, 8), 70986230654478487928866),
 ((9, 9), 21295869196343546378659800),
 ((10, 10), 5747755096093123167600280020),
 ((11, 11), 1409767568114476936925959590360),
 ((12, 12), 316845260933728691574109417933410),
 ((13, 13), 65708832575179427114138383903728720),
 ((14, 14), 12648950270722039719471638901467778600),
 ((15, 15), 2271751468621678333617106346703613036560),
 ((16, 16), 382364169062386234526929211979551869216005),
 ((17, 17), 60548490771526102549793731685232566584087380),
 ((18, 18), 9051999370343152331194162886942268704321063310),
 ((19, 19), 1281572542432793672153278850835510674453876858100),
 ((20, 20), 172307428330089109221008341494834410180323743571545),
 ((21, 21), 22055350826251405980289067711338804503081439177157760),
 ((22, 22), 2693760348642614903138032951834880349989992139501