In [1]:
import tensorflow as tf
# Fixes bad convolution
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [2]:
def gen_matrix(ingredients):
    all_ingredients = tf.reshape(ingredients, (-1,))
    # The indices will help map to columns in the matrix
    unique_ingredients, ingredients_indices = tf.unique(all_ingredients)
    blank_idx = tf.where(unique_ingredients == '')[0][0]
    total_ingredients = tf.shape(unique_ingredients)[0]
    max_amount_ingredients = tf.shape(ingredients)[1]
    total_foods = tf.shape(ingredients)[0]
    
    food_to_ingredient_map = tf.map_fn(
        lambda x: tf.reduce_sum(tf.one_hot(x, total_ingredients, dtype=tf.int32), axis=0),
        tf.reshape(ingredients_indices, tf.shape(ingredients))
    )
    # Get rid of blank ingredients
    return tf.boolean_mask(food_to_ingredient_map, unique_ingredients!='', axis=1)

@tf.function
def solve(s):
    s = tf.strings.split(s, '\n')
    s = tf.strings.split(s, ' (')
    s = tf.strings.regex_replace(s, 'contains |,|\)', '')
    s = tf.strings.split(s, ' ').to_tensor()
    
    ingredients = s[:,0]
    allergens = s[:,1]
    
    food_to_ingredient_map = gen_matrix(ingredients)
    food_to_allergen_map = gen_matrix(allergens)
    
    allergen_shape = tf.shape(food_to_allergen_map)
    ingredient_shape = tf.shape(food_to_ingredient_map)
    
    
    common_ingredients_stack = tf.map_fn(
        lambda i: tf.reduce_prod(food_to_ingredient_map[food_to_allergen_map[:,i]==1], axis=0),
        tf.range(allergen_shape[1])
    )
    
    return tf.reduce_sum(
        tf.boolean_mask(
            food_to_ingredient_map,
            tf.reduce_sum(common_ingredients_stack, axis=0) == 0,
            axis=1
        )
    )
        
    
%timeit solve(tf.io.read_file('day21.txt'))

3.25 ms ± 281 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [3]:
def gen_matrix(ingredients):
    all_ingredients = tf.reshape(ingredients, (-1,))
    # The indices will help map to columns in the matrix
    unique_ingredients, ingredients_indices = tf.unique(all_ingredients)
    blank_idx = tf.where(unique_ingredients == '')[0][0]
    total_ingredients = tf.shape(unique_ingredients)[0]
    max_amount_ingredients = tf.shape(ingredients)[1]
    total_foods = tf.shape(ingredients)[0]
    
    food_to_ingredient_map = tf.map_fn(
        lambda x: tf.reduce_sum(tf.one_hot(x, total_ingredients, dtype=tf.int32), axis=0),
        tf.reshape(ingredients_indices, tf.shape(ingredients))
    )
    # Get rid of blank ingredients
    return tf.boolean_mask(food_to_ingredient_map, unique_ingredients!='', axis=1), unique_ingredients[unique_ingredients!='']

def sort_matches(matches):
    res = ','.join(matches[i][1].numpy().decode("utf-8") for i in sorted(range(matches.shape[0]), key=lambda i: matches[i][0].numpy().decode("utf-8")))
    return res
    
@tf.function
def solve(s):
    s = tf.strings.split(s, '\n')
    s = tf.strings.split(s, ' (')
    s = tf.strings.regex_replace(s, 'contains |,|\)', '')
    s = tf.strings.split(s, ' ').to_tensor()
    
    ingredients = s[:,0]
    allergens = s[:,1]
    
    food_to_ingredient_map, ingredient_names = gen_matrix(ingredients)
    food_to_allergen_map, allergen_names = gen_matrix(allergens)

    allergen_shape = tf.shape(food_to_allergen_map)
    ingredient_shape = tf.shape(food_to_ingredient_map)
    unmatched_ingredients = tf.fill(ingredient_shape[1:], True)
    matched_allergens = tf.fill(allergen_shape[1:], False)
    
    ta = tf.TensorArray(tf.string, size=allergen_shape[1])
    ta_idx = 0
    
    n_matched = tf.constant(-1, tf.int64)
    while n_matched != tf.math.count_nonzero(~unmatched_ingredients):
        n_matched = tf.math.count_nonzero(~unmatched_ingredients)
        # Loop over the different allergens
        for i in tf.range(allergen_shape[1]):
            ingredients_to_check = food_to_ingredient_map
            common_ingredients = tf.reduce_prod(ingredients_to_check[food_to_allergen_map[:,i]==1], axis=0)

            # If theres only a single common ingredient for the allergen, we can rule it out
            if tf.reduce_sum(common_ingredients[unmatched_ingredients]) == 1:
                common_ingredient_index = tf.where((common_ingredients==1)&unmatched_ingredients)[0]
                
                ta = ta.write(ta_idx, [allergen_names[i], ingredient_names[common_ingredient_index[0]]])
                ta_idx += 1
                
                unmatched_ingredients = tf.tensor_scatter_nd_update(
                    unmatched_ingredients,
                    [common_ingredient_index],
                    [False]
                )
                matched_allergens = tf.tensor_scatter_nd_update(
                    matched_allergens,
                    [[i]],
                    [True]
                )
    matched_ingredients = ~unmatched_ingredients
    
    final_ingredient_names = ingredient_names[matched_ingredients]
    final_allergen_names = allergen_names[matched_allergens]
    
    return tf.py_function(
        sort_matches,
        [ta.stack()],
        [tf.string]
    )
    
%timeit solve(tf.io.read_file('day21.txt'))

11.2 ms ± 812 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
