In [None]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler

dates = pd.bdate_range('2002', '2020')
n_dates = len(dates)
xs_len = 4000

In [None]:
membership = pd.DataFrame(
    data=np.random.rand(n_dates, xs_len) > 0.75,
    index=dates)
membership.sum(axis=1).plot()

weights = pd.DataFrame(
    data=np.random.randn(n_dates, xs_len),
    index=dates)
weights = weights.where(membership)
weights.notna().sum(axis=1).plot()

In [None]:
%%time

def standard_scaler(series):
    scaler = StandardScaler()
    vals = series.values.reshape(-1, 1)
    transformed = scaler.fit_transform(vals)
    return pd.Series(transformed.reshape(-1))

weights.apply(standard_scaler, axis=1).agg(['min', 'max', 'mean'], axis=1).plot()

In [None]:
%%time

def ranks_scaler(weights_frame):
    ranks = weights_frame.rank(axis=1)
    means = ranks.mean(axis=1)
    beta = (
        ranks
        .agg(['min', 'max'], axis=1)
        .sub(means, axis=0)
        .abs()
        .max(axis=1)
        .rdiv(1))
    alpha = -beta * means
    return ranks.mul(beta, axis=0).add(alpha, axis=0)

ranks_scaler(weights).agg(['min', 'max', 'mean'], axis=1).plot()

In [None]:
weights = weights.apply(standard_scaler, axis=1)
weights.notna().sum(axis=1).plot()

In [None]:
betas = pd.DataFrame(
    data=np.random.randn(n_dates, xs_len),
    index=dates)

In [None]:
%%time

def proj_hyperplane(weights, betas):
    sqr_norm = np.dot(betas, betas)
    assert sqr_norm > 1e-6, 'betas is too close to 0'
    signed_dist = np.dot(weights, betas)
    return weights - signed_dist / sqr_norm * betas

# betas = np.random.randn(4)
# weights = np.abs(np.random.randn(4))
# weights /= sum(weights)
# proj = proj_hyperplane(weights, betas)
# print(np.dot(proj, betas))

neutralized = (
    weights.fillna(0).T
    .combine(
        betas.where(membership).fillna(0).T,
        proj_hyperplane)
    .T.where(membership))
neutralized.notna().sum(axis=1).plot()

In [None]:
neutralized.mul(betas).sum(axis=1).plot()

In [None]:
neutralized.agg(['min', 'max', 'mean'], axis=1).plot()