In [93]:
class Group:
    def __init__(self, id, army, n_units, hit_points, immune_to, weak_to, damage, damage_type, initiative):
        self.id = id
        self.army = army
        self.n_units = n_units
        self.hit_points = hit_points
        self.immune_to = immune_to
        self.weak_to = weak_to
        self.damage = damage
        self.damage_type = damage_type
        self.initiative = initiative
        self.effective_power = self.n_units * self.damage

    def __repr__(self):
        return f'{self.id}: {self.army}, {self.n_units}, {self.hit_points}, {', '.join(self.immune_to)}, {', '.join(self.weak_to)}, {self.damage}, {self.damage_type}, {self.initiative}, {self.effective_power}'


def calc_damage(attacker, target):
    if attacker.damage_type in target.immune_to:
        return 0
    damage = attacker.effective_power
    if attacker.damage_type in target.weak_to:
        damage *= 2
    return damage

def choose_targets(groups):
    target_map = {}
    for attacker in sorted(groups, key=lambda x: (x.effective_power, x.initiative), reverse=True):
        target_army = 'infection' if attacker.army == 'immune' else 'immune'
        possible_targets = [g for g in groups if g.army == target_army and g.id not in target_map.values()]
        possible_targets = sorted(
            possible_targets,
            key=lambda x: (calc_damage(attacker, x), x.effective_power, x.initiative),
            reverse=True
        )
        if possible_targets:
            target = possible_targets[0]
            if calc_damage(attacker, target) > 0:
                target_map[attacker.id] = target.id
    return target_map

def attack(groups, target_map):
    for attacker in sorted(groups, key=lambda x: x.initiative, reverse=True):
        if attacker.id not in target_map:
            continue
        target_id = target_map[attacker.id]
        for target in groups:
            if target.id == target_id:
                damage = calc_damage(attacker, target)
                if damage > 0:
                    n_units = damage // target.hit_points
                    if target.n_units - n_units > 0:
                        target.n_units -= n_units
                        target.effective_power = target.n_units * target.damage
                    else:
                        target.n_units = 0
                        target.effective_power = 0
    groups = [group for group in groups if group.n_units > 0]
    return groups

def init_groups(boost=0, test_case=False):
    if test_case:
        groups = [
            Group(0, 'immune', 17, 5390, [], ['radiation', 'bludgeoning'], 4507, 'fire', 2),
            Group(1, 'immune', 989, 1274, ['fire'], ['bludgeoning', 'slashing'], 25, 'slashing', 3),
            Group(2, 'infection', 801, 4706, [], ['radiation'], 116, 'bludgeoning', 1),
            Group(3, 'infection', 4485, 2961, ['radiation'], ['fire', 'cold'], 12, 'slashing', 4)
        ]
    else:
        groups = [
            Group(0, 'immune', 4082, 2910, [], [], 5, 'cold', 15),
            Group(1, 'immune', 2820, 9661, ['slashing'], [], 27, 'cold', 8),
            Group(2, 'immune', 4004, 4885, [], ['slashing'], 10, 'bludgeoning', 13),
            Group(3, 'immune', 480, 7219, [], ['bludgeoning'], 134, 'radiation', 18),
            Group(4, 'immune', 8734, 4421, ['bludgeoning'], [], 5, 'slashing', 14),
            Group(5, 'immune', 516, 2410, [], ['slashing'], 46, 'bludgeoning', 5),
            Group(6, 'immune', 2437, 11267, [], ['slashing'], 38, 'fire', 17),
            Group(7, 'immune', 1815, 7239, [], ['cold'], 33, 'slashing', 10),
            Group(8, 'immune', 4941, 10117, ['bludgeoning'], [], 20, 'fire', 9),
            Group(9, 'immune', 617, 7816, [], ['bludgeoning', 'slashing'], 120, 'bludgeoning', 4),
            Group(10, 'infection', 2877, 20620, [], ['radiation', 'bludgeoning'], 13, 'cold', 11),
            Group(11, 'infection', 1164, 51797, ['fire'], [], 63, 'fire', 7),
            Group(12, 'infection', 160, 31039, ['bludgeoning'], ['radiation'], 317, 'bludgeoning', 2),
            Group(13, 'infection', 779, 24870, ['radiation', 'bludgeoning'], ['slashing'], 59, 'slashing', 12),
            Group(14, 'infection', 1461, 28000, ['radiation'], ['bludgeoning'], 37, 'slashing', 16),
            Group(15, 'infection', 1060, 48827, [], [], 73, 'slashing', 3),
            Group(16, 'infection', 4422, 38291, [], [], 14, 'slashing', 1),
            Group(17, 'infection', 4111, 14339, ['fire', 'bludgeoning', 'cold'], [], 6, 'radiation', 20),
            Group(18, 'infection', 4040, 49799, ['bludgeoning', 'cold'], ['slashing', 'fire'], 24, 'fire', 19),
            Group(19, 'infection', 2198, 41195, [], ['radiation'], 36, 'slashing', 6)
        ]
    for group in groups:
        if group.army == 'immune':
            group.damage += boost
            group.effective_power = group.n_units * group.damage
    return groups

In [95]:
groups = init_groups(test_case=False)
while True:
    target_map = choose_targets(groups)
    groups = attack(groups, target_map)
    if not (any([g.army == 'immune' for g in groups]) and any([g.army == 'infection' for g in groups])):
        break

In [96]:
sum([g.n_units for g in groups])

19974

In [112]:
for boost in range(10000):
    if boost == 42:
        continue
    if boost % 100 == 0:
        print(f'Boost = {boost}')
    groups = init_groups(boost=boost)
    win = False
    while True:
        target_map = choose_targets(groups)
        groups = attack(groups, target_map)
        if not any([g.army == 'immune' for g in groups]):
            break
        if not any([g.army == 'infection' for g in groups]):
            print(f'Win for boost = {boost}!')
            win = True
            break
    if win:
        break

Boost = 0
Win for boost = 43!


In [111]:
sum([g.n_units for g in groups])

4606