In [None]:
## Set up Google Colab environment
try:
    from google.colab import drive
    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    import os
    # Link your Goolge Drive to Goolge Colab
    drive.mount('/content/gdrive')
    %cd '/content/gdrive/My Drive'
    # Create a folder to store the project
    if 'Project' in os.listdir():
        %cd '/content/gdrive/My Drive/Project'
    else:
        ! mkdir '/content/gdrive/My Drive/Project/'
        %cd '/content/gdrive/My Drive/Project/'
    # Clone GitHub repo to the folder and change working directory to the repo
    if 'demand_asset_pricing' not in os.listdir():
        ! git clone https://github.com/hanxuh-hub/demand_asset_pricing.git
    %cd '/content/gdrive/My Drive/Project/demand_asset_pricing'

## Load packages
import pickle
import numpy as np
import pandas as pd
from numba import jit
print('#----------Successfully Loaded Python Packages----------#')

In [None]:
def load_data():
    # Open the pkl file if it exists
    try:
        with open('DataCounterfactual.pkl', 'rb') as handle:
            data = pickle.load(handle) 
    
    # Load and clean npy data if the pkl file does not exist
    except:
        
        mdata = np.load('DataCounterfactual.npy')
        
        # Reformat data
        vmgrno_in       = np.array(mdata[:,0]).reshape(-1,1)
        vpermno_in      = np.array(mdata[:,1]).reshape(-1,1)
        
        # Organize data - rows: stocks; columns: managers

        vmgrno          = np.unique(vmgrno_in).reshape(-1,1)
        vpermno         = np.unique(vpermno_in).reshape(-1,1)

        iNmgr           = len(vmgrno)
        iNstocks        = len(vpermno)

        mlatent         = np.zeros((iNstocks,iNmgr))
        mweight         = np.zeros((iNstocks,iNmgr))

        vLNme           = np.zeros((iNstocks,1))
        vLNme[:,0]      = np.NaN

        vaum            = np.zeros((iNmgr,1))
        voutweight      = np.zeros((iNmgr,1))
        vbME            = np.zeros((iNmgr,1))
        vaum[:,0]       = np.NaN
        voutweight[:,0] = np.NaN
        vbME[:,0]       = np.NaN

        for i in range(iNmgr):

            vsel             = np.where(vmgrno_in[:,0] == vmgrno[i,0])[0]

            vpermno_in_sel   = np.array(mdata[vsel,1]).reshape(-1,1)

            vaum[i,0]        = np.array(mdata[vsel[0],3])
            voutweight[i,0]  = np.array(mdata[vsel[0],5])
            vbME[i,0]        = np.array(mdata[vsel[0],6])

            vweight_in       = np.array(mdata[vsel,4]).reshape(-1,1)
            vlatent_in       = np.array(mdata[vsel,7]).reshape(-1,1)


            for n in range(iNstocks):

                isel         = np.where(vpermno_in_sel[:,0] == vpermno[n,0])[0]

                if len(isel) != 0:

                    mweight[n,i]  = vweight_in[isel,0]
                    mlatent[n,i]  = vlatent_in[isel,0]

        for n in range(iNstocks):

            vsel                  = np.where(vpermno_in[:,0] == vpermno[n,0])[0]
            vLNme[n]              = np.array(mdata[vsel[0],2])
        
        # Save important variables as python dictionary
        data = {}
        data['iNmgr']      = iNmgr
        data['iNstocks']   = iNstocks
        data['mlatent']    = mlatent
        data['mweight']    = mweight
        data['vaum']       = vaum
        data['vbME']       = vbME
        data['vLNme']      = vLNme
        data['vmgrno']     = vmgrno
        data['voutweight'] = voutweight
        data['vpermno']    = vpermno
        
        # Export cleaned data to local drive
        with open('DataCounterfactual.pkl', 'wb') as handle:
            pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    return data

In [None]:
@jit
def one_step_update(vflow,mlatentLOOP,vLNmeLOOP,istep):
    # update AUM

    vaumLOOP      = vaum * (mweight.T @ np.exp(vLNmeLOOP - vLNme) + voutweight) + vflow

    # Update demand

    mweightD      = mlatentLOOP * np.exp(vLNmeLOOP @ vbME.T)
    mweightLOOP   = mweightD / (np.ones((iNstocks,1)) @ (1 + np.sum(mweightD,0).reshape(1,-1)))

    vdemand       = np.sum(mweightLOOP * (np.ones((iNstocks,1)) @ vaumLOOP.T),1).reshape(-1,1)

    vDdemand      = np.sum((np.ones((iNstocks,1)) @ (vbME * vaumLOOP).T) \
                         * mweightLOOP * (1 - mweightLOOP) / (vdemand @ np.ones((1,iNmgr))),1).reshape(-1,1)

    temp          = vDdemand.reshape(-1)
    temp[temp>0]  = 0       
    vDdemand      = (1./(1 - temp)).reshape(-1,1)
    
    
    # Check convergence

    vgap          = np.log(vdemand) - vLNmeLOOP
    dgap          = np.max(np.abs(vgap))

    vLNmeLOOP     = vLNmeLOOP + vDdemand * vgap
    istep         = istep + 1     
    
    return vLNmeLOOP, istep, dgap


@jit
def solve_for_price(i):
    # Initialize market caps
    
    vLNmeLOOP        = vLNme.copy()
    
    # Set latent demand to zero
    
    mlatentLOOP      = mlatent.copy()
    mlatentLOOP[:,i] = np.zeros(iNstocks)     
    
    # Compute the flow to other managers (adjust for HH)
    
    iHH              = np.where(vmgrno[:,0] == 0)[0]
    daumHH           = vaum[iHH]
        
    vS               = vaum / (np.sum(vaum) - vaum[i] - daumHH)
    vS[i]            = 0
    vS[iHH]          = 0
    vflow            = vaum[i] * vS
    
    # Main loop: Solve for prices
    
    istep   = 1
    dgap    = 1
    
    while istep < 1e3 and dgap > 1e-4:
        vLNmeLOOP, istep, dgap = one_step_update(vflow,mlatentLOOP,vLNmeLOOP,istep)
        
    return vLNmeLOOP.reshape(-1)


def counterfactuals():
    ## Counterfactuals

    mLNme_cf      = np.zeros((iNstocks,iNmgr))
    mLNme_cf[:,:] = np.NaN

    for i in range(1,iNmgr):
        
        print(np.round((i+1)/iNmgr,4))   
        
        vLNmeLOOP = solve_for_price(i)

        # Store the data

        mLNme_cf[:,i]     = vLNmeLOOP.reshape(-1)    
    
    return mLNme_cf


In [None]:
# Load data
data = load_data()

# Set global variables
iNmgr      = data['iNmgr']
iNstocks   = data['iNstocks'] 
mlatent    = data['mlatent'] 
mweight    = data['mweight'] 
vaum       = data['vaum']  
vbME       = data['vbME'] 
vLNme      = data['vLNme'] 
vmgrno     = data['vmgrno'] 
voutweight = data['voutweight'] 
vpermno    = data['vpermno']

In [None]:
# Run solver
# Note that for test purpose, you can replace 'iNmgr' with 2 in the third line 
# of counterfactuals() to only compute the 2nd column of mLNme_cf
mLNme_cf = counterfactuals()