In [1]:
import mesa
import numpy as np
from scipy.stats import expon
from mesa.visualization.utils import update_counter
from mesa.visualization import SolaraViz, make_plot_component
import solara 
from matplotlib.figure import Figure

### Helper Functions

In [2]:
#Helper function for churn

def calculate_churn(model): 
    higher_out = 0
    higher_in = 0
    middle_in = 0
    middle_out = 0
    lower_out = 0
    lower_in = 0

    for agent in model.agents: 
        if agent.previous != agent.bracket: 
            if agent.previous == "Middle":
                middle_out += 1
                if agent.bracket == "Upper":
                    higher_in += 1
                elif agent.bracket == "Lower":
                    lower_in += 1
            if agent.previous == "Upper":
                higher_out += 1
                if agent.bracket == "Middle": 
                    middle_in += 1
                elif agent.bracket == "Lower":
                    lower_in += 1
            if agent.previous == "Lower":
                lower_out+=1
                if agent.bracket == "Upper":
                    higher_in += 1
                elif agent.bracket == "Middle":
                    middle_in += 1
    totals = model.agents.groupby("bracket").count()
    if "Upper" not in totals:
        totals["Upper"] = 0
    if "Middle" not in totals:
        totals["Middle"]=0
    if "Lower" not in totals: 
        totals["Lower"]=0
    return [higher_out, higher_in, middle_in, middle_out, lower_out, lower_in, totals]
                    
                
def number_to_words(n):
    # Define thresholds and corresponding words, adding "Trillion"
    thresholds = [
        (1_000_000_000_000, 'Trillion'),
        (1_000_000_000, 'Billion'),
        (1_000_000, 'Million'),
        (1_000, 'Thousand')
    ]
    
    # Loop through thresholds to determine appropriate scale
    for threshold, word in thresholds:
        if n >= threshold:
            value = n / threshold
            return f"{int(value)} {word}"
    
    # If less than 1000, just return the number as a string
    return str(n)   
    

### Visuals

In [5]:
@solara.component
def Histogram(model):
    update_counter.get() # This is required to update the counter
    # Note: you must initialize a figure using this method instead of
    # plt.figure(), for thread safety purpose
    fig = Figure()
    ax = fig.subplots()
    wealth_vals = [agent.wealth for agent in model.agents]
    # Note: you have to use Matplotlib's OOP API instead of plt.hist
    # because plt.hist is not thread-safe.
    ax.hist(wealth_vals, bins=10)
    solara.FigureMatplotlib(fig)

@solara.component
def Churn(model): 
    update_counter.get()

    fig = Figure()
    ax = fig.subplots()
    ax.axis('off')
    churn_details = calculate_churn(model)
    ax.text(0.2, 0.8, f"Upper Class: {churn_details[6]['Upper']} \nUpper In: {churn_details[1]} \nUpper Out: {churn_details[0]}", 
         va='center', ha='left', color="green", fontsize=15)
    ax.text(0.2,0.5,f"Middle Class {churn_details[6]['Middle']} \nMiddle In: {churn_details[2]} \nMiddle Out: {churn_details[3]}", 
             va='center', ha='left', color="orange", fontsize=15)
    ax.text(0.2, 0.2, f"Lower Class {churn_details[6]['Lower']} \nLower In: {churn_details[5]} \nLower Out: {churn_details[4]}", 
             va='center', ha='left', color="red", fontsize=15)
    solara.FigureMatplotlib(fig)
    
def compute_gini(model):
    agent_wealths = [abs(float(agent.wealth)) for agent in model.agents]
    x = sorted(agent_wealths)
    N = model.population
    B = sum(xi * (N - i) for i, xi in enumerate(x)) / (N * sum(x))
    return 1 + (1 / N) - 2 * B

@solara.component
def Wealth(model): 
    update_counter.get()
    fig = Figure()
    ax = fig.subplots()
    ax.axis('off')
    total_wealth = sum([agent.wealth for agent in model.agents])
    words = number_to_words(int(total_wealth))
    ax.text(0.0,0.8,f"Total Wealth Using {model.policy} Policy", 
             va='center', ha='left', color="black", fontsize=20)
    ax.text(0.5,0.5,f"{words}", 
             va='center', ha='left', color="black", fontsize=20)
    ax.text(0.0,0.2,f"{len(model.agents), sum([agent.wealth for agent in model.agents])}", 
             va='center', ha='left', color="black", fontsize=20)
    solara.FigureMatplotlib(fig)

def total_wealth(model): 
    return sum([agent.wealth for agent in model.agents])

### Agent

In [6]:
class WealthAgent(mesa.Agent):
    
    def __init__(self,model, proportion,innovation,party_elite):
        super().__init__(model)
        self.wealth=1
        self.party_elite = party_elite
        self.bracket = "Middle"
        self.previous = "Middle"
        self.W = proportion
        self.I = innovation
        self.ioriginal = innovation
        self.decay = 0
        
        
    def exchange(self): 

        return self.random.choice(self.model.agents)
    
    
    def step(self):
        
        """
                                PAYDAY
        """
        self.previous = self.bracket
        count = 0
        #increase wealth by proportion - payday
        self.wealth += (self.W*self.wealth)
        
        
        if self.wealth > model.survival_cost and self.wealth > 0: 
            self.wealth -= model.survival_cost
        else: 
            self.wealth -= self.wealth
        
        if model.policy=="fascist": 
            party_elites = self.model.agents.select(lambda a: a.party_elite==True)
            #pay tax to the party_elite
            party_elite = self.random.choice(party_elites)
            party_elite.wealth += self.wealth*.05
            self.wealth -= self.wealth*.05
                        
        exchange_agent = self.random.choice(self.model.agents)
        
        if self.wealth >= 0 and exchange_agent is not None and exchange_agent is not self:
            exchange_agent.wealth += (exchange_agent.W*self.wealth)
            self.wealth -= (exchange_agent.W*self.wealth)      

        #calculate bracket
        if self.wealth < model.brackets[0]:
            self.bracket = "Lower"
        elif self.wealth > model.brackets[1]:
            self.bracket = "Upper"
        else: 
            self.bracket = "Middle"
        
        '''
        
                            INNOVATION
      
           
        if self.model.innovation==True: 
            if self.wealth > self.model.total*model.threshold and self.I > 1.0: 
                #increase payday by innovation
                self.W*=self.I
                #Value of innovation decreases over time
                self.I-=self.decay #starts at 0
                #increase decay for next step 
                self.decay+=0.01
            else: 
                self.decay = 0 
                self.I = self.ioriginal
        '''  
       

### Model

In [16]:
def calc_brackets(model): 
        most = max([agent.wealth for agent in model.agents])
        most += most*.1
        #print(most)
        return [int(most/3), int((most/3)*2)]
    
class WealthModel(mesa.Model): 
    
    def __init__(self, policy="fascist", population=200, seed=None):
        super().__init__(seed=seed)
        self.policy = policy
        self.population = population
        self.party_elite = None
        self.total = 200
        self.survival_cost = 0
        self.each_wealth = 0
        self.brackets = [0.75,1.25]
        
    
        #self.schedule = mesa.time.RandomActivation(self)
        self.datacollector = mesa.DataCollector(model_reporters = {"Gini": compute_gini,"Total": total_wealth },
                                               agent_reporters={"Wealth":"wealth", "Bracket":"bracket","Pay":"W" })
        
        
        mean = 0.2      # mean of the distribution
        sigma = 0.05     # original sigma, so variance is 2*sigma^2
        variance = 2 * sigma**2

        # Generate 200 data points from the Gaussian distribution
        payday_array = self.rng.normal(mean, np.sqrt(variance), self.population)
        
        innovation_array = self.rng.normal(loc=1.05,
                                          scale=0.01,
                                          size=self.population)
        # round array to two decimals
        payday_array = np.around(payday_array, decimals=2)
        party_elite_cut= np.percentile(payday_array, 95)
        
        innovation_array = np.around(innovation_array, decimals=2)
        
        for idx in range(self.population):
            party_elite = False
            if payday_array[idx] >= party_elite_cut: 
                party_elite = True
            WealthAgent(self, float(payday_array[idx]), float(innovation_array[idx]), party_elite)
    
    def step(self):
        self.brackets = calc_brackets(self)
        self.total = total_wealth(self)
        print(self.total)
        exp_scale = np.mean([agent.wealth for agent in self.agents])
        self.survival_cost = expon.ppf(0.1, scale=exp_scale)
        
        # party_elites can only receive from subordinates but never give money
        if self.policy == "fascist": 
            subordinates = self.agents.select(lambda a: a.party_elite==False)
            subordinates.shuffle_do("step")
        # Divide the wealth equally among all agents at the beginning of the time step
        elif self.policy == "communist": 
            each_wealth = self.total/self.population
            for agent in self.agents: 
                agent.wealth = each_wealth
            self.agents.shuffle_do("step")
        else: 
             self.agents.shuffle_do("step")
        
        self.datacollector.collect(self)
       
       
        
       

### Headless Run

In [17]:
model = WealthModel(policy="fascist")
for step in range(100):
    model.step()
    #print(compute_gini(model))
    
output = model.datacollector.get_agent_vars_dataframe()
output.to_csv("inequality_output.csv")

output2 = model.datacollector.get_model_vars_dataframe()
output2.to_csv("model_output.csv")

200
221.70274396692386
244.86931684354323
267.9975723081733
291.982903560039
314.5394573917128
338.2200375654276
361.0193504149487
382.66234362931675
405.7503780504264
428.2296475673336
451.19449972944983
474.0205513093531
496.50471013376085
517.6092600163968
539.263716633236
560.4425518434967
581.9983173706839
599.530472707043
621.317008297952
643.4753401473371
660.1841015325612
677.5116154557816
696.847070435903
717.0459529156584
733.8557705513296
751.6041510730765
769.3838129986887
785.1468723401748
800.9401276391934
815.8967047901134
833.5145090916708
848.1905967212904
863.9308035250276
875.1915879712382
887.0712902233347
898.6018097165756
907.8600450728535
918.1725464926684
927.2181000951646
936.746734850991
942.9938438389602
945.5707637812637
949.3876166966093
948.5454388357125
945.3847916083594
945.0848507785565
941.6141981105617
938.8015032578335
937.379469282065
931.4721001317811
929.4297530763141
927.0773678487107
926.0181433494744
925.9499108613787
925.3229770544801
923.0761

### Dashboard

In [18]:
model = WealthModel()
model_params = {
    "policy": {
        "type": "Select",
        "value": "fascist",
        "values": ["econophysics", "fascist","communist"],
        "text": "Select Policy"
    }
}

wealth_plot = make_plot_component("Gini")

dash = SolaraViz(
    model, 
    components=[Churn, wealth_plot, Wealth, Histogram],
    model_params=model_params,
)

dash

In [12]:
model.steps

10