In [2]:
import numpy as np
import torch
from tqdm import tqdm
import os


In [61]:
'''
choose between train and val
'''
filename = 'val'
# filename = 'train'

In [41]:
def f(a,b): # range of a,b is {1,2,...,n}
    assert len(a) == len(b)
    return [a[b[i]-1] for i in range(len(b))]  # f(a,b)[i] = a[b[i]]

def permute(list_of_permutations,partial_sum=True):
    """
    list_of_permutations: each permutation in the list is over {1,2,..,n} where n is some integer. 
    If partial_sum = True, then answer is a list of permutations containing all the intermediate results.
    If partial_sum = False, then answer is a list containing only one permutation which is final answer.
    """
    answer = [list(list_of_permutations[0])]
    for p in list_of_permutations[1:]:
        if partial_sum:
            answer.append( f(answer[-1], list(p)))
        else:
            answer[-1] = f(answer[-1], list(p))
    return answer

In [62]:
rng = np.random.default_rng(1234)
block_size = 128
p = 5
C = 9
N = 1000000
left_bracket_id = p+1
right_bracket_id = p+2
eq_id = p+3
dtype = np.int8

random_list = rng.permuted(np.tile(np.arange(1,p+1), N * C).reshape(N*C,p), axis=1).reshape(N,C,p)
input_length_list = np.random.randint(low = 1,high =C+1, size = (N,), dtype = dtype)
token_id_list = []
for i in tqdm(range(N)):
    current_id = []
    current_input_tokens = random_list[i]
    for j in range(input_length_list[i]):
        current_id += [left_bracket_id]+ list(current_input_tokens[j])+ [right_bracket_id]
    answer = [[left_bracket_id] + perm + [right_bracket_id] for perm in permute(random_list[i][:input_length_list[i]],partial_sum=True)]
    token_id_list.append(current_id + [eq_id] + sum(answer,[]))
# print( token_id_list)

100%|██████████| 1000000/1000000 [00:20<00:00, 48719.18it/s]


In [44]:
token_id_list[0][:-1]

[5, -2]

In [63]:
arr = np.memmap(filename + '.bin', dtype=dtype, mode='w+', shape=(block_size * len(token_id_list)))
arr_label = np.memmap(filename + '_label.bin', dtype=dtype, mode='w+', shape=(block_size * len(token_id_list)))
print(f"writing {filename}...")
for jj, example in tqdm(enumerate(token_id_list)):
    arr[jj * block_size : jj * block_size + len(example)] = example
    arr[jj * block_size + len(example) + 1 : (jj + 1) * block_size] = 0
    arr_label[jj * block_size : (1+jj) * block_size ] = -1
    arr_label[jj * block_size + (p+2)*input_length_list[jj] :jj * block_size + len(example)-1] = example[input_length_list[jj]*(p+2)+1:]
    # pad to block_size
arr.flush()
arr_label.flush()

writing val...


1000000it [00:16, 59991.60it/s]


In [64]:
print(arr[:64])
print(arr_label[:64])
print(arr[64:128])
print(arr_label[64:128])

[6 1 5 2 4 3 7 6 2 4 3 1 5 7 6 1 2 4 3 5 7 6 2 1 5 4 3 7 6 2 1 3 4 5 7 6 5
 1 4 2 3 7 6 1 3 2 4 5 7 6 2 4 3 5 1 7 6 3 5 2 4 1 7 8]
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1  6]
[6 1 5 2 4 3 7 6 5 4 2 1 3 7 6 5 4 1 2 3 7 6 4 5 3 2 1 7 6 5 4 3 2 1 7 6 1
 5 2 4 3 7 6 1 2 5 4 3 7 6 2 4 5 3 1 7 6 5 1 4 3 2 7 0]
[ 1  5  2  4  3  7  6  5  4  2  1  3  7  6  5  4  1  2  3  7  6  4  5  3
  2  1  7  6  5  4  3  2  1  7  6  1  5  2  4  3  7  6  1  2  5  4  3  7
  6  2  4  5  3  1  7  6  5  1  4  3  2  7 -1 -1]


In [67]:
offset=384
print(arr[+offset:64+offset])
print(arr_label[+offset:64+offset])
print(arr[64+offset:offset+offset])
print(arr_label[64+offset:offset+offset])

[6 2 1 4 5 3 7 6 1 3 5 2 4 7 6 2 1 3 5 4 7 8 6 2 1 4 5 3 7 6 2 4 3 1 5 7 6
 4 2 3 5 1 7 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1  6  2  1
  4  5  3  7  6  2  4  3  1  5  7  6  4  2  3  5  1  7 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 6 3 4 2 1 5 7 6 1 4
 3 5 2 7 6 4 1 2 5 3 7 6 3 4 1 5 2 7 6 3 2 4 5 1 7 6 3 1 4 2 5 7 6 3 2 5 4
 1 7 6 4 2 1 5 3 7 6 4 3 2 5 1 7 8 6 3 4 2 1 5 7 6 3 1 2 5 4 7 6 5 3 1 4 2
 7 6 1 4 5 2 3 7 6 5 4 2 3 1 7 6 2 5 3 4 1 7 6 3 5 1 4 2 7 6 4 5 3 2 1 7 6
 2 3 5 1 4 7 0 6 5 3 2 1 4 7 6 3 5 1 2 4 7 6 2 4 5 1 3 7 6 3 1 2 5 4 7 6 2
 1 4 3 5 7 6 4 3 5 2 1 7 6 2 5 1 4 3 7 6 2 4 3 1 5 7 8 6 5 3 2 1 4 7 6 2 4
 5 3 1 7 6 4 3 1 2 5 7 6 1 4 3 5 2 7 6 4 1 5 3 2 7 6 3 5 2 1 4 7 6 5 4 3 1
 2 7 6 4 1 3 5 2 7 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[-1 -1 -1 -1 -1 -1 -1 -

In [None]:
print(arr[+128:64+128])
print(arr_label[+128:64+128])
print(arr[64+128:128+128])
print(arr_label[64+128:128+128])

In [16]:
a = [1,2]
a.append([2])
a

[1, 2, [2]]

In [18]:
[[2],[3]].join()

AttributeError: 'list' object has no attribute 'join'

In [19]:
sum([[1],[2,3]], [])

[1, 2, 3]