In [6]:
from itertools import cycle, islice
from collections import defaultdict
from dataclasses import dataclass

rocks = [
    ((0, 1, 2, 3),),
    ((1,), (0, 1, 2), (1,)),
    ((0, 1, 2), (2,), (2,)),
    ((0,), (0,), (0,), (0,)),
    ((0, 1), (0, 1)) 
]

jets = open('puzzle.data', 'r').read()


@dataclass
class Position:
    x: int
    y: int


def check(x: int, y: int, rock, rows):
    for line_y, rock_line in enumerate(rock):
        rock_row = {p + x for p in rock_line}
        if -1 in rock_row or 7 in rock_row:
            return False
        if y + line_y not in rows:
            continue
        if rows[y + line_y] & rock_row:
            return False
    
    return True


def simulate(iterations):
    patterns = dict()
    rows = defaultdict(set)
    rows[-1] |= set(range(7))  # add a dummy row at the bottom
    jet_index = 0
    iteration_index = 0
    while iteration_index < iterations:
        rock = rocks[iteration_index % len(rocks)]
        pos = Position(2, max(rows) + 4)

        while pos.y >= 0:
            x = pos.x - 1 if jets[jet_index] == '<' else pos.x + 1
            jet_index = (jet_index + 1) % len(jets)
            if check(x, pos.y, rock, rows):
                pos.x = x

            y = pos.y - 1
            if not check(pos.x, y, rock, rows):
                break
            
            pos.y = y
        
        for line_y, rock_line in enumerate(rock):
            rows[pos.y + line_y] |= {p + pos.x for p in rock_line}
            if len(rows[pos.y + line_y]) == 7:
                for r in tuple(r for r in rows if r < pos.y + line_y):
                    rows.pop(r)

        iteration_index += 1        
        
        pattern = (rock, jet_index, tuple(tuple(sorted(r)) for r in (rows[i] for i in range(max(rows), max(rows) - 100, -1))))
        
        if pattern in patterns:
            iteration_start, height = patterns[pattern]
            height_diff = max(rows) - height
            for i, row in tuple(rows.items()):
                rows[i + height_diff * (iterations // (iteration_index - iteration_start) - 1)] = row
            iteration_index = iterations - iterations % (iteration_index - iteration_start) + iteration_start
            patterns.clear()
        else:
            patterns[pattern] = (iteration_index, max(rows))
        
    return max(rows) + 1

In [7]:
simulate(2022)

3191

In [8]:
simulate(1000000000000)

1572093023267