# Example of using SpecialtyInsurance Simulator


In [None]:
### 1. Set Simulation Parameters

import os
from logger.arguments import get_arguments

# Read arguments from logger.arguments
sim_args, manager_args, broker_args, syndicate_args, reinsurancefirm_args, shareholder_args, risk_args, seed = get_arguments()

# Reset arguments
sim_args["max_time"] = 1000   # Simulation time span unit day
manager_args["lead_top_k"] = 3   # Number of syndicates competing for the lead quote
manager_args["follow_top_k"] = 2   # Number of syndicates following the lead strategy
broker_args["num_brokers"] = 100   # Number of brokers in the insurance market
syndicate_args["num_syndicates"] = 20   # Number of syndicates in the insurance market
shareholder_args["num_shareholders"] = 1   # Number of shareholders in the insurance market
risk_args["num_risks"] = 10000  # Number of risks
risk_args["num_categories"] = 4  # Number of risk categories

# No reinsurance mechanism included in this stage
with_reinsurance = False   

# Nomber of risk models loaded to all syndicates
num_risk_models = 1   


In [None]:
### 2. Generate Catastrophes

from environment.risk_generator import RiskGenerator

# Create catastrophe list and catastrophe configurations
catastrophes, catastrophe_damage, broker_risks, fair_market_premium, risk_model_configs = RiskGenerator(num_risk_models, sim_args, broker_args, risk_args, seed).generate_risks()
print(catastrophes)
print(broker_risks)


In [None]:
### 3. Generate Insurance Market

from environment.market_generator import MarketGenerator
from logger import logger

# Create lists of brokers, syndicates, reinsurancefirms, and shareholders
brokers, syndicates, reinsurancefirms, shareholders = MarketGenerator(with_reinsurance, 
                                                                      num_risk_models, 
                                                                      sim_args, 
                                                                      broker_args, 
                                                                      syndicate_args, 
                                                                      reinsurancefirm_args, 
                                                                      shareholder_args, 
                                                                      risk_model_configs).generate_agents()
for broker_id in range(len(brokers)):
    print(brokers[broker_id].data())
for syndicate_id in range(len(syndicates)):
    print(syndicates[syndicate_id].data())


In [None]:
### 4. Input risk from broker

from environment.event_generator import EventGenerator

current_time = 0
broker_risk_events = EventGenerator(risk_model_configs).generate_risk_events(sim_args, broker_risks)
catastrophe_events = EventGenerator(risk_model_configs).generate_catastrophe_events(catastrophes)
attritional_loss_events = EventGenerator(risk_model_configs).generate_attritional_loss_events(sim_args, broker_risks)
broker_premium_events = EventGenerator(risk_model_configs).generate_premium_events(sim_args)
broker_claim_events = EventGenerator(risk_model_configs).generate_claim_events(sim_args)

for i in range(len(broker_risk_events)):
    print("risk_id:", broker_risk_events[i].risk_id, "broker_id:", broker_risk_events[i].broker_id, "risk_start_time:", broker_risk_events[i].risk_start_time,
         "risk_end_time:", broker_risk_events[i].risk_end_time, "risk_factor:", broker_risk_events[i].risk_factor,
         "risk_category:", broker_risk_events[i].risk_category, "risk_value:", broker_risk_events[i].risk_value)


In [None]:
from __future__ import annotations
import warnings
from environment.event.add_catastrophe import AddCatastropheEvent
from environment.event.add_attritionalloss import AddAttritionalLossEvent
from environment.event.add_risk import AddRiskEvent
from environment.event.add_premium import AddPremiumEvent
from environment.event.add_claim import AddClaimEvent
import numpy as np
from environment.market import NoReinsurance_RiskOne, NoReinsurance_RiskFour, Reinsurance_RiskOne, Reinsurance_RiskFour
from manager.event_handler import EventHandler

class MarketManager:
    """
    Manage and evolve the market.
    """

    def __init__(self, maxstep, sim_args, manager_args, brokers, syndicates, reinsurancefirms, shareholders, catastrophes, fair_market_premium, risk_model_configs, with_reinsurance, num_risk_models, 
                 catastrophe_events, attritional_loss_events, broker_risk_events, broker_premium_events, broker_claim_events, event_handler, logger = None, time = 0):
        self.maxstep = maxstep
        self.sim_args = sim_args
        self.manager_args = manager_args
        self.brokers = brokers
        self.syndicates = syndicates
        self.reinsurancefirms = reinsurancefirms
        self.shareholders = shareholders
        self.catastrophes = catastrophes
        self.fair_market_premium = fair_market_premium
        self.risk_model_configs = risk_model_configs
        self.with_reinsurance = with_reinsurance
        self.num_risk_models = num_risk_models
        self.catastrophe_events = catastrophe_events
        self.attritional_loss_events = attritional_loss_events
        self.broker_risk_events = broker_risk_events
        self.broker_premium_events = broker_premium_events
        self.broker_claim_events = broker_claim_events
        self.event_handler = event_handler

        self.market = NoReinsurance_RiskOne(time, self.maxstep, self.manager_args, self.brokers, self.syndicates, self.shareholders, self.catastrophes, self.risk_model_configs, 
                                            self.catastrophe_events, self.attritional_loss_events, self.broker_risk_events, self.broker_premium_events, self.broker_claim_events)

        self.min_step_time = 1  # Day Event

        self.actions_to_apply = []
        # For logging keep track of all Actions ever received and whether they were accepted or refused by the manager
        self.actions_accepted = {}
        self.actions_refused = {}

        # Logging
        self.logger = logger
        if self.logger is not None:
            self.logger._store_metadata(
                self.market.time, self.market.brokers, self.market.syndicates, self.market.reinsurancefirms, self.market.shareholders, self.event_handler
            )

    def evolve_action_market(self, starting_broker_risk):
        """
        Evolve the syndicate, broker, risk in the market for step_time [day].

        Parameters
        ----------
        starting_broker_risk: AddRiskEvent
            The current risk event.
        step_time: float
            Amount of time in days to evolve the Market for.
        """

        # Update the status of brokers and syndicates in the market
        num_risk = len(self.actions_to_apply)
        for num in range(num_risk):
            broker_id = starting_broker_risk[num].broker_id
            risks = {"risk_id": starting_broker_risk[num].risk_id,
                "risk_start_time": starting_broker_risk[num].risk_start_time,
                "risk_end_time": starting_broker_risk[num].risk_end_time+self.sim_args["mean_contract_runtime"],
                "risk_factor": starting_broker_risk[num].risk_factor,
                "risk_category": starting_broker_risk[num].risk_category,
                "risk_value": starting_broker_risk[num].risk_value}
            if len(self.actions_to_apply[num]) > 0:
                lead_syndicate_id = self.actions_to_apply[num][0].syndicate
                lead_line_size = 0.5
                lead_syndicate_premium = self.actions_to_apply[num][0].premium * lead_line_size
                premium = lead_syndicate_premium
                follow_syndicates_id = [None for i in range(len(self.market.syndicates))]
                follow_syndicates_premium = [None for i in range(len(self.market.syndicates))]
                follow_line_sizes = 0.1
                for i in range(1,len(self.actions_to_apply[num])):
                    follow_syndicates_id[i-1] = self.actions_to_apply[num][i].syndicate
                    follow_syndicates_premium[i-1] = self.actions_to_apply[num][i].premium * follow_line_sizes
                    premium += follow_syndicates_premium[i-1]
                self.market.brokers[int(broker_id)].add_contract(risks, lead_syndicate_id, lead_line_size, lead_syndicate_premium, follow_syndicates_id, follow_line_sizes, follow_syndicates_premium, premium)
                self.market.syndicates[int(lead_syndicate_id)].add_leader(risks, lead_line_size, lead_syndicate_premium)
                self.market.syndicates[int(lead_syndicate_id)].add_contract(risks, broker_id, lead_syndicate_premium)
                for sy in range(len(follow_syndicates_id)):
                    if follow_syndicates_id[sy] != None:
                        self.market.syndicates[int(follow_syndicates_id[sy])].add_follower(risks, follow_line_sizes, self.actions_to_apply[num][1+sy].premium)
                        self.market.syndicates[int(follow_syndicates_id[sy])].add_contract(risks, broker_id, self.actions_to_apply[num][1+sy].premium)
                    else:
                        self.market.brokers[int(broker_id)].not_underwritten_risk(risks)

    def run_attritional_loss(self, starting_attritional_loss):
        """
        Update market with attritional loss event

        Parameters
        ----------
        starting_attritional_loss: AddAttritionalLossEvent
            The current attritional loss event
        """
        for i in range(len(self.market.syndicates)):
            self.market.syndicates[i].current_capital -= starting_attritional_loss.risk_value * 0.001
            self.market.syndicates[i].profits_losses -= starting_attritional_loss.risk_value * 0.001
            self.market.syndicates[i].market_permanency(starting_attritional_loss.risk_start_time)

    def run_broker_premium(self, starting_broker_premium):
        """
        Update market with premium event

        Parameters
        ----------
        starting_broker_premium: AddPremiumEvent
            The current premium event
        """
        for broker_id in range(len(self.market.brokers)):
            affected_contract = []
            for num in range(len(self.market.brokers[broker_id].underwritten_contracts)):
                if self.market.brokers[broker_id].underwritten_contracts[num]["risk_end_time"] >= starting_broker_premium.risk_start_time:
                    affected_contract.append(self.market.brokers[broker_id].underwritten_contracts[num])
            for num in range(len(affected_contract)):
                print(2)
                premium, lead_syndicate_premium, follow_syndicates_premium, lead_syndicate_id, follow_syndicates_id, risk_category = self.market.brokers[broker_id].pay_premium(affected_contract[num])
                self.market.syndicates[int(lead_syndicate_id)].receive_premium(lead_syndicate_premium*1000, risk_category)
                for follow_id in range(len(follow_syndicates_id)):
                    if follow_syndicates_id[follow_id] != None:
                        self.market.syndicates[int(follow_syndicates_id[follow_id])].receive_premium(follow_syndicates_premium[follow_id]*1000, risk_category)
    
    def run_broker_claim(self, starting_broker_claim):
        """
        Update market with claim event

        Parameters
        ----------
        starting_broker_claim: AddClaimEvent
            The current claim event
        """
        for broker_id in range(len(self.market.brokers)):
            affected_contract = []
            for num in range(len(self.market.brokers[broker_id].underwritten_contracts)):
                if self.market.brokers[broker_id].underwritten_contracts[num]["risk_end_time"] == starting_broker_claim.risk_start_time+1:
                    affected_contract.append(self.market.brokers[broker_id].underwritten_contracts[num])
            for num in range(len(affected_contract)):
                claim, lead_syndicate_id, follow_syndicates_id, risk_category, lead_claim_value, follow_claim_values = self.market.brokers[broker_id].end_contract_ask_claim(affected_contract[num])
                if self.market.syndicates[int(lead_syndicate_id)].current_capital >= lead_claim_value:
                    # TODO: now pay claim according to broker id, can add other mechanism in the future
                    self.market.syndicates[int(lead_syndicate_id)].pay_claim(broker_id, risk_category, lead_claim_value)
                    self.market.syndicates[int(lead_syndicate_id)].current_capital -= lead_claim_value
                    self.market.syndicates[int(lead_syndicate_id)].current_capital_category[risk_category] -= lead_claim_value
                    self.market.syndicates[int(lead_syndicate_id)].profits_losses -= lead_claim_value
                    self.market.brokers[broker_id].receive_claim(lead_syndicate_id, risk_category, lead_claim_value, lead_claim_value)
                else:
                    self.market.syndicates[int(lead_syndicate_id)].pay_claim(broker_id, risk_category, lead_claim_value)
                    self.market.syndicates[int(lead_syndicate_id)].current_capital -= lead_claim_value
                    self.market.syndicates[int(lead_syndicate_id)].current_capital_category[risk_category] -= lead_claim_value
                    self.market.syndicates[int(lead_syndicate_id)].profits_losses -= lead_claim_value
                    self.market.brokers[broker_id].receive_claim(lead_syndicate_id, risk_category, lead_claim_value, self.market.syndicates[int(lead_syndicate_id)].current_capital)
                    self.market.syndicates[int(lead_syndicate_id)].bankrupt() 
                    self.market.syndicates[int(lead_syndicate_id)].excess_capital = 0
                for follow_num in range(len(follow_syndicates_id)): 
                    if follow_syndicates_id[follow_num] != None: 
                        if self.market.syndicates[int(follow_syndicates_id[follow_num])].current_capital >= follow_claim_values[follow_num]:
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].pay_claim(broker_id, risk_category, follow_claim_values[follow_num])
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].current_capital -= follow_claim_values[follow_num]
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].current_capital_category[risk_category] -= follow_claim_values[follow_num]
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].profits_losses -= follow_claim_values[follow_num]
                            self.market.brokers[broker_id].receive_claim(follow_syndicates_id[follow_num], risk_category, follow_claim_values[follow_num], follow_claim_values[follow_num])
                        else:
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].pay_claim(broker_id, risk_category, follow_claim_values[follow_num])
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].current_capital -= follow_claim_values[follow_num]
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].current_capital_category[risk_category] -= follow_claim_values[follow_num]
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].profits_losses -= follow_claim_values[follow_num]
                            self.market.brokers[broker_id].receive_claim(follow_syndicates_id[follow_num], risk_category, follow_claim_values[follow_num], self.market.syndicates[int(follow_syndicates_id[follow_num])].current_capital)
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].bankrupt()
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].excess_capital = 0

    def run_catastrophe(self, starting_catastrophe):
        """
        Update market with catastrophe event

        Parameters
        ----------
        starting_catastrophe: AddCatastropheEvent
            The current catastrophe event
        """
        # Catastrophe will influce the broker claim
        for broker_id in range(len(self.market.brokers)):
            affected_contract = []
            for num in range(len(self.market.brokers[broker_id].underwritten_contracts)):
                if int(self.market.brokers[broker_id].underwritten_contracts[num]["risk_category"]) == int(starting_catastrophe.catastrophe_category):
                    affected_contract.append(self.market.brokers[broker_id].underwritten_contracts[num])
            for num in range(len(affected_contract)):
                claim, lead_syndicate_id, follow_syndicates_id, catastrophe_category, lead_claim_value, follow_claim_values = self.market.brokers[broker_id].end_contract_ask_claim(affected_contract[num])
                if self.market.syndicates[int(lead_syndicate_id)].current_capital >= lead_claim_value:
                    # TODO: now pay claim according to broker id, can add other mechanism in the future
                    self.market.syndicates[int(lead_syndicate_id)].pay_claim(broker_id, catastrophe_category, lead_claim_value)
                    self.market.syndicates[int(lead_syndicate_id)].current_capital -= lead_claim_value
                    self.market.syndicates[int(lead_syndicate_id)].current_capital_category[catastrophe_category] -= lead_claim_value
                    self.market.brokers[broker_id].receive_claim(lead_syndicate_id, catastrophe_category, lead_claim_value, lead_claim_value)
                    #expected_profit, acceptable_by_category, cash_left_by_categ, var_per_risk_per_categ, self.excess_capital = self.riskmodel.evaluate(underwritten_risks, self.cash)
                else:
                    self.market.syndicates[int(lead_syndicate_id)].pay_claim(broker_id, catastrophe_category, lead_claim_value)
                    self.market.syndicates[int(lead_syndicate_id)].current_capital -= lead_claim_value
                    self.market.syndicates[int(lead_syndicate_id)].current_capital_category[catastrophe_category] -= lead_claim_value
                    self.market.brokers[broker_id].receive_claim(lead_syndicate_id, catastrophe_category, lead_claim_value, self.market.syndicates[int(lead_syndicate_id)].current_capital)
                    self.market.syndicates[int(lead_syndicate_id)].bankrupt() 
                    self.market.syndicates[int(lead_syndicate_id)].excess_capital = 0
                for follow_num in range(len(follow_syndicates_id)): 
                    if follow_syndicates_id[follow_num] != None: 
                        if self.market.syndicates[int(follow_syndicates_id[follow_num])].current_capital >= follow_claim_values[follow_num]:
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].pay_claim(broker_id, catastrophe_category, follow_claim_values[follow_num])
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].current_capital -= follow_claim_values[follow_num]
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].current_capital_category[catastrophe_category] -= follow_claim_values[follow_num]
                            self.market.brokers[broker_id].receive_claim(follow_syndicates_id[follow_num], catastrophe_category, follow_claim_values[follow_num], follow_claim_values[follow_num])
                        else:
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].pay_claim(broker_id, catastrophe_category, follow_claim_values[follow_num])
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].current_capital -= follow_claim_values[follow_num]
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].current_capital_category[catastrophe_category] -= follow_claim_values[follow_num]
                            self.market.brokers[broker_id].receive_claim(follow_syndicates_id[follow_num], catastrophe_category, follow_claim_values[follow_num], self.market.syndicates[int(follow_syndicates_id[follow_num])].current_capital)
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].bankrupt()
                            self.market.syndicates[int(follow_syndicates_id[follow_num])].excess_capital = 0

    def evolve(self, step_time):

        # Storage for all the syndicates' status
        syndicates_status = {}

        # The time the market will have after being evolved
        market_start_time = self.market.time
        market_end_time = self.market.time + step_time

        upcoming_catastrophe = [
            e.catastrophe_id for e in self.event_handler.upcoming_catastrophe.values() if isinstance(e, AddCatastropheEvent)
        ]

        upcoming_attritional_loss = [
            e.risk_id for e in self.event_handler.upcoming_attritional_loss.values() if isinstance(e, AddAttritionalLossEvent)
        ]

        upcoming_broker_risk = [
            e.risk_id for e in self.event_handler.upcoming_broker_risk.values() if isinstance(e, AddRiskEvent)
        ]

        upcoming_broker_premium = [
            e.risk_id for e in self.event_handler.upcoming_broker_premium.values() if isinstance(e, AddPremiumEvent)
        ]

        upcoming_broker_claim = [
            e.risk_id for e in self.event_handler.upcoming_broker_claim.values() if isinstance(e, AddClaimEvent)
        ]

        # Enact the events
        self.event_handler.forward(self.market, step_time)

        # Track any newly-added broker_risk events
        newly_added_broker_risk_events = {
            e.risk_id: e.risk_start_time
            for e in self.event_handler.completed_broker_risk.values()
            if isinstance(e, AddRiskEvent) and (e.risk_id in upcoming_broker_risk)
        }

        broker_risk_event_start_times = np.array(
            [
                newly_added_broker_risk_events.get(risk_id)
                for risk_id in upcoming_broker_risk
                if newly_added_broker_risk_events.get(risk_id) != None
            ]
        )

        # Get the unique start times and sort
        sorted_broker_risk_start_times = np.sort(np.unique(broker_risk_event_start_times))

        # Update all the agents, run the event at the same start time
        for start_time in sorted_broker_risk_start_times:
            # Move along the market's time
            self.market.time = start_time
            # Get all the events starting at this time
            starting_broker_risk = []
            for i in range(len(self.broker_risk_events)):
                if self.broker_risk_events[i].risk_start_time == start_time-2:
                    starting_broker_risk.append(self.broker_risk_events[i])

            # Move along the corresponding syndicates
            self.evolve_action_market(starting_broker_risk)

            # Empty all the actions to apply to syndicates
            self.actions_to_apply = []

        # Track any newly-added attritional_loss events and execute
        newly_added_attritional_loss_events = {
            e.risk_id: e.risk_start_time
            for e in self.event_handler.completed_attritional_loss.values()
            if isinstance(e, AddAttritionalLossEvent) and (e.risk_id in upcoming_attritional_loss)
        }
        attritional_loss_event_start_times = np.array(
            [
                newly_added_attritional_loss_events.get(risk_id)
                for risk_id in upcoming_attritional_loss
                if newly_added_attritional_loss_events.get(risk_id) != None
            ]
        )
        sorted_attritional_loss_start_times = np.sort(np.unique(attritional_loss_event_start_times))
        follow_id = []
        for start_time in sorted_attritional_loss_start_times:
            starting_attritional_loss = None
            for i in range(len(self.attritional_loss_events)):
                if self.attritional_loss_events[i].risk_start_time == start_time:
                    starting_attritional_loss = self.attritional_loss_events[i]
            self.run_attritional_loss(starting_attritional_loss)

        # Track any newly-added broker_premium events and execute
        newly_added_broker_premium_events = {
            e.risk_id: e.risk_start_time
            for e in self.event_handler.completed_broker_premium.values()
            if isinstance(e, AddPremiumEvent) and (e.risk_id in upcoming_broker_premium)
        }
        broker_premium_event_start_times = np.array(
            [
                newly_added_broker_premium_events.get(risk_id)
                for risk_id in upcoming_broker_premium
                if newly_added_broker_premium_events.get(risk_id) != None
            ]
        )
        sorted_broker_premium_start_times = np.sort(np.unique(broker_premium_event_start_times))
        for start_time in sorted_broker_premium_start_times:
            starting_broker_premium = None
            for i in range(len(self.broker_premium_events)):
                if self.broker_premium_events[i].risk_start_time == start_time:
                    starting_broker_premium = self.broker_premium_events[i]
            self.run_broker_premium(starting_broker_premium)

        # Track any newly-added broker_claim events and execute
        newly_added_broker_claim_events = {
            e.risk_id: e.risk_start_time
            for e in self.event_handler.completed_broker_claim.values()
            if isinstance(e, AddClaimEvent) and (e.risk_id in upcoming_broker_claim)
        }
        broker_claim_event_start_times = np.array(
            [
                newly_added_broker_claim_events.get(risk_id)
                for risk_id in upcoming_broker_claim
                if newly_added_broker_claim_events.get(risk_id) != None
            ]
        )
        sorted_broker_claim_start_times = np.sort(np.unique(broker_claim_event_start_times))
        for start_time in sorted_broker_claim_start_times:
            starting_broker_claim = None
            for i in range(len(self.broker_claim_events)):
                if self.broker_claim_events[i].risk_start_time == start_time:
                    starting_broker_claim = self.broker_claim_events[i]
            self.run_broker_claim(starting_broker_claim)        

        # Track any newly-added catastrophe events and execute
        newly_added_catastrophe_events = {
            e.catastrophe_id: e.catastrophe_start_time
            for e in self.event_handler.completed_catastrophe.values()
            if isinstance(e, AddCatastropheEvent) and (e.catastrophe_id in upcoming_catastrophe)
        }
        catastrophe_event_start_times = np.array(
            [
                newly_added_catastrophe_events.get(catastrophe_id)
                for catastrophe_id in upcoming_catastrophe
                if newly_added_catastrophe_events.get(catastrophe_id) != None
            ]
        )
        sorted_catastrophe_start_times = np.sort(np.unique(catastrophe_event_start_times))
        for start_time in sorted_catastrophe_start_times:
            starting_catastrophe = None
            for i in range(len(self.catastrophe_events)):
                if self.catastrophe_events[i].catastrophe_start_time == start_time:
                    starting_catastrophe = self.catastrophe_events[i]
            self.run_catastrophe(starting_catastrophe)

        self.market.time = market_end_time

    def receive_actions(self, actions):

        # Choose the leader and save its action, the first syndicate with the highest line size wins 
        # TODO: will add selection algorithm in the future
        num_risk = len(actions)
        accept_actions = [[] for x in range(num_risk)]
        min_premium = [0 for x in range(num_risk)]
        lead_syndicate_id = [0 for x in range(num_risk)]
        sum_line_size = [0 for x in range(num_risk)]
        # Accept the quote
        syndicate_list = [[] for x in range(num_risk)]
        lead_line_size = 0.5
        follow_line_size = 0.1
        for num in range(num_risk):
            # Find the leader
            for sy in range(len(self.market.syndicates)):
                if actions[num][sy].premium != 0:
                    min_premium[num] = actions[num][sy].premium
                    lead_syndicate_id[num] = sy
            for sy_new in range(len(self.market.syndicates)):
                if (actions[num][sy_new].premium != 0) and (actions[num][sy_new].premium < min_premium[num]):
                    min_premium[num] = actions[num][sy_new].premium
                    lead_syndicate_id[num] = sy_new
            syndicate_list[num].append(lead_syndicate_id[num])
            accept_actions[num].append(actions[num][lead_syndicate_id[num]])
            sum_line_size[num] += lead_line_size
            # Sort the premium
            premium_sort = []
            for sy in range(len(self.market.syndicates)):
                premium_sort.append(actions[num][sy].premium)
            premium_sort = np.array(premium_sort)
            premium_sort = np.sort(premium_sort)
            # Assign line size to the rest syndicates, min premium win
            rest_line_size = 1 - sum_line_size[num]
            for p in range(len(premium_sort)):
                for sy in range(len(self.market.syndicates)):
                    if (rest_line_size > 0) and (sy not in syndicate_list[num]) and (actions[num][sy].premium == premium_sort[p]) and (premium_sort[p] != 0):
                        rest_line_size -= follow_line_size
                        accept_actions[num].append(actions[num][sy])
                        syndicate_list[num].append(sy)
                    
        # Save Actions to issue
        self.actions_to_apply = accept_actions

In [None]:
### 5. Create Multi-agent Environment to get access to the market performance

import gym
import numpy as np
import scipy
from environment.event_generator import EventGenerator
from manager.ai_model.action import Action
from manager import EventHandler
#MarketManager
from logger import logger
from environment.risk_model import RiskModel
from environment.environment import SpecialtyInsuranceMarketEnv

class MultiAgentBasedModel(SpecialtyInsuranceMarketEnv):

    def __init__(self, sim_args, manager_args, broker_args, syndicate_args, reinsurancefirm_args, shareholder_args, risk_args, 
                 brokers, syndicates, reinsurancefirms, shareholders, catastrophes, broker_risks, fair_market_premium,
                 risk_model_configs, with_reinsurance, num_risk_models, logger, dt = 1):
        self.sim_args = sim_args
        self.maxstep = self.sim_args["max_time"]
        self.manager_args = manager_args
        self.broker_args = broker_args
        self.syndicate_args = syndicate_args
        self.reinsurancefirm_args = reinsurancefirm_args
        self.shareholder_args = shareholder_args
        self.risk_args = risk_args
        self.brokers = brokers
        self.syndicates = syndicates
        self.reinsurancefirms = reinsurancefirms
        self.shareholders = shareholders
        self.catastrophes = catastrophes
        self.broker_risks = broker_risks
        self.fair_market_premium =fair_market_premium
        self.initial_catastrophes = catastrophes
        self.risk_model_configs = risk_model_configs
        self.with_reinsurance = with_reinsurance
        self.num_risk_models = num_risk_models
        self.logger = logger
        self.dt = dt
        self.mm = None
        self.event_handler = None

        # Active syndicate list
        self.syndicate_active_list = []
        # Initialise events, actions, and states 
        self.attritional_loss_events = []
        self.catastrophe_events = []
        self.broker_risk_events = []
        self.broker_premium_events = []
        self.broker_claim_events = []
        self.action_map_dict = {}
        self.state_encoder_dict = {}

        # Define Action Space, Define Observation Space
        self.n = len(self.syndicates)
        self.agents = {self.syndicates[i].syndicate_id for i in range(self.n)} 
        self._agent_ids = set(self.agents)
        self.dones = set()
        self._spaces_in_preferred_format = True
        self.observation_space = gym.spaces.Dict({
            self.syndicates[i].syndicate_id: gym.spaces.Box(low=np.array([-30000000,-30000000,-30000000,-30000000]), 
                                                     high=np.array([30000000,30000000,30000000,30000000]), dtype = np.float32) for i in range(self.n)
        })
        self.action_space = gym.spaces.Dict({
            self.syndicates[i].syndicate_id: gym.spaces.Box(0.5, 0.9, dtype = np.float32) for i in range(self.n)})

        super(MultiAgentBasedModel, self).__init__(sim_args = self.sim_args, 
                                                   manager_args = self.manager_args,
                                                   broker_args = self.broker_args, 
                                                   syndicate_args = self.syndicate_args, 
                                                   reinsurancefirm_args = self.reinsurancefirm_args, 
                                                   shareholder_args = self.shareholder_args, 
                                                   risk_args = self.risk_args, 
                                                   brokers = self.brokers, 
                                                   syndicates = self.syndicates, 
                                                   reinsurancefirms = self.reinsurancefirms, 
                                                   shareholders = self.shareholders, 
                                                   catastrophes = self.catastrophes, 
                                                   broker_risks = self.broker_risks,
                                                   fair_market_premium = self.fair_market_premium,
                                                   risk_model_configs = self.risk_model_configs, 
                                                   with_reinsurance = self.with_reinsurance, 
                                                   num_risk_models = self.num_risk_models,
                                                   logger = self.logger,
                                                   dt = 1)

        # Log data
        self.cumulative_bankruptcies = 0
        self.cumulative_market_exits = 0
        self.cumulative_unrecovered_claims = 0.0
        self.cumulative_claims = 0.0
        self.total_cash = 0.0
        self.total_excess_capital = 0.0
        self.total_profitslosses =  0.0
        self.total_contracts = 0.0
        self.operational_syndicates = 0.0
        # Reset the environmnet
        self.reset()

    def reset(self, seed = None, options = None):
        super().reset(seed = seed)
        
        # Broker risk event daily: TODO: broker generate risk according to poisson distribution
        # Catastrophe event 
        self.catastrophe_events = EventGenerator(self.risk_model_configs).generate_catastrophe_events(self.catastrophes)
        # Attritioal loss event daily
        self.attritional_loss_events = EventGenerator(self.risk_model_configs).generate_attritional_loss_events(self.sim_args, self.broker_risks)
        # Broker risk event daily: TODO: broker generate risk according to poisson distribution
        self.broker_risk_events = EventGenerator(self.risk_model_configs).generate_risk_events(self.sim_args, self.broker_risks)
        # Broker pay premium according to underwritten contracts
        self.broker_premium_events = EventGenerator(self.risk_model_configs).generate_premium_events(self.sim_args)
        # Broker ask for claim if the contract reaches the end time
        self.broker_claim_events = EventGenerator(self.risk_model_configs).generate_claim_events(self.sim_args)
        # Initiate event handler
        self.event_handler = EventHandler(self.maxstep, self.catastrophe_events, self.attritional_loss_events, self.broker_risk_events, self.broker_premium_events, self.broker_claim_events)
        # Initiate market manager
        self.mm = MarketManager(self.maxstep, self.sim_args, self.manager_args, self.brokers, self.syndicates, self.reinsurancefirms, self.shareholders, self.catastrophes, self.fair_market_premium,
                                self.risk_model_configs, self.with_reinsurance, self.num_risk_models, self.catastrophe_events, self.attritional_loss_events, 
                                self.broker_risk_events, self.broker_premium_events, self.broker_claim_events, self.event_handler)
        self.mm.evolve(self.dt)
        
        # Set per syndicate active status and build status list
        self.syndicate_active_list = []   # Store syndicates currently in the market
        for sy in range(len(self.mm.market.syndicates)):
            if self.mm.market.syndicates[sy].status == True:
                self.syndicate_active_list.append(self.mm.market.syndicates[sy].syndicate_id)

        # Create action map and state list
        info_dict = {}
        for sy in range(len(self.mm.market.syndicates)):
            self.action_map_dict[self.mm.market.syndicates[sy].syndicate_id] = self.action_map_creator(self.mm.market.syndicates[sy], 0, self.broker_risk_events[0])
            self.state_encoder_dict[self.mm.market.syndicates[sy].syndicate_id] = self.state_encoder(self.mm.market.syndicates[sy].syndicate_id)
            info_dict[self.mm.market.syndicates[sy].syndicate_id] = None

        for i in range(len(self.mm.market.brokers)):
            self.mm.market.brokers[i].underwritten_contracts = []
            self.mm.market.brokers[i].not_underwritten_risks = []
            self.mm.market.brokers[i].not_paid_claims = []
        for i in range(len(self.mm.market.syndicates)):
            self.mm.market.syndicates[i].current_hold_contracts = []
            self.mm.market.syndicates[i].current_capital = self.syndicate_args["initial_capital"]
            self.mm.market.syndicates[i].current_capital_category = [self.syndicate_args["initial_capital"]/self.risk_args["num_categories"] for i in range(self.risk_args["num_categories"])]
        # Reset broker and syndicates variables
        for i in range(len(self.mm.market.syndicates)):
            self.mm.market.syndicates[i].reset_pl()
        # Log data
        self.cumulative_bankruptcies = 0
        self.cumulative_market_exits = 0
        self.cumulative_unrecovered_claims = 0.0
        self.cumulative_claims = 0.0
        
        # Initiate time step
        self.timestep = -1
        self.step_track = 0

        return self.state_encoder_dict, info_dict
    
    def adjust_market_premium(self, capital):
        """Adjust_market_premium Method.
               Accepts arguments
                   capital: Type float. The total capital (cash) available in the insurance market (insurance only).
               No return value.
           This method adjusts the premium charged by insurance firms for the risks covered. The premium reduces linearly
           with the capital available in the insurance market and viceversa. The premium reduces until it reaches a minimum
           below which no insurer is willing to reduce further the price. """
        self.market_premium = self.fair_market_premium * (self.syndicate_args["upper_premium_limit"] 
                                                   - self.syndicate_args["premium_sensitivity"] 
                                                   * capital / (self.syndicate_args["initial_capital"] 
                                                   * self.risk_model_configs[0]["damage_distribution"].mean() * self.risk_args["num_risks"]))
        if self.market_premium < self.fair_market_premium * self.syndicate_args["lower_premium_limit"]:
            self.market_premium = self.fair_market_premium * self.syndicate_args["lower_premium_limit"]
        return self.market_premium 

    def get_mean(self,x):
        return sum(x) / len(x)
    
    def get_mean_std(self, x):
        m = self.get_mean(x)
        variance = sum((val - m) ** 2 for val in x)
        return m, np.sqrt(variance / len(x))

    def balanced_portfolio(self, syndicate_id, risk, cash_left_by_categ, var_per_risk): #This method decides whether the portfolio is balanced enough to accept a new risk or not. If it is balanced enough return True otherwise False.
                                                                          #This method also returns the cash available per category independently the risk is accepted or not.
        cash_reserved_by_categ = self.mm.market.syndicates[syndicate_id].current_capital - cash_left_by_categ     #Here it is computed the cash already reserved by category

        _, std_pre = self.get_mean_std(cash_reserved_by_categ)

        cash_reserved_by_categ_store = np.copy(cash_reserved_by_categ)

        cash_reserved_by_categ_store[risk.risk_category] += var_per_risk[risk.risk_category] #Here it is computed how the cash reserved by category would change if the new insurance risk was accepted

        mean, std_post = self.get_mean_std(cash_reserved_by_categ_store)     #Here it is computed the mean, std of the cash reserved by category after the new risk of reinrisk is accepted

        total_cash_reserved_by_categ_post = sum(cash_reserved_by_categ_store)

        if (std_post * total_cash_reserved_by_categ_post/self.mm.market.syndicates[syndicate_id].current_capital) <= (self.mm.market.syndicates[syndicate_id].balance_ratio * mean) or std_post < std_pre:      #The new risk is accepted is the standard deviation is reduced or the cash reserved by category is very well balanced. (std_post) <= (self.balance_ratio * mean)
            for i in range(len(cash_left_by_categ)):                                                                           #The balance condition is not taken into account if the cash reserve is far away from the limit. (total_cash_employed_by_categ_post/self.cash <<< 1)
                cash_left_by_categ[i] = self.mm.market.syndicates[syndicate_id].current_capital - cash_reserved_by_categ_store[i]

            return True, cash_left_by_categ
        else:
            for i in range(len(cash_left_by_categ)):
                cash_left_by_categ[i] = self.mm.market.syndicates[syndicate_id].current_capital - cash_reserved_by_categ[i]

            return False, cash_left_by_categ


    def process_newrisks_insurer(self, new_risks, syndicate_id, acceptable_by_category, var_per_risk_per_categ, cash_left_by_categ, time): #This method processes one by one the risks contained in risks_per_categ in order to decide whether they should be underwritten or not
        accept = []
        for categ_id in range(len(acceptable_by_category)):    #Here we take only one risk per category at a time to achieve risk[C1], risk[C2], risk[C3], risk[C4], risk[C1], risk[C2], ... if possible.
            if acceptable_by_category[categ_id] > 0:
                for i in range(len(new_risks)):
                    if new_risks[i].risk_category == categ_id:
                        risk_to_insure = new_risks[i]
                        [condition, cash_left_by_categ] = self.balanced_portfolio(syndicate_id, risk_to_insure, cash_left_by_categ, var_per_risk_per_categ)   
                        if condition:
                            accept.append(True)
                            acceptable_by_category[categ_id] -= 1  # TODO: allow different values per risk (i.e. sum over value (and reinsurance_share) or exposure instead of counting)
                        else:
                            accept.append(False)

        return accept # This list store the accept decision for each risk by this syndicate at this time, its size varied 
    
    def get_actions(self, time):
        new_risks = []
        for risk in range(len(self.broker_risk_events)):
            if self.broker_risk_events[risk].risk_start_time == time:
                new_risks.append(self.broker_risk_events[risk])
        action_dict = [{} for x in range(len(new_risks))]
        for i in range(len(self.mm.market.syndicates)):
            expected_profits, acceptable_by_category, cash_left_by_categ, var_per_risk_per_categ, self.excess_capital  = self.mm.market.syndicates[i].riskmodel.evaluate(self.mm.market.syndicates[i].current_hold_contracts, self.mm.market.syndicates[i].current_capital)
            accept = self.process_newrisks_insurer(new_risks, i, acceptable_by_category, var_per_risk_per_categ, cash_left_by_categ, time)
            if len(self.mm.market.syndicates[i].current_hold_contracts) == 0: 
                # Syndicates compete for the ledership, they will all cover 0.5
                sum_capital = sum([self.mm.market.syndicates[k].current_capital for k in range(len(self.mm.market.syndicates))]) 
                market_premium = self.adjust_market_premium(capital=sum_capital) * 1000
                for num in range(len(new_risks)):
                    action_dict[num].update({self.mm.market.syndicates[i].syndicate_id: market_premium})
            else: 
                for num in range(len(new_risks)):
                    if accept[num]:
                        sum_capital = sum([self.mm.market.syndicates[k].current_capital for k in range(len(self.mm.market.syndicates))]) 
                        market_premium = self.adjust_market_premium(capital=sum_capital) * 1000
                        action_dict[num].update({self.mm.market.syndicates[i].syndicate_id: market_premium})
                    else:
                        action_dict[num].update({self.mm.market.syndicates[i].syndicate_id: 0})
        return action_dict
        
    def step(self, action_dict):

        obs_dict, reward_dict, terminated_dict, info_dict = {}, {}, {}, {}
        flag_dict = {}

        # Update environemnt after actions
        new_risk = []
        for risk in range(len(self.broker_risk_events)):
            if self.broker_risk_events[risk].risk_start_time == self.timestep+1:
                new_risk.append(self.broker_risk_events[risk])
        parsed_actions = [[] for x in range(len(new_risk))]  
        for l in range(len(new_risk)):
            for syndicate_id, action in action_dict[l].items():
                # update action map
                self.action_map = self.action_map_creator(self.mm.market.syndicates[int(syndicate_id)], action, new_risk[l]) 
                parsed_ac2add = self.action_map
                parsed_actions[l].append(parsed_ac2add)
        
        self.send_action2env(parsed_actions)

        # Evolve the market
        self.mm.evolve(self.dt)
        
        self.timestep += 1

        # Compute rewards and get next observation
        for l in range(len(new_risk)):
            for syndicate_id, action in action_dict[l].items():
                reward_dict[syndicate_id] = self.compute_reward(action, syndicate_id)
                obs_dict[syndicate_id]= self.state_encoder(syndicate_id)
                info_dict[syndicate_id] = {}
                flag_dict[syndicate_id] = False
                terminated_dict[syndicate_id] = self.check_termination(syndicate_id)
                if terminated_dict[syndicate_id]:
                    self.dones.add(syndicate_id)
        # Update plot 
        self.draw2file(self.mm.market)

        # All done termination check
        all_terminated = True
        for _, syndicate_terminated in terminated_dict.items():
            if syndicate_terminated is False:
                all_terminated = False
                break
        
        terminated_dict["__all__"] = all_terminated
        flag_dict["__all__"] = all_terminated

        return obs_dict, reward_dict, terminated_dict, flag_dict, info_dict

    def check_termination(self, syndicate_id):

        # Update per syndicate status, True-active in market, False-exit market becuase of no contract or bankruptcy
        market = self.mm.market
        sy = market.syndicates[int(syndicate_id)] 

        # The simulation is done when syndicates exit or bankrupt or reach the maximum time step
        if self.timestep >= self.maxstep-1:
            terminated = True
        else:
            terminated = False

        return terminated

    def compute_reward(self, action, syndicate_id):

        market = self.mm.market
        # calculate reward function
        r = [0.0] * 4

        # For each insurable risk being accepted +1 or refused -1
        if(self.timestep <= self.maxstep):
            for broker_id in range(len(market.brokers)):
                for risk in range(len(market.brokers[broker_id].risks)):
                    for contract in range(len(market.brokers[broker_id].underwritten_contracts)):
                        if market.brokers[broker_id].risks[risk]["risk_id"] == market.brokers[broker_id].underwritten_contracts[contract]["risk_id"]:
                            r[0] += 1
                        else:
                            r[0] -= 1

        # For each claim being paied +1 or refused -1
        if(self.timestep <= self.maxstep):
            for claim in range(len(market.syndicates[int(syndicate_id)].paid_claim)):
                if market.syndicate[syndicate_id].paid_claim[claim]["status"] == True:
                    r[1] += 1
                else:
                    r[1] -= 1

        # Profit and Bankruptcy       
        if(self.timestep <= self.maxstep):
            initial_capital = market.syndicates[int(syndicate_id)].initial_capital
            current_capital = market.syndicates[int(syndicate_id)].current_capital
            r[2] += current_capital - initial_capital
            if (current_capital - initial_capital) < 0:
                r[3] -= 10000

        # Sum reward
        reward = r[0] + r[1] + r[2] + r[3]

        return reward     

    def send_action2env(self, parsed_actions):               
            
        # Apply action
        if len(parsed_actions) > 0:
            self.mm.receive_actions(actions=parsed_actions) 
    
    def state_encoder(self, syndicate_id):
        
        ### Observation Space:             
        obs = []
        #for risk in range(len(self.broker_risk_events)):
            #if self.broker_risk_events[risk].risk_start_time == self.timestep+1:
                # Catastrophe risk category and risk value
                #obs.append(self.broker_risk_events[risk].risk_category)
                #obs.append(self.broker_risk_events[risk].risk_value)
                #obs.append(self.broker_risk_events[risk].risk_factor)
                #break   # Just for the game version, if AI considered, it needs to fix the size of the obs
        
        # Syndicates status current capital in 
        market = self.mm.market
        for num in range(len(market.syndicates[int(syndicate_id)].current_capital_category)):
            obs.append(market.syndicates[int(syndicate_id)].current_capital_category[num])
            
        return obs

    def action_map_creator(self, syndicate, premium, new_risk):

        action_map = Action(syndicate.syndicate_id, premium, new_risk.risk_id, new_risk.broker_id)
       
        return action_map
    
    def save_data(self):
        """Method to collect statistics about the current state of the simulation. Will pass these to the 
           Logger object (self.logger) to be recorded."""
        # Collect data
        self.total_cash = sum([self.mm.market.syndicates[i].current_capital for i in range(len(self.mm.market.syndicates))])
        self.total_excess_capital = sum([self.mm.market.syndicates[i].excess_capital for i in range(len(self.mm.market.syndicates))])
        self.total_profitslosses =  sum([self.mm.market.syndicates[i].profits_losses for i in range(len(self.mm.market.syndicates))])
        self.total_contracts = sum([len(self.mm.market.syndicates[i].current_hold_contracts) for i in range(len(self.mm.market.syndicates))])
        self.operational_syndicates = sum([self.mm.market.syndicates[i].status for i in range(len(self.mm.market.syndicates))])
        #operational_catbonds = sum([catbond.operational for catbond in self.catbonds])
        
        # Collect agent-level data
        syndicates_data = [(self.mm.market.syndicates[i].current_capital, 
                            self.mm.market.syndicates[i].syndicate_id, 
                            self.mm.market.syndicates[i].status) for i in range(len(self.mm.market.syndicates))]
        
        # Update cumulative information
        for i in range(len(self.mm.market.syndicates)):
            if self.mm.market.syndicates[i].status == False:
                self.cumulative_bankruptcies += 1
            if self.mm.market.syndicates[i].current_capital < self.mm.market.syndicates[i].capital_permanency_limit:         #If their level of cash is so low that they cannot underwrite anything they also leave the market.
                self.cumulative_market_exits += 1  # TODO: update the syndicates list becuase of market exit
            for j in range(len(self.mm.market.syndicates[i].current_hold_contracts)):
                if self.mm.market.syndicates[i].current_hold_contracts[j]["pay"] == False:
                    self.cumulative_unrecovered_claims += self.mm.market.syndicates[i].current_hold_contracts[j]["risk_value"]
                elif self.mm.market.syndicates[i].current_hold_contracts[j]["pay"] == True:
                    self.cumulative_claims += self.mm.market.syndicates[i].current_hold_contracts[j]["risk_value"]
        
        # Prepare dict
        current_log = {}
        current_log['total_cash'] = self.total_cash
        current_log['total_excess_capital'] = self.total_excess_capital
        current_log['total_profits_losses'] = self.total_profitslosses
        current_log['total_contracts'] = self.total_contracts
        current_log['total_operational'] = self.operational_syndicates
        #current_log['total_catbondsoperational'] = catbondsoperational_no
        current_log['market_premium'] = self.market_premium  # Oxford Picing has a fair premium and adjust TODO:  
        current_log['cumulative_bankruptcies'] = self.cumulative_bankruptcies
        current_log['cumulative_market_exits'] = self.cumulative_market_exits
        current_log['cumulative_unrecovered_claims'] = self.cumulative_unrecovered_claims
        current_log['cumulative_claims'] = self.cumulative_claims    #Log the cumulative claims received so far.
        
        # Add agent-level data to dict
        current_log['insurance_firms_cash'] = syndicates_data
        
        current_log['individual_contracts'] = []
        individual_contracts_no = [len(self.mm.market.syndicates[i].current_hold_contracts) for i in range(len(self.mm.market.syndicates))]
        for i in range(len(individual_contracts_no)):
            current_log['individual_contracts'].append(individual_contracts_no[i])

        # Call to Logger object
        self.logger.record_data(current_log)

    def obtain_log(self, requested_logs=None):
        #This function allows to return in a list all the data generated by the model. There is no other way to transfer it back from the cloud.
        return self.logger.obtain_log(requested_logs)
        

In [None]:
### 6. Register environment and train the model

import gymnasium as gym
import numpy as np
import ray
from ray.tune.registry import register_env
from ray import air, tune
from ray.rllib.algorithms.ppo import PPO
from ipywidgets import IntProgress
from gym.spaces import Box
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.examples.policy.random_policy import RandomPolicy

log = logger.Logger(risk_args["num_riskmodels"], catastrophes, catastrophe_damage, brokers, syndicates)
insurance_args = {"sim_args": sim_args,
                "manager_args": manager_args,
                "broker_args": broker_args,
                "syndicate_args": syndicate_args,
                "reinsurancefirm_args": reinsurancefirm_args,
                "shareholder_args": shareholder_args,
                "risk_args": risk_args,
                "brokers": brokers,
                "syndicates": syndicates,
                "reinsurancefirms": reinsurancefirms,
                "shareholders": shareholders,
                "catastrophes": catastrophes,
                "broker_risks": broker_risks,
                "fair_market_premium": fair_market_premium,
                "risk_model_configs": risk_model_configs,
                "with_reinsurance": with_reinsurance,
                "num_risk_models": num_risk_models,
                "logger": log}

def env_creator(env_config):
    return MultiAgentBasedModel(**env_config)

def policy_mapping_fn(agent_id, episode, worker, **kwargs):
    # agent0 -> main0
    # agent1 -> main1
    return f"main{agent_id[-1]}"

def ppo_trainer_creator(insurance_args):
    
    config = {
        "env": "SpecialtyInsuranceMarket-validation",
        "framework": "tf",
        "multi_agent": {"policies":{
                # The Policy we are actually learning.
                "main0": PolicySpec(
                    observation_space=gym.spaces.Box(low=np.array([-10000000,-10000000,-30000000,-30000000,-30000000,-30000000]), 
                                                     high=np.array([10000000,10000000,30000000,30000000,30000000,30000000]), dtype = np.float32),
                    action_space=gym.spaces.Box(0.5, 0.9, dtype = np.float32)
                ),
                "main1": PolicySpec(
                    observation_space=gym.spaces.Box(low=np.array([-10000000,-10000000,-30000000,-30000000,-30000000,-30000000]), 
                                                     high=np.array([10000000,10000000,30000000,30000000,30000000,30000000]), dtype = np.float32),
                    action_space=gym.spaces.Box(0.5, 0.9, dtype = np.float32)
                ),
                "random": PolicySpec(policy_class=RandomPolicy),
            }, 
                        "policy_mapping_fn": policy_mapping_fn,
                        "policies_to_train":["main0"],
        },
        "observation_space": gym.spaces.Box(low=np.array([-10000000,-10000000,-30000000,-30000000,-30000000,-30000000]), 
                                            high=np.array([10000000,10000000,30000000,30000000,30000000,30000000]), dtype = np.float32),
        "action_space": gym.spaces.Box(0.5, 0.9, dtype = np.float32),
        "env_config": insurance_args,
        "evaluation_interval": 2,
        "evaluation_duration": 20,
    }
    
    trainer = PPO(config=config)
    return trainer

# Folder for recording
top_dir = "noreinsurance_" + "_model_" + str(num_risk_models)

# Register environment
register_env("SpecialtyInsuranceMarket-validation", env_creator)

# The number of training iteration for the RL agent
num_training = 10

trainer = ppo_trainer_creator(insurance_args)
# Number of training iterations

"""
for n in range(num_training):
    # Create a path to store the trained agent for each iteration
    model_filepath = f"{top_dir}/{str(n)}/saved_models"
        
    num_episode = 10

    # A training iteration includes parallel sample collection by the environment workers 
    # as well as loss calculation on the collected batch and a model update.

    bar = IntProgress(min=0, max=num_episode)
    display(bar)
    list_mean_rewards = []
    list_min_rewards = []
    list_max_rewards = []
    list_train_step = []

    for i in range(num_episode):
        trainer.train()     
        print("Progress:", i+1, "/", num_episode, end="\r")
        bar.value += 1
        if (i+1) % 2 == 0:
            list_mean_rewards.append(trainer.evaluation_metrics["evaluation"]["episode_reward_mean"])
            list_min_rewards.append(trainer.evaluation_metrics["evaluation"]["episode_reward_min"])
            list_max_rewards.append(trainer.evaluation_metrics["evaluation"]["episode_reward_max"])
            list_train_step.append(i+1)
        if i % 10 == 0:
            trainer.save(model_filepath)
"""  
    

# Can be used for game model
env = MultiAgentBasedModel(**insurance_args)
    
total_steps = 0
terminated_dict = {"__all__": False}
    
obs_dict, info_dict = env.reset()

while not terminated_dict["__all__"]:        
    action_dict = env.get_actions(total_steps)  
    total_steps += 1
    
    obs_dict, reward_dict, terminated_dict, flag_dict, info_dict = env.step(action_dict)


In [None]:

### 7. Test the trained model performance

def trainer_restore(self, top_dir, n):
    if n <= 9:
        path0 = top_dir
        path1 = str(n-1)
        path2 = "saved_models"
        path3 = "checkpoint_"+str(0)+str(0)+str(0)+str(0)+str(0)+str(n)
        path4 = "rllib_checkpoint.json"
    elif 9 < n <= 99:
        path0 = top_dir
        path1 = str(n-1)
        path2 = "saved_models"
        path3 = "checkpoint_"+str(0)+str(0)+str(0)+str(0)+str(n)
        path4 = "rllib_checkpoint.json"
    elif 99 < n <= 999:
        path0 = top_dir
        path1 = str(n-1)
        path2 = "saved_models"
        path3 = "checkpoint_"+str(0)+str(0)+str(0)+str(n)
        path4 = "rllib_checkpoint.json"

    # Join various path components
    self.trainer.restore(os.path.join(path0, path1, path2, path3, path4))

insurance_args = {"sim_args": sim_args,
                "manager_args": manager_args,
                "broker_args": broker_args,
                "syndicate_args": syndicate_args,
                "reinsurancefirm_args": reinsurancefirm_args,
                "shareholder_args": shareholder_args,
                "risk_args": risk_args,
                "brokers": brokers,
                "syndicates": syndicates,
                "reinsurancefirms": reinsurancefirms,
                "shareholders": shareholders,
                "catastrophes": catastrophes,
                "broker_risks": broker_risks,
                "fair_market_premium": fair_market_premium,
                "risk_model_configs": risk_model_configs,
                "with_reinsurance": with_reinsurance,
                "num_risk_models": num_risk_models}

validation_episodes = 1
all_rewards = {}
        
for epi in range(validation_episodes):
    env = MultiAgentBasedModel(**insurance_args)
    
    total_steps = 0
    terminated_dict = {"__all__": False}
    all_rewards[epi] = {}
    
    obs_dict, info_dict = env.reset()
    
    while not terminated_dict["__all__"]:
        if total_steps % 20 == 0: print(".", end="")
        
        action_dict = trainer.compute_actions(obs_dict)  
        total_steps += 1
        
        obs_dict, reward_dict, terminated_dict, flag_dict, info_dict = env.step(action_dict)
        for k, v in reward_dict.items():
            if k not in all_rewards[epi]:
                all_rewards[epi][k] = [v]
            else:
                all_rewards[epi][k].append(v)




In [None]:
### Main function run the simulation, two syndicates will be chosen to compete for the leader position
from manager.ai_model.runner import AIRunner
from manager.game_model.runner import GameRunner

model = 0
if model == 0: 
    runner = AIRunner(sim_args, manager_args, brokers, syndicates, reinsurancefirms, shareholders, catastrophes, risk_model_configs, with_reinsurance, num_risk_models)
elif model == 1:
    runner = GameRunner(sim_args, manager_args, brokers, syndicates, reinsurancefirms, shareholders, catastrophes, risk_model_configs, with_reinsurance, num_risk_models)
runner.run()
