<a href="https://colab.research.google.com/github/yufengg/advent-of-code-2024/blob/master/AoC_2024_day_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# prompt: python code download advent of code file

import requests
import os
from datetime import datetime

def download_aoc_input(year, day, session_cookie):
  """Downloads the input file for a specific Advent of Code day.

  Args:
    year: The year of the challenge (e.g., 2023).
    day: The day of the challenge (e.g., 1).
    session_cookie: Your session cookie from the Advent of Code website.
  """

  url = f"https://adventofcode.com/{year}/day/{day}/input"
  headers = {
      "User-Agent": "github.com/yourusername/yourrepo (or your email)",  # Replace with your info
      "Cookie": f"session={session_cookie}"
  }

  try:
    response = requests.get(url, headers=headers)
    response.raise_for_status()  # Raise an exception for bad status codes

    # Create a directory for the year if it doesn't exist
    os.makedirs(str(year), exist_ok=True)

    filename = os.path.join(str(year), f"day{day}.txt")
    with open(filename, "w") as file:
      file.write(response.text)
    print(f"Successfully downloaded input for year {year}, day {day} to {filename}")
    return filename

  except requests.exceptions.RequestException as e:
    print(f"Error downloading input: {e}")
  except Exception as e:
    print(f"An unexpected error occurred: {e}")


# Example usage:  Replace with your actual values
year = 2024  #@param {type:"integer"}
day = 1 #@param {type:"integer"}
session_cookie = "53616c7465645f5fd6809f86933785dcec3c4f79b521758a796b0808898e32ff9a09d06dd15282b2a68340b7bf970ab66894fa0bf9271caa521bc5d5fc8df3b9" #@param {type:"string"}

input_filename = download_aoc_input(year, day, session_cookie)

Successfully downloaded input for year 2024, day 1 to 2024/day1.txt


In [2]:
import jax

In [6]:
import jax.numpy as jnp

In [7]:

def read_input(filename, parser=None):
  with open(filename, 'r') as file:
    lines = file.read().splitlines()
    if parser:
      lines = list(map(parser, lines))
  return lines


In [8]:
list(map(int, '39472   15292'.split()))

[39472, 15292]

In [9]:
input_filename = f'{year}/day{day}.txt'
def day1_parser(line):
  return list(map(int, line.split()))

input_lines = read_input(input_filename, day1_parser)
input_matrix = jnp.array(input_lines)

In [20]:
def abs_diff(input_matrix):
  sorted_matrix = jnp.sort(input_matrix, axis=0)
  diffs = sorted_matrix[:,0] - sorted_matrix[:,1]
  return jnp.abs(diffs).sum()


In [21]:
%timeit abs_diff(input_matrix).block_until_ready()

664 µs ± 28.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [22]:
jit_abs_diff = jax.jit(abs_diff)
_ = jit_abs_diff(input_matrix)
%timeit jit_abs_diff(input_matrix).block_until_ready()

459 µs ± 62.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# Part 2

Compute count of list 2 as a dict: `{number: count}`
Go through list 1, referencing the dict, use `get(key, default=0)`
Increment the total by number * count for each element.

In [64]:
# @jax.jit
def similarity_score(input_matrix):
  # count up occurrences
  # sorted_matrix = jnp.sort(input_matrix, axis=0)
  list2 = input_matrix[:,1]
  val, cts = jnp.unique(list2, return_counts=True, size=1000)
  list2_dict = dict(zip(val.tolist(), cts.tolist()))

  # def get_score(x):
  #   return x * list2_dict.get(x, 0)
  # v_get_score = jax.vmap(get_score)

  scores = [x * list2_dict.get(x, 0) for x in input_matrix[:,0]]
  # scores = v_get_score(input_matrix[:,0])
  return jnp.sum(scores)

In [65]:
similarity_score(input_matrix)

TypeError: unhashable type: 'jaxlib.xla_extension.ArrayImpl'

In [69]:
list2 = input_matrix[:,1]
val, cts = jnp.unique(list2, return_counts=True)
list2_dict = dict(zip(val.tolist(), cts.tolist()))

scores = jnp.array([x * list2_dict.get(x, 0) for x in input_matrix[:,0].tolist()])


jnp.sum(scores)

Array(27732508, dtype=int32)

In [57]:
list2_dict.get(10318,0)

1

In [53]:
10318 in list2

True