
# Linear Regression with Embeddings



In [21]:

# This cell pulls the helper module from GitHub so imports work in Colab.
# It overwrites any local copy to ensure students get the latest version.
HELPER_URL = "https://raw.githubusercontent.com/ucla-anderson-SSAI/SSAI/main/hm_helper.py"

import os, importlib, sys, subprocess

def _wget(url, out):
    cmd = ["bash", "-lc", f'wget -q -O "{out}" "{url}"']
    r = subprocess.run(cmd)
    if r.returncode != 0:
        raise RuntimeError(f"Download failed for: {url}")

_wget(HELPER_URL, "hm_helper.py")

import hm_helper
importlib.reload(hm_helper)

print("✅ Helper ready:", hm_helper.__name__)


✅ Helper ready: hm_helper


In [23]:
# ==========================================================
# H&M SALES EMBEDDINGS ABLATION — COLAB VERSION (FULL CELL)
# ==========================================================

# --- CONFIG ---
CSV_PATH = "https://raw.githubusercontent.com/ucla-anderson-SSAI/SSAI/main/HandMSales.csv"
PRODUCT_TYPE_FILTER = None         # e.g. "Hoodies" if you want to filter by product type
N_PRODUCTS = 500
SAMPLE = "random"
RANDOM_STATE = 42
LIMIT_EMBED_DIMS = 8
DEV_SKIP_EMBED = False  # Set True to skip embedding computation for faster runs

# --- IMPORTS ---
from hm_helper import load_and_prepare, design_matrices, fit_lasso

# --- LOAD DATA ---
print("[INFO] Preparing data…")
train, test = load_and_prepare(
    CSV_PATH,
    PRODUCT_TYPE_FILTER,
    N_PRODUCTS,
    SAMPLE,
    RANDOM_STATE
)
print(f"[INFO] Train months: {train['month_ts'].min().date()} → {train['month_ts'].max().date()} | rows: {len(train)}")
print(f"[INFO] Test  months: {test['month_ts'].min().date()}  → {test['month_ts'].max().date()}  | rows: {len(test)}")

# ==========================================================
# MODEL A — TEMPORAL FEATURES ONLY
# ==========================================================
XA_tr, XA_te, y_tr, y_te, y_tr_log = design_matrices(
    train, test,
    include_numeric=True,
    include_month_ohe=True,
    include_channel_ohe=True,
    include_meta_ohe=False,
    include_embeddings=False,
    dev_skip_embed=False
)
metrics_A = fit_lasso(XA_tr, y_tr_log, XA_te, y_te, "Model A (Temporal)")

# ==========================================================
# MODEL B — TEMPORAL + EMBEDDINGS
# ==========================================================
XB_tr, XB_te, *_ = design_matrices(
    train, test,
    include_numeric=True,
    include_month_ohe=True,
    include_channel_ohe=True,
    include_meta_ohe=False,
    include_embeddings=(not DEV_SKIP_EMBED),
    limit_embed_dims=LIMIT_EMBED_DIMS,
    inter_month=False,
    inter_age=False,
    inter_channel=False,
    model_name="openai/clip-vit-base-patch32",
    dev_skip_embed=DEV_SKIP_EMBED
)
metrics_B = fit_lasso(XB_tr, y_tr_log, XB_te, y_te, "Model B (Temporal + Embeddings)")

# ==========================================================
# MODEL C — TEMPORAL + EMBEDDINGS × SEASONALITY/LIFECYCLE
# ==========================================================
XC_tr, XC_te, *_ = design_matrices(
    train, test,
    include_numeric=True,
    include_month_ohe=True,
    include_channel_ohe=True,
    include_meta_ohe=False,
    include_embeddings=(not DEV_SKIP_EMBED),
    limit_embed_dims=LIMIT_EMBED_DIMS,
    inter_month=True,
    inter_age=True,
    inter_channel=False,
    model_name="openai/clip-vit-base-patch32",
    dev_skip_embed=DEV_SKIP_EMBED
)
metrics_C = fit_lasso(XC_tr, y_tr_log, XC_te, y_te, "Model C (Temporal + Embeddings × Seasonality/Lifecycle)")

# ==========================================================
# RESULTS SUMMARY
# ==========================================================
print("\n=== RESULTS (R² / RMSE / MAE) ===")
print(f"A: R²={metrics_A['r2']:.3f}  RMSE={metrics_A['rmse']:.3f}  MAE={metrics_A['mae']:.3f}")
print(f"B: R²={metrics_B['r2']:.3f}  RMSE={metrics_B['rmse']:.3f}  MAE={metrics_B['mae']:.3f}")
print(f"C: R²={metrics_C['r2']:.3f}  RMSE={metrics_C['rmse']:.3f}  MAE={metrics_C['mae']:.3f}")


[INFO] Preparing data…
[START] Loading https://raw.githubusercontent.com/ucla-anderson-SSAI/SSAI/main/HandMSales.csv
(array([ 0.00647684,  0.33084482, -0.01330697,  0.05791098, -0.15775721,
        0.12472432, -0.00963638, -0.00480747, -0.17166381,  0.        ,
       -0.02515709,  0.92013829, -0.11002552, -0.1074554 ,  0.        ,
       -0.21114879, -0.00273071,  0.11841269,  0.37532112, -0.04023559,
        0.        ,  0.01934152,  0.00350349, -0.04791826, -0.03170584,
       -0.08502722,  0.00273279,  0.07180293,  0.07864282, -0.00510369,
       -0.10472226, -0.06084039,  0.01954619,  0.017426  ,  0.03751721,
        0.05509439, -0.        ]), 4.902500711523317, 4.907817444327576, 223)
(array([ 0.00663947,  0.33093531, -0.01332861,  0.05809908, -0.15776836,
        0.12473676, -0.0097373 , -0.00489145, -0.17174075,  0.        ,
       -0.02529216,  0.91978342, -0.1103841 , -0.10788956,  0.        ,
       -0.21135357, -0.00271548,  0.11856975,  0.37654009, -0.04044552,
        0. 

  df = pd.read_csv(CSV_PATH, dtype={"article_id":"string"}, parse_dates=[parse_col])


[DONE ] Loaded in 3.74s
[START] Cleaning
[DONE ] Cleaned in 0.11s
[INFO] Train months: 2018-12-01 → 2020-07-01 | rows: 2858
[INFO] Test  months: 2020-08-01  → 2020-09-01  | rows: 153
(array([-0.        ,  0.10896379,  0.        ,  0.        , -0.        ,
        0.        ,  0.        , -0.        ,  0.        ,  0.        ,
        0.        ,  0.99555767,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        , -0.        ,  0.        , -0.        , -0.        ,
       -0.        ,  0.        ,  0.        ,  0.        , -0.        ,
       -0.        , -0.        , -0.        ,  0.        ,  0.        ,
        0.        , -0.        ]), 0.06431076365652189, 0.464723030995497, 14)
(array([ 0.        ,  0.04903679,  0.        ,  0.        , -0.        ,
        0.        ,  0.        , -0.        ,  0.        ,  0.        ,
        0.        ,  0.990315  ,  0.        ,  0.        ,  0.        ,
        0.      

[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 2 concurrent workers.


(array([ 0.        ,  0.12811368,  0.        ,  0.        , -0.04578227,
        0.02015636,  0.        , -0.        ,  0.        , -0.        ,
        0.        ,  1.00700515,  0.        ,  0.        ,  0.        ,
        0.        , -0.        , -0.        ,  0.00790204,  0.        ,
        0.        , -0.        ,  0.        , -0.        , -0.        ,
       -0.        ,  0.        ,  0.        ,  0.        , -0.        ,
       -0.        , -0.        , -0.        ,  0.        ,  0.        ,
        0.        , -0.        ]), 0.14806813277095898, 0.4143602038207337, 10)
(array([-0.        ,  0.13452233,  0.        ,  0.        , -0.06215983,
        0.01646188,  0.        , -0.        ,  0.        ,  0.        ,
        0.        ,  1.0220928 ,  0.        ,  0.00389591,  0.00248494,
        0.        ,  0.        , -0.        ,  0.00849453,  0.        ,
        0.        , -0.        ,  0.        , -0.        , -0.        ,
       -0.        ,  0.        ,  0.        ,  0.     

[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:   11.5s


(array([-0.0097503 ,  0.17548717, -0.        ,  0.        , -0.11301803,
        0.07182067, -0.        , -0.        , -0.        , -0.        ,
        0.        ,  1.00167969,  0.        ,  0.00611427,  0.        ,
       -0.00121838, -0.        , -0.        ,  0.        ,  0.        ,
        0.        , -0.        ,  0.        , -0.        , -0.        ,
       -0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
       -0.        , -0.        , -0.        ,  0.        , -0.        ,
        0.        , -0.        ]), 0.021225790124390187, 0.424105212685365, 14)
(array([-1.19665027e-02,  1.80858022e-01, -0.00000000e+00,  0.00000000e+00,
       -1.15925144e-01,  7.53614038e-02, -0.00000000e+00, -0.00000000e+00,
       -0.00000000e+00, -0.00000000e+00,  0.00000000e+00,  1.00220159e+00,
        0.00000000e+00,  1.01047091e-02,  0.00000000e+00, -1.09048138e-02,
       -0.00000000e+00, -0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00, -0.00000000e+00

[Parallel(n_jobs=-1)]: Done   3 out of   5 | elapsed:   33.4s remaining:   22.2s


(array([ 0.        ,  0.1143149 ,  0.        ,  0.        , -0.01787853,
        0.        ,  0.        , -0.        ,  0.        ,  0.        ,
        0.        ,  1.02137745,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        , -0.        , -0.        ,  0.        , -0.        ,
       -0.        ,  0.        ,  0.        ,  0.        , -0.        ,
       -0.        , -0.        , -0.        ,  0.        ,  0.        ,
        0.        , -0.        ]), 0.018226880996508044, 0.4817550786645079, 7)
(array([ 0.        ,  0.11793173,  0.        ,  0.        , -0.03130894,
        0.010051  ,  0.        , -0.        ,  0.        ,  0.        ,
        0.        ,  1.0256188 ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        , -0.        ,  0.        ,  0.        ,
        0.        , -0.        , -0.        ,  0.        , -0.        ,
       -0.        ,  0.        ,  0.        ,  0.     

[Parallel(n_jobs=-1)]: Done   5 out of   5 | elapsed:   48.1s finished


(array([ 0.0077277 ,  0.33154082, -0.01347345,  0.05935783, -0.15784301,
        0.12481996, -0.01041267, -0.00545339, -0.17225559,  0.        ,
       -0.02619597,  0.91740872, -0.11278369, -0.11079483,  0.        ,
       -0.21272389, -0.0026136 ,  0.11962075,  0.38469713, -0.04185032,
        0.        ,  0.01940695,  0.00356879, -0.04824703, -0.03203916,
       -0.08544566,  0.00302161,  0.07199092,  0.07884762, -0.00564026,
       -0.10537991, -0.06160144,  0.02013505,  0.01813466,  0.03775868,
        0.05520713, -0.        ]), 4.83124061849594, 4.907817444327576, 223)
(array([ 0.007753  ,  0.3315549 , -0.01347681,  0.0593871 , -0.15784474,
        0.1248219 , -0.01042837, -0.00546645, -0.17226755,  0.        ,
       -0.02621699,  0.91735351, -0.11283948, -0.11086237,  0.        ,
       -0.21275575, -0.00261123,  0.11964518,  0.38488676, -0.04188298,
        0.        ,  0.01940828,  0.00357011, -0.04825368, -0.0320459 ,
       -0.08545412,  0.00302745,  0.07199473,  0.07885176

[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 2 concurrent workers.


(array([ 0.        ,  0.04903679,  0.        ,  0.        , -0.        ,
        0.        ,  0.        , -0.        ,  0.        ,  0.        ,
        0.        ,  0.990315  ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        , -0.        ,  0.        ,  0.        , -0.        ,
        0.        ,  0.        , -0.        ,  0.        ,  0.        ,
       -0.        , -0.        , -0.        , -0.        ,  0.        ,
        0.        ,  0.        , -0.        , -0.        , -0.        ,
       -0.        ,  0.        ,  0.        ,  0.        , -0.        ]), 0.0026400528042813676, 0.4143602038207337, 15)
(array([-0.        ,  0.10896379,  0.        ,  0.        , -0.        ,
        0.        ,  0.        , -0.        ,  0.        ,  0.        ,
        0.        ,  0.99555767,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.   

[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:    7.1s


(array([-0.00000000e+00,  1.12579448e-01,  0.00000000e+00,  0.00000000e+00,
       -8.90944820e-04,  0.00000000e+00,  0.00000000e+00, -0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  9.56655968e-01,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00, -0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
       -0.00000000e+00,  0.00000000e+00, -0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00, -0.00000000e+00, -0.00000000e+00,
       -0.00000000e+00, -0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00, -0.00000000e+00, -0.00000000e+00,
       -0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
       -0.00000000e+00]), 0.018867417582669077, 0.424105212685365, 14)
(array([-0.        ,  0.12075797,  0.        ,  0.        , -0.01187548,
        0.        ,  0.       

[Parallel(n_jobs=-1)]: Done   3 out of   5 | elapsed:   21.1s remaining:   14.1s


(array([ 0.        ,  0.32969055, -0.        ,  0.        , -0.14097994,
        0.09823051,  0.00882327, -0.01076733, -0.12681606, -0.01204299,
        0.00479473,  0.98945194,  0.        ,  0.01171296,  0.        ,
       -0.02903071, -0.        , -0.        ,  0.        ,  0.        ,
        0.04841601,  0.        , -0.00259492,  0.02985523, -0.        ,
        0.03158367,  0.01277629,  0.00126875,  0.01389363,  0.        ,
       -0.01900908, -0.        , -0.08677038, -0.04863736,  0.01933486,
        0.0530871 ,  0.05439889, -0.        , -0.05891353, -0.        ,
       -0.        ,  0.        ,  0.        ,  0.04053965, -0.        ]), 0.44168392436426984, 0.4817550786645079, 32)
(array([ 0.        ,  0.34121748, -0.00211857,  0.        , -0.14175219,
        0.09809826,  0.01161981, -0.01270106, -0.13407003, -0.01300375,
        0.00618446,  0.98616599,  0.        ,  0.01296947,  0.        ,
       -0.03076988, -0.        , -0.        ,  0.        ,  0.        ,
        0.04949

[Parallel(n_jobs=-1)]: Done   5 out of   5 | elapsed:   28.8s finished
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 2 concurrent workers.


(array([ 0.        ,  0.04903679,  0.        ,  0.        , -0.        ,
        0.        ,  0.        , -0.        ,  0.        ,  0.        ,
        0.        ,  0.990315  ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        , -0.        ,  0.        ,  0.        , -0.        ,
        0.        ,  0.        , -0.        ,  0.        ,  0.        ,
       -0.        , -0.        , -0.        , -0.        ,  0.        ,
       -0.        ,  0.        ,  0.        , -0.        ,  0.        ,
        0.        , -0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        , -0.        , -0.        , -0.        ,
        0.        ,  0.        ,  0.        , -0.        , -0.        ,
       -0.        , -0.        ,  0.        ,  0.        ,  0.        ,
       -0.        ]), 0.0026400528042813676, 0.4143602038207337, 15)
(array([-0.        ,  0.10896379,  0.        ,  0.        , -0.   

[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:   12.1s


(array([-0.00000000e+00,  2.51518046e-01, -0.00000000e+00,  0.00000000e+00,
       -1.24032626e-01,  7.78347450e-02, -0.00000000e+00, -0.00000000e+00,
       -4.47150168e-02, -0.00000000e+00,  0.00000000e+00,  9.74648198e-01,
        0.00000000e+00,  1.64385968e-02,  0.00000000e+00, -2.32084894e-02,
       -0.00000000e+00, -0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00, -0.00000000e+00, -0.00000000e+00,  4.25423029e-02,
       -1.50189305e-03,  3.18524045e-02, -1.78733007e-02,  4.44802622e-03,
        1.40277382e-02,  1.96818348e-02, -2.05972824e-02, -4.08424957e-02,
       -4.16454649e-02,  0.00000000e+00,  0.00000000e+00, -0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00, -0.00000000e+00, -7.23759355e-03,  0.00000000e+00,
        3.39608728e-02, -0.00000000e+00, -1.16260127e-02, -0.00000000e+00,
       -0.00000000e+00, -0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00, 

[Parallel(n_jobs=-1)]: Done   3 out of   5 | elapsed:   31.6s remaining:   21.0s


(array([-0.00000000e+00,  2.45033393e-01, -0.00000000e+00,  0.00000000e+00,
       -1.34575979e-01,  9.46891184e-02,  0.00000000e+00, -6.07223870e-03,
       -7.06212563e-02, -0.00000000e+00,  0.00000000e+00,  1.01413328e+00,
        0.00000000e+00,  4.24900756e-04,  0.00000000e+00, -0.00000000e+00,
       -0.00000000e+00, -1.38177724e-02,  0.00000000e+00,  0.00000000e+00,
        8.65364710e-03,  0.00000000e+00, -0.00000000e+00,  2.53631670e-02,
        0.00000000e+00,  2.28133477e-02,  8.73948956e-03,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00, -8.29971438e-03, -0.00000000e+00,
       -1.39318940e-02, -3.33181880e-03,  8.83565762e-03, -4.98350054e-03,
        1.15078865e-02, -0.00000000e+00,  0.00000000e+00, -0.00000000e+00,
        1.32937503e-03,  0.00000000e+00, -9.82795866e-03,  0.00000000e+00,
        2.39902416e-02, -0.00000000e+00, -9.24522168e-03, -0.00000000e+00,
       -2.86580938e-02, -0.00000000e+00,  0.00000000e+00,  2.23371189e-02,
        2.33120270e-02, 

[Parallel(n_jobs=-1)]: Done   5 out of   5 | elapsed:   43.1s finished
