diff --git a/src/empyrial/main.py b/src/empyrial/main.py index 4d68fa18..8b132c09 100644 --- a/src/empyrial/main.py +++ b/src/empyrial/main.py @@ -122,19 +122,25 @@ def get_returns(stocks, wts, start_date, end_date=TODAY): if len(stocks) > 1: assets = yf.download(stocks, start=start_date, end=end_date, progress=False)["Adj Close"] assets = assets.filter(stocks) - ret_data = assets.pct_change()[1:] - returns = (ret_data * wts).sum(axis=1) + initial_alloc = wts/assets.iloc[0] + if initial_alloc.isna().any(): + raise ValueError("Some stock is not available at initial state!") + portfolio_value = (assets * initial_alloc).sum(axis=1) + returns = portfolio_value.pct_change()[1:] return returns else: df = yf.download(stocks, start=start_date, end=end_date, progress=False)["Adj Close"] df = pd.DataFrame(df) - returns = df.pct_change() + returns = df.pct_change()[1:] return returns def get_returns_from_data(data, wts): - ret_data = data.pct_change()[1:] - returns = (ret_data * wts).sum(axis=1) + initial_alloc = wts/data.iloc[0] + if initial_alloc.isna().any(): + raise ValueError("Some stock is not available at initial state!") + portfolio_value = (data * initial_alloc).sum(axis=1) + returns = portfolio_value.pct_change()[1:] return returns