In [151]:
from dataclasses import dataclass

@dataclass
class Record:
  pattern: str
  counts: list[int]

In [152]:
def get_input(filename: str):
  records = []
  with open(filename) as f:
    for line in f.read().splitlines():
      pattern, counts_str = line.split()
      records.append(Record(pattern, list(map(int, counts_str.split(',')))))
  return records

In [162]:
from collections import deque
from functools import cache

def solve_part1(filename: str):
  records = get_input(filename)
  return sum([number_of_record_variations_part1(record) for record in records])

def solve_part2(filename: str):
  records = get_input(filename)
  answers = []
  for record in records:
    print(record)
    answers.append(number_of_record_variations_part2(record))
  print(answers)
  return sum(answers)

def number_of_record_variations_part1(record: Record):
  res = 0
  for possible_string in gen_possible_string(record.pattern):
    if is_string_valid(possible_string, record.counts):
      res += 1
  return res

def is_string_valid(str: str, counts: list[int]):
  return counts == list(map(lambda v: len(v), filter(lambda v: len(v) != 0, str.split('.'))))

def gen_possible_string(pattern: str):
  if '?' not in pattern:
    yield pattern
    return
  i = pattern.index('?')
  for possible_pattern in gen_possible_string(pattern[:i] + '.' + pattern[i + 1:]):
    yield possible_pattern
  for possible_pattern in gen_possible_string(pattern[:i] + '#' + pattern[i + 1:]):
    yield possible_pattern

def number_of_record_variations_part2(record: Record):
  pattern = '?'.join([record.pattern] * 5)
  counts = record.counts  * 5
  important_fields = list(filter(lambda v: len(v) > 0, pattern.split('.')))
  return process_fields(tuple(important_fields), tuple(counts))

@cache
def process_fields(fields, counts):
  if len(fields) == 0 and len(counts) == 0:
    return 1
  if len(fields) == 0:
    return 0
  if len(counts) == 0:
    if '#' in ''.join(fields):
      return 0
    return 1
  res = 0
  if (fields[0][0] == '#' and len(fields[0]) >= counts[0]) or len(fields[0]) >= counts[0]:
    res += s_with_hash(fields, counts)
  if fields[0][0] == '?':
    left = fields[0][1:]
    res += process_fields(((left,) if left != '' else tuple()) + fields[1:], counts)
  return res
  
def s_with_hash(fields: list[str], counts: list[int]):
  field = fields[0]
  left = field[counts[0]:]
  if len(left) > 0 and left[0] == '#':
    return 0
  left = left[1:]
  new_fields = fields[1:] if left == '' else (left,) + fields[1:]
  return process_fields(new_fields, counts[1:])

In [165]:
# print(solve_part1('test-input.txt'))
# print(solve_part1('input.txt'))

# print(number_of_record_variations_part2(Record(pattern='?????.?#??.#????#?', counts=[1, 1, 1, 3, 2, 4])))

print(solve_part2('test-input.txt'))
print(solve_part2('input.txt'))

Record(pattern='???.###', counts=[1, 1, 3])
Record(pattern='.??..??...?##.', counts=[1, 1, 3])
Record(pattern='?#?#?#?#?#?#?#?', counts=[1, 3, 1, 6])
Record(pattern='????.#...#...', counts=[4, 1, 1])
Record(pattern='????.######..#####.', counts=[1, 6, 5])
Record(pattern='?###????????', counts=[3, 2, 1])
[1, 16384, 1, 16, 2500, 506250]
525152
