# Challenge 1: count amount of 1, 4, 7 and 8 in the output signals

In [1]:
lines = open("input.txt", "r").read().split("\n")

In [20]:
class Display:
    def __init__(self, input_line) -> None:
        self.input, self.outputs = self.parse(input_line)
        self.mapper = None
        self.decode_lines()

    def parse(self, input_line):
        inputs, outputs = input_line.split("|")
        inputs = inputs.strip().split(" ")
        outputs = outputs.strip().split(" ")
        return inputs, outputs

    def letters_to_hex(self, letters):
        hex = 0x00
        letters = sorted(letters)
        if "a" in letters:
            hex |= 0x40
        if "b" in letters:
            hex |= 0x20
        if "c" in letters:
            hex |= 0x10
        if "d" in letters:
            hex |= 0x08
        if "e" in letters:
            hex |= 0x04
        if "f" in letters:
            hex |= 0x02
        if "g" in letters:
            hex |= 0x01
        return hex

    def decode_lines(self):
        combinations = {}

        hexs = [self.letters_to_hex(i) for i in self.input]

        remaining_i = []
        remaining_h = []
        for i, h in zip(self.input, hexs):
            if len(i) == 2:
                combinations[1] = h
            elif len(i) == 3:
                combinations[7] = h
            elif len(i) == 4:
                combinations[4] = h
            elif len(i) == 7:
                combinations[8] = h
            else:
                remaining_i.append(i)
                remaining_h.append(h)

        assert all([i in combinations.keys() for i in [1, 4, 7, 8]])

        # 0, 6 and 9 are 6 letters long. The one that doesn't inter with 1 is 6
        iter = list(zip(remaining_i, remaining_h))
        remaining_i.clear()
        remaining_h.clear()
        one = combinations[1]

        for i, h in iter:
            if len(i) == 6:
                zero_six_or_nine = h
                if (zero_six_or_nine & one) != one:
                    combinations[6] = zero_six_or_nine
                else:
                    remaining_i.append(i)
                    remaining_h.append(h)
            else:
                remaining_i.append(i)
                remaining_h.append(h)

        assert 6 in combinations.keys()

        # Only one of length 5 to contain 1 entirely is 3
        iter = list(zip(remaining_i, remaining_h))
        remaining_i.clear()
        remaining_h.clear()
        for i, h in iter:
            if len(i) == 5 and (h & one) == one:
                combinations[3] = h
            else:
                remaining_i.append(i)
                remaining_h.append(h)

        assert 3 in combinations.keys()

        # Obtain "be" with 8^3
        be = combinations[8] ^ combinations[3]

        # Only remaining number to contain "be" is 0
        iter = list(zip(remaining_i, remaining_h))
        remaining_i.clear()
        remaining_h.clear()
        for i, h in iter:
            if (h & be) == be:
                combinations[0] = h
            else:
                remaining_i.append(i)
                remaining_h.append(h)

        assert 0 in combinations.keys()

        # Only remaining number with length 6 is 9
        iter = list(zip(remaining_i, remaining_h))
        remaining_i.clear()
        remaining_h.clear()
        for i, h in iter:
            if len(i) == 6:
                combinations[9] = h
            else:
                remaining_i.append(i)
                remaining_h.append(h)

        assert 9 in combinations.keys()

        # Obtain e by doing 8 ^ 9
        e = combinations[8] ^ combinations[9]
        
        # The only number that remains and contains e is 2. The other one is 5
        iter = list(zip(remaining_i, remaining_h))
        for i, h in iter:
            if (h & e) == e:
                combinations[2] = h
            else:
                combinations[5] = h

        assert 2 in combinations.keys() and 5 in combinations.keys()
        
        self.mapper = {k: v for v, k in combinations.items()}

    def count_1s(self):
        cnt = 0
        for o in self.outputs:
            if len(o) == 2:
                cnt += 1
        return cnt

    def count_4s(self):
        cnt = 0
        for o in self.outputs:
            if len(o) == 4:
                cnt += 1
        return cnt

    def count_7s(self):
        cnt = 0
        for o in self.outputs:
            if len(o) == 3:
                cnt += 1
        return cnt

    def count_8s(self):
        cnt = 0
        for o in self.outputs:
            if len(o) == 7:
                cnt += 1
        return cnt
    
    def output_sum(self):
        output_hex = [self.letters_to_hex(i) for i in self.outputs]
        decoded_output = [self.mapper[i] for i in output_hex]
        return sum(decoded_output)

    def output_to_num(self):
        output_hex = [self.letters_to_hex(i) for i in self.outputs]
        decoded_output = [str(self.mapper[i]) for i in output_hex]
        return int("".join(decoded_output))

cnt = 0
for l in lines:
    if len(l) > 10:
        d = Display(l)
        cnt += d.count_1s() + d.count_4s() + d.count_7s() + d.count_8s()
print(cnt)


303


# Challenge 2: sum all outputs together

In [21]:
cnt = 0
for l in lines:
    if len(l) > 10:
        d = Display(l)
        cnt += d.output_to_num()
print(cnt)

961734
