In [2]:
import jax.numpy as jnp
from jax import grad, jit, vmap

from functional import partial

import pandas as pd
from src.data import prime_fn
from einops import rearrange

In [18]:
def primes_fn(d):
    prime_iter = prime_fn()
    primes     = []
    prime      = next(prime_iter)
    while prime < 10 ** d:
        if prime >= 10 ** (d - 1):
            primes.append(prime)
        prime = next(prime_iter)
    return jnp.array(primes).astype(jnp.int32)

def data_fn(d):
    primes  = primes_fn(d)                              # (n_primes,  )
    context = context_fn(d)                             # (10 ** d , 4)
    y = vmap(lambda row: target_fn(primes, row))(context)  # (10 ** d, 4)
    x = vmap(token_fn)(context)                         # (10 ** d, 4)
    return x, y 

def token_fn(row):
    # [1001, 1003, 1007, 1009] -> [1, 0, 0, 1, 1, 0, 0, 3, 1, 0, 0, 7, 1, 0, 0, 9]
    return jnp.array(rearrange(row, 'a b -> b a')).flatten()

def target_fn(primes, row):
    # return vector of 4 booleans, indicating if the row contains the prime numbers
    return jnp.any(row[None, :] == primes[:, None], axis=0).astype(jnp.int32)

def context_fn(d):
    context = jnp.arange(10 ** (d - 1), 10 ** d).reshape(-1, 10)
    context = context[:, [1, 3, 7, 9]]
    return context

In [22]:
x, y = jit(partial(data_fn, 4))()
x

Array([[1001, 1003, 1007, 1009],
       [1011, 1013, 1017, 1019],
       [1021, 1023, 1027, 1029],
       ...,
       [9971, 9973, 9977, 9979],
       [9981, 9983, 9987, 9989],
       [9991, 9993, 9997, 9999]], dtype=int32)