In [None]:
# imports
from tracr.rasp import rasp
from tracr.compiler import compiling
from tracr.compiler import lib

import jax
import numpy as np
import matplotlib.pyplot as plt
jax.config.update('jax_default_matmul_precision', 'float32')

from typing import List, Sequence

In [None]:
#@title Plotting functions
def tidy_label(label, value_width=5):
  if ':' in label:
    label, value = label.split(':')
  else:
    value = ''
  return label + f":{value:>{value_width}}"


def add_residual_ticks(model, value_width=5, x=False, y=True):
  if y:
    plt.yticks(
            np.arange(len(model.residual_labels))+0.5, 
            [tidy_label(l, value_width=value_width)
              for l in model.residual_labels], 
            family='monospace',
            fontsize=20,
    )
  if x:
    plt.xticks(
            np.arange(len(model.residual_labels))+0.5, 
            [tidy_label(l, value_width=value_width)
              for l in model.residual_labels], 
            family='monospace',
            rotation=90,
            fontsize=20,
    )


def plot_computation_trace(model,
                           input_labels,
                           residuals_or_outputs,
                           add_input_layer=False,
                           figsize=(12, 9)):
  fig, axes = plt.subplots(nrows=1, ncols=len(residuals_or_outputs), figsize=figsize, sharey=True)
  value_width = max(map(len, map(str, input_labels))) + 1

  for i, (layer, ax) in enumerate(zip(residuals_or_outputs, axes)):
    plt.sca(ax)
    plt.pcolormesh(layer[0].T, vmin=0, vmax=1)
    if i == 0:
      add_residual_ticks(model, value_width=value_width)
    plt.xticks(
        np.arange(len(input_labels))+0.5,
        input_labels,
        rotation=90,
        fontsize=20,
    )
    if add_input_layer and i == 0:
      title = 'Input'
    else:
      layer_no = i - 1 if add_input_layer else i
      layer_type = 'Attn' if layer_no % 2 == 0 else 'MLP'
      title = f'{layer_type} {layer_no // 2 + 1}'
    plt.title(title, fontsize=20)


def plot_residuals_and_input(model, inputs, figsize=(12, 9)):
  """Applies model to inputs, and plots the residual stream at each layer."""
  model_out = model.apply(inputs)
  residuals = np.concatenate([model_out.input_embeddings[None, ...],
                              model_out.residuals], axis=0)
  plot_computation_trace(
      model=model,
      input_labels=inputs,
      residuals_or_outputs=residuals,
      add_input_layer=True,
      figsize=figsize)


def plot_layer_outputs(model, inputs, figsize=(12, 9)):
  """Applies model to inputs, and plots the outputs of each layer."""
  model_out = model.apply(inputs)
  plot_computation_trace(
      model=model,
      input_labels=inputs,
      residuals_or_outputs=model_out.layer_outputs,
      add_input_layer=False,
      figsize=figsize)


In [None]:
# compute average of a column
def average_numerical_column() -> rasp.SOp:
    square = rasp.Map(lambda x: x**2, rasp.tokens)
    all_selector = rasp.Select(rasp.tokens, square, rasp.Comparison.LT)
    identity_numerical = rasp.numerical(rasp.Map(lambda x: x, rasp.tokens))
    return rasp.numerical(rasp.Aggregate(all_selector, identity_numerical, default=0))
    

average = average_numerical_column()
average([31, 22, 22, 25, 28])

In [None]:
# I need to add numerical values
bos = "BOS"
model = compiling.compile_rasp_to_model(
    program=average,
    vocab={5, 10, 15},
    max_seq_len=5,
    compiler_bos=bos,
)

In [None]:
out = model.apply([bos, 5, 10, 15])
out.decoded

In [None]:
salary = [62000, 69000, 80000, 60000, 75000]

def select_employee_where_salary_above_x_below_y(x: rasp.SOp, y: rasp.SOp) -> rasp.SOp:
    # above
    bools_above = rasp.numerical(x)
    # below
    bools_below = rasp.numerical(y)
    rasp.Select(bools_above, bools_above, )

In [None]:
def select_employee_where_salary_above_x_below_y2(x: int, y: int) -> rasp.SOp:

    return rasp.Map(lambda z: x < int(z) < y, rasp.tokens)

select = select_employee_where_salary_above_x_below_y2(65000, 70000)

select([62000, 69000, 80000, 60000, 75000])

In [None]:
vocab = {'62000', '69000', '80000', '60000', '75000'}
max_seq_len = 3

assembled_model = compiling.compile_rasp_to_model(
      program=select,
      vocab=vocab,
      max_seq_len=max_seq_len,
      causal=False,
      compiler_bos="bos",
      compiler_pad="pad",
      mlp_exactness=100)

In [None]:
assembled_model.apply(["bos", 62000, 69000, 80000, 60000, 75000]).decoded

In [None]:
# model it by combining the values
def select_employee_where_salary_above_x_below_y_with_bonus(x: int, y: int) -> rasp.SOp:
    return rasp.Map(lambda z: x < z < y, rasp.tokens)
    

In [None]:
def shift_by(offset: int, /, sop: rasp.SOp) -> rasp.SOp:
  """Returns the sop, shifted by `offset`, None-padded."""
  select_off_by_offset = rasp.Select(rasp.indices, rasp.indices,
                                     lambda k, q: q == k + offset)
  out = rasp.Aggregate(select_off_by_offset, sop, default=None)
  return out.named(f"shift_by({offset})")

shift = shift_by(2, rasp.tokens)

def detect_pattern(sop: rasp.SOp, pattern: Sequence[rasp.Value]) -> rasp.SOp:
  """Returns an SOp which is True at the final element of the pattern.

  The first len(pattern) - 1 elements of the output SOp are None-padded.

  detect_pattern(tokens, "abc")("abcabc") == [None, None, T, F, F, T]

  Args:
    sop: the SOp in which to look for patterns.
    pattern: a sequence of values to look for.

  Returns:
    a sop which detects the pattern.
  """

  if len(pattern) < 1:
    raise ValueError(f"Length of `pattern` must be at least 1. Got {pattern}")

  # detectors[i] will be a boolean-valued SOp which is true at position j iff
  # the i'th (from the end) element of the pattern was detected at position j-i.
  detectors = []
  for i, element in enumerate(reversed(pattern)):
    detector = sop == element
    if i != 0:
      detector = shift_by(i, detector)
    detectors.append(detector)

  # All that's left is to take the AND over all detectors.
  pattern_detected = detectors.pop()
  while detectors:
    pattern_detected = pattern_detected & detectors.pop()

  return pattern_detected.named(f"detect_pattern({pattern})")

# Modeling JOIN

In [None]:
def join_and_calculate_bonuses() -> rasp.SOp:
    BONUS = 5000

    select_by_foreign_keys = rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ)
    bonus_counts = rasp.SelectorWidth(select_by_foreign_keys)

    salaries_with_bonuses = rasp.SequenceMap(lambda x, y: x + y * BONUS, rasp.tokens, bonus_counts)

    return bonus_counts

join = join_and_calculate_bonuses()

join([60000, 70000, 80000, 60000, 75000, 2, 1, 2, 0, 4])

In [None]:
vocab = {60000, 70000, 1, 1}
max_seq_len = 4

assembled_model = compiling.compile_rasp_to_model(
      program=join,
      vocab=vocab,
      max_seq_len=max_seq_len,
      causal=False,
      compiler_bos="bos",
      compiler_pad="pad",
      mlp_exactness=100)

In [None]:
assembled_model.apply(["bos", 60000, 70000, 80000, 2, 1, 2]).decoded

In [None]:
#@title Plot residual stream
plot_residuals_and_input(
  model=assembled_model,
  inputs=["bos", 60000, 70000, 1, 1],
  figsize=(50, 45)
)

In [None]:
#@title Plot layer outputs
plot_layer_outputs(
  model=assembled_model,
  inputs = ["bos", 60000, 70000, 1, 1],
  figsize=(8, 9)
)

### Filtering and Sorting by Categorical Variables
input: ['John', 'Jane', 'Mike', 'Emily', 'James', 'Sales', 'Marketing', 'IT', 'HR', 'Sales']

SELECT 
    EmployeeID,
    Name,
    Department,
    Salary
FROM 
    [EmployeeDB].[dbo].[Employees]
WHERE 
    Department = 'Sales'

output: input: ['James', 'John', 'Mike', 'Emily', 'James', 'Sales', 'Marketing', 'IT', 'HR', 'Sales']

In [None]:
def shift_by(offset: int, /, sop: rasp.SOp) -> rasp.SOp:
  """Returns the sop, shifted by `offset`, None-padded."""
  select_off_by_offset = rasp.Select(rasp.indices, rasp.indices,
                                     lambda k, q: q == k - offset)
  out = rasp.Aggregate(select_off_by_offset, sop, default=-1)

  out2 = rasp.Map(lambda x: x - 5 if x > 0 else x, out)
  
  return out2.named(f"shift_by({offset})")

### Working Where CLAUSE

In [None]:
def where_clause(row_lenght: int, department: str, /, sop: rasp.SOp) -> rasp.SOp:
  """Returns the sop, shifted by `offset`, None-padded."""
  department_indeces = rasp.SequenceMap(lambda x, y: y if x == 'Sales' else -1, rasp.tokens, rasp.indices)
  
  select_off_by_offset = rasp.Select(rasp.indices, department_indeces,
                                     lambda k, q: q == k + row_lenght)
  out = rasp.Aggregate(select_off_by_offset, sop, default='X')

  return out

In [None]:
shift = where_clause(5, 'Sales', rasp.tokens)
shift(['John', 'Jane', 'Mike', 'Emily', 'James', 'Sales', 'Marketing', 'IT', 'HR', 'Sales'])

In [None]:
def where_clause(row_lenght: int, department: str, /, sop: rasp.SOp) -> rasp.SOp:
  """Returns the sop, shifted by `offset`, None-padded."""
  department_indeces = rasp.SequenceMap(lambda x, y: y if x == 'Sales' else -1, rasp.tokens, rasp.indices)
  
  select_off_by_offset = rasp.Select(rasp.indices, department_indeces,
                                     lambda k, q: q == k + row_lenght)
  out = rasp.Aggregate(select_off_by_offset, sop, default='X')

  return out

shift = where_clause(5, 'Sales', rasp.tokens)
shift(['John', 'Jane', 'Mike', 'Emily', 'James', 'Sales', 'Marketing', 'IT', 'HR', 'Sales'])

In [None]:
def where_clause(row_lenght: int, department: str, /, sop: rasp.SOp) -> rasp.SOp:
  """Returns the sop, shifted by `offset`, None-padded."""
  department_indeces = rasp.SequenceMap(lambda x, y: y if x == department else -1, rasp.tokens, rasp.indices)
  
  select_off_by_offset = rasp.Select(rasp.indices, department_indeces,
                                     lambda k, q: q == k + row_lenght)
  out = rasp.Aggregate(select_off_by_offset, sop, default='X')

  return out

where = where_clause(5, 'I', rasp.tokens)
where(['F', 'E', 'H', 'A', 'D', 'I', 'G', 'C', 'B', 'I'])

In [None]:
def shift_by(offset: int, /, sop: rasp.SOp) -> rasp.SOp:
  """Returns the sop, shifted by `offset`, None-padded."""
  department_indeces = rasp.SequenceMap(lambda x, y: y if x == 'I' else -1, rasp.tokens, rasp.indices)

  select_off_by_offset = rasp.Select(rasp.indices, department_indeces,
                                     rasp.Comparison.EQ)
  out = rasp.Aggregate(select_off_by_offset, sop, default="A")

  
  return out.named(f"shift_by({offset})")

In [None]:
def shift_by(offset: int, /, sop: rasp.SOp) -> rasp.SOp:
  """Returns the sop, shifted by `offset`, None-padded."""
  select_off_by_offset = rasp.Select(rasp.indices, rasp.indices,
                                     lambda k, q: q == k + offset)
  out = rasp.Aggregate(select_off_by_offset, sop, default=None)
  return out.named(f"shift_by({offset})")

In [None]:
input = ['John', 'Jane', 'Mike', 'Emily', 'James', 'Sales', 'Marketing', 'IT', 'HR', 'Sales']


def filter_by_department_and_order_by_name() -> rasp.SOp:
    INPUT_LENGTH = 5    
    # select recors where sales are true

    # select indexes where Sales

    # create indexes map
    # index_mapping = rasp.Map(lambda x: x if x < 5 else x - 5, rasp.indices)

    name_ordering = rasp.Map(lambda x: True if x == 'Sales' else False, rasp.tokens)
    #select_sales_people = rasp.Select()
    select_all = rasp.Select(rasp.indices, rasp.indices,
                           rasp.Comparison.TRUE).named("select_all")
    
    return name_ordering

filter = filter_by_department_and_order_by_name()

filter(['John', 'Jane', 'Mike', 'Emily', 'James', 'Sales', 'Marketing', 'IT', 'HR', 'Sales'])

In [None]:
def map_words_to_letters(input_data):
    unique_words = sorted(set(input_data))
    word_to_letter = {word: chr(65 + idx) for idx, word in enumerate(unique_words)}
    mapped_letters = [word_to_letter[word] for word in input_data]
    return mapped_letters

def map_letters_to_words(input_letters):
    unique_words = sorted(set(input_data))
    word_to_letter = {word: chr(65 + idx) for idx, word in enumerate(unique_words)}
    letter_to_word = {v: k for k, v in word_to_letter.items()}
    mapped_words = [letter_to_word[letter] for letter in input_letters]
    return mapped_words

# Example usage
input_data = ['John', 'Jane', 'Mike', 'Emily', 'James', 'Sales', 'Marketing', 'IT', 'HR', 'Sales']
letters = map_words_to_letters(input_data)
print("Mapped Letters:", letters)
words = map_letters_to_words(letters)
print("Mapped Words:", words)

In [None]:
from typing import List
from tracr.rasp import rasp



def filter_by_department_and_order_by_name(sop: rasp.SOp) -> rasp.SOp:
    # Define the department to filter by
    department = 'I'
    john = 'F'
    
    # Create a boolean mask to select the department entries
    is_department = rasp.Map(lambda x: 1 if x == department else 0, rasp.tokens).named("is_department")
    index_mapping = rasp.Map(lambda x: x if x < 5 else x - 5, rasp.indices)
    is_john = rasp.Map(lambda x: 1 if x == john else 0, rasp.tokens).named("is_department")

    select_john = rasp.Select(is_department, is_john, rasp.Comparison.EQ)

    
    
    return is_john


filter = filter_by_department_and_order_by_name(rasp.tokens)
filter('FEHADIGCBI')


In [None]:
def make_hist() -> rasp.SOp:
  """Returns the number of times each token occurs in the input.

   (As implemented in the RASP paper.)

  Example usage:
    hist = make_hist()
    hist("abac")
    >> [2, 1, 2, 1]
  """
  same_tok = rasp.Select(rasp.tokens, rasp.tokens,
                         rasp.Comparison.EQ).named("same_tok")
  return rasp.SelectorWidth(same_tok).named("hist")


hist = make_hist()
hist('FEHADIGCBI')

In [None]:
vocab = {'John', 'Jane', 'Mike', 'Emily', 'James', 'Sales', 'Marketing', 'IT', 'HR', 'Sales'}
max_seq_len = 10

assembled_model = compiling.compile_rasp_to_model(
      program=filter,
      vocab=vocab,
      max_seq_len=max_seq_len,
      causal=False,
      compiler_bos="bos",
      compiler_pad="pad",
      mlp_exactness=100)

In [None]:
def make_length() -> rasp.SOp:
  """Creates the `length` SOp using selector width primitive.

  Example usage:
    length = make_length()
    length("abcdefg")
    >> [7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0]

  Returns:
    length: SOp mapping an input to a sequence, where every element
      is the length of that sequence.
  """
  all_true_selector = rasp.Select(
      rasp.tokens, rasp.tokens, rasp.Comparison.TRUE).named("all_true_selector")
  return rasp.SelectorWidth(all_true_selector).named("length")


length = make_length()

def make_reverse(sop: rasp.SOp) -> rasp.SOp:
  """Create an SOp that reverses a sequence, using length primitive.

  Example usage:
    reverse = make_reverse(rasp.tokens)
    reverse("Hello")
    >> ['o', 'l', 'l', 'e', 'H']

  Args:
    sop: an SOp

  Returns:
    reverse : SOp that reverses the input sequence.
  """
  opp_idx = (length - rasp.indices).named("opp_idx")
  opp_idx = (opp_idx - 1).named("opp_idx-1")
  reverse_selector = rasp.Select(rasp.indices, opp_idx,
                                 rasp.Comparison.EQ).named("reverse_selector")
  agg = rasp.Aggregate(reverse_selector, sop).named("reverse")
  return reverse_selector


reverse = make_reverse(rasp.tokens)
reverse("Hello")

## SELECT DISINCT
SELECT COUNT(DISTINCT Department) AS NumberOfDistinctDepartments
FROM Employees;

In [None]:
def make_sort_unique(vals: rasp.SOp, keys: rasp.SOp) -> rasp.SOp:
  """Returns vals sorted by < relation on keys.

  Only supports unique keys.

  Example usage:
    sort = make_sort(rasp.tokens, rasp.tokens)
    sort([2, 4, 3, 1])
    >> [1, 2, 3, 4]

  Args:
    vals: Values to sort.
    keys: Keys for sorting.
  """
  smaller = rasp.Select(keys, keys, rasp.Comparison.LT).named("smaller")
  target_pos = rasp.SelectorWidth(smaller).named("target_pos")
  sel_new = rasp.Select(target_pos, rasp.indices, rasp.Comparison.EQ)
  return rasp.Aggregate(sel_new, vals).named("sort")


def make_sort(vals: rasp.SOp, keys: rasp.SOp, *, max_seq_len: int,
              min_key: float) -> rasp.SOp:
  """Returns vals sorted by < relation on keys, which don't need to be unique.

  The implementation differs from the RASP paper, as it avoids using
  compositions of selectors to break ties. Instead, it uses the arguments
  max_seq_len and min_key to ensure the keys are unique.

  Note that this approach only works for numerical keys.

  Example usage:
    sort = make_sort(rasp.tokens, rasp.tokens, 5, 1)
    sort([2, 4, 3, 1])
    >> [1, 2, 3, 4]
    sort([2, 4, 1, 2])
    >> [1, 2, 2, 4]

  Args:
    vals: Values to sort.
    keys: Keys for sorting.
    max_seq_len: Maximum sequence length (used to ensure keys are unique)
    min_key: Minimum key value (used to ensure keys are unique)

  Returns:
    Output SOp of sort program.
  """
  keys = rasp.SequenceMap(lambda x, i: x + min_key * i / max_seq_len, keys,
                          rasp.indices)
  return make_sort_unique(vals, keys)

In [None]:
def map_departments_to_numbers(departments):
    unique_departments = {}
    unique_number = 0
    mapped_numbers = []
    
    for department in departments:
        if department not in unique_departments:
            unique_departments[department] = unique_number
            unique_number += 1
        mapped_numbers.append(unique_departments[department])
            
    return mapped_numbers

# Example usage
departments = ['Sales', 'Marketing', 'IT', 'HR', 'Sales']
mapped_numbers = map_departments_to_numbers(departments)
print(mapped_numbers)

In [None]:
def make_sort_unique(vals: rasp.SOp, keys: rasp.SOp) -> rasp.SOp:
  """Returns vals sorted by < relation on keys.

  Only supports unique keys.

  Example usage:
    sort = make_sort(rasp.tokens, rasp.tokens)
    sort([2, 4, 3, 1])
    >> [1, 2, 3, 4]

  Args:
    vals: Values to sort.
    keys: Keys for sorting.
  """
  smaller = rasp.Select(keys, keys, rasp.Comparison.LT).named("smaller")
  target_pos = rasp.SelectorWidth(smaller).named("target_pos")
  sel_new = rasp.Select(target_pos, rasp.indices, rasp.Comparison.EQ)
  return rasp.Aggregate(sel_new, vals).named("sort")


def make_sort(vals: rasp.SOp, keys: rasp.SOp, *, max_seq_len: int,
              min_key: float) -> rasp.SOp:
  """Returns vals sorted by < relation on keys, which don't need to be unique.

  The implementation differs from the RASP paper, as it avoids using
  compositions of selectors to break ties. Instead, it uses the arguments
  max_seq_len and min_key to ensure the keys are unique.

  Note that this approach only works for numerical keys.

  Example usage:
    sort = make_sort(rasp.tokens, rasp.tokens, 5, 1)
    sort([2, 4, 3, 1])
    >> [1, 2, 3, 4]
    sort([2, 4, 1, 2])
    >> [1, 2, 2, 4]

  Args:
    vals: Values to sort.
    keys: Keys for sorting.
    max_seq_len: Maximum sequence length (used to ensure keys are unique)
    min_key: Minimum key value (used to ensure keys are unique)

  Returns:
    Output SOp of sort program.
  """
  keys = rasp.SequenceMap(lambda x, i: x + min_key * i / max_seq_len, keys,
                          rasp.indices)
  return make_sort_unique(vals, keys)

def make_length() -> rasp.SOp:
  """Creates the `length` SOp using selector width primitive.

  Example usage:
    length = make_length()
    length("abcdefg")
    >> [7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0]

  Returns:
    length: SOp mapping an input to a sequence, where every element
      is the length of that sequence.
  """
  all_true_selector = rasp.Select(
      rasp.tokens, rasp.tokens, rasp.Comparison.TRUE).named("all_true_selector")
  return rasp.SelectorWidth(all_true_selector).named("length")

def make_reverse(sop: rasp.SOp) -> rasp.SOp:
  """Create an SOp that reverses a sequence, using length primitive.

  Example usage:
    reverse = make_reverse(rasp.tokens)
    reverse("Hello")
    >> ['o', 'l', 'l', 'e', 'H']

  Args:
    sop: an SOp

  Returns:
    reverse : SOp that reverses the input sequence.
  """
  opp_idx = (length - rasp.indices).named("opp_idx")
  opp_idx = (opp_idx - 1).named("opp_idx-1")
  reverse_selector = rasp.Select(rasp.indices, opp_idx,
                                 rasp.Comparison.EQ).named("reverse_selector")
  return rasp.Aggregate(reverse_selector, sop).named("reverse")

In [None]:
sort = make_sort(rasp.tokens, rasp.tokens, max_seq_len=5, min_key=1)
sorted = sort([1, 2, 3, 4, 1])
reverse = make_reverse(rasp.tokens)
reversed = reverse(sorted)

rasp.SequenceMap(lambda x, y: x+y, sorted, reversed)

In [None]:
def shift_by_left(offset: int, /, sop: rasp.SOp) -> rasp.SOp:
  """Returns the sop, shifted by `offset`, None-padded."""
  select_off_by_offset = rasp.Select(rasp.indices, rasp.indices,
                                     lambda k, q: q == k - offset)
  out = rasp.Aggregate(select_off_by_offset, sop, default=0)
  return out.named(f"shift_by({offset})")

def shift_by_right(offset: int, /, sop: rasp.SOp) -> rasp.SOp:
  """Returns the sop, shifted by `offset`, None-padded."""
  select_off_by_offset = rasp.Select(rasp.indices, rasp.indices,
                                     lambda k, q: q == k + offset)
  out = rasp.Aggregate(select_off_by_offset, sop, default=0)
  return out.named(f"shift_by({offset})")


def select_distinct(max_seq_len=5, min_key=1) -> rasp.SOp:
    sorted = make_sort(rasp.tokens, rasp.tokens, max_seq_len=5, min_key=1)
    zeros = shift_by_left(5, sorted)
    unique_value = shift_by_left(4, sorted)
    out = rasp.SequenceMap(lambda x, y: x + y, zeros, unique_value)

    shifts = max_seq_len - 1
    current_shift = unique_value
    for _ in range(shifts):
      current_shift = shift_by_right(1, current_shift)
      out = rasp.SequenceMap(lambda x, y: x + y, out, current_shift)

    return out

dist = select_distinct(max_seq_len=5, min_key=1)

In [None]:
from tracr.rasp import rasp

def make_hist() -> rasp.SOp:
  """Creates a histogram of token frequencies."""
  same_tok = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ).named("same_tok")
  return rasp.SelectorWidth(same_tok).named("hist")

def make_count_distinct() -> rasp.SOp:
  """Counts the number of distinct elements in the input list."""
  hist = make_hist().named("hist")
  unique_tokens = hist >= 1
  select_unique = rasp.Select(unique_tokens, unique_tokens, rasp.Comparison.TRUE).named("select_unique")
  count_unique = rasp.SelectorWidth(select_unique).named("count_unique")
  return count_unique

count_distinct = make_count_distinct()

# Example usage:
result = count_distinct([0, 1, 2, 3, 0])
print(result)  # Output: [4, 4, 4, 4, 4]
