In [1]:
import os
import torch
import gc
import logging

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()        
    torch.cuda.ipc_collect()


try:
    %run setup_paths
except:
    %run notebooks/setup_paths
    

logging.basicConfig(
    level=logging.INFO,  # or DEBUG, WARNING, etc.
    format='%(asctime)s - %(levelname)s - %(message)s',
    stream=sys.stdout
)

logging.info(f"current dir: {os.getcwd()}")

2025-08-11 21:15:54,353 - INFO - current dir: c:\Projects\scmsim


In [2]:
%run src/datasets
%run src/models

In [3]:
mlm = MovieLensData(MovieLens1MLoader('ml-1m'))
model = MatrixFactorization(mlm.num_users, mlm.num_items, 50)
model.load_state_dict(torch.load("models/MF20.1.weights", map_location='cpu'))

2025-08-11 21:15:54,746 - INFO - loading ratings


  model.load_state_dict(torch.load("models/MF20.1.weights", map_location='cpu'))


<All keys matched successfully>

In [4]:
probs = model.probablity_matrix()

In [5]:
#(torch.rand(probs.shape) < probs).sum()

In [10]:
%run src/mlsim
csdf = enrich_cause_indexes(pd.read_csv("products/MoviesCausalGPT.csv"), mlm.info)
## we do not support negative effect in simulation
csdf = csdf[csdf["causal_effect"] >= 0]
cmat = build_causal_matrix(csdf, mlm.num_items, factor=0.09)

## Geneate Data Samples

In [21]:
%run src/mlsim
for idx in range(10):
    logging.info(f"generating samples {idx}")
    watched, timestamps = generate_data(probs, cmat)
    df = create_pairs_df(watched, timestamps)
    df.to_csv(f"products/MFSim/samples.{idx}.csv", index=False)


2025-08-11 21:47:56,856 - INFO - generating samples 0
2025-08-11 21:47:58,065 - INFO - [0] - watched:0.98M; added:31733.0
2025-08-11 21:48:03,860 - INFO - Done
2025-08-11 21:48:06,199 - INFO - generating samples 1
2025-08-11 21:48:07,416 - INFO - [0] - watched:0.98M; added:31619.0
2025-08-11 21:48:10,934 - INFO - Done
2025-08-11 21:48:13,338 - INFO - generating samples 2
2025-08-11 21:48:14,620 - INFO - [0] - watched:0.98M; added:31747.0
2025-08-11 21:48:18,981 - INFO - Done
2025-08-11 21:48:21,418 - INFO - generating samples 3
2025-08-11 21:48:22,725 - INFO - [0] - watched:0.98M; added:31696.0
2025-08-11 21:48:28,044 - INFO - Done
2025-08-11 21:48:30,476 - INFO - generating samples 4
2025-08-11 21:48:31,844 - INFO - [0] - watched:0.98M; added:31539.0
2025-08-11 21:48:36,323 - INFO - Done
2025-08-11 21:48:38,654 - INFO - generating samples 5
2025-08-11 21:48:39,959 - INFO - [0] - watched:0.98M; added:31811.0
2025-08-11 21:48:43,543 - INFO - Done
2025-08-11 21:48:46,113 - INFO - generat

## Generate Ground Truth

In [32]:

selected_causes = list(set(csdf[csdf["causal_effect"] > 0]["treatment_idx"]))
len(selected_causes)

451

In [51]:
%run src/mlsim
for idx in [0]:
    gtdf = generate_ground_truth_estimate(probs, cmat, selected_causes)
    gtdf.to_csv(f"products/MFSim/gt.{idx}.csv", index=False)

2025-08-11 22:26:54,263 - INFO - [0] - watched:0.98M; added:31695.0
2025-08-11 22:26:59,609 - INFO - Done
2025-08-11 22:27:01,234 - INFO - [0] - watched:0.99M; added:34182.0
2025-08-11 22:27:06,422 - INFO - Done
2025-08-11 22:27:06,455 - INFO - [0] evaluated cause: 2; max-ate:0.1612582802772522
2025-08-11 22:27:07,945 - INFO - [0] - watched:0.98M; added:31120.0
2025-08-11 22:27:13,996 - INFO - Done
2025-08-11 22:27:15,462 - INFO - [0] - watched:0.99M; added:37855.0
2025-08-11 22:27:20,524 - INFO - Done
2025-08-11 22:27:20,557 - INFO - [1] evaluated cause: 6; max-ate:0.17599338293075562
2025-08-11 22:27:21,996 - INFO - [0] - watched:0.98M; added:31803.0
2025-08-11 22:27:27,006 - INFO - Done
2025-08-11 22:27:28,465 - INFO - [0] - watched:0.98M; added:32256.0
2025-08-11 22:27:33,465 - INFO - Done
2025-08-11 22:27:33,493 - INFO - [2] evaluated cause: 7; max-ate:0.08509933948516846
2025-08-11 22:27:34,956 - INFO - [0] - watched:0.98M; added:31882.0
2025-08-11 22:27:39,967 - INFO - Done
2025

In [56]:
gtdf[(0.1< gtdf["ate"]) & (0.99 > gtdf["ate"]) ]

Unnamed: 0,treatment_idx,resp_idx,ate
1016,2,1017,0.161258
4247,6,296,0.148179
5065,6,1114,0.165397
5130,6,1179,0.175662
5469,6,1518,0.158609
...,...,...,...
1766816,1779,273,0.171358
1768374,1779,1831,0.159603
1775906,1810,1459,0.155960
1779231,1833,832,0.147185


## manual check


In [38]:
csdf[(csdf["causal_effect"] > 0) & csdf["treatment_idx"].isin([2,6,7,10])]

Unnamed: 0,idx,treatment_title,resp_title,sate,causal_effect,explanation,treatment_idx,resp_idx
125,767,Heat (1995),Mission: Impossible (1996),-0.305805,1,Both are high-stakes crime/action films from t...,6,648
155,271,Jumanji (1995),Mission: Impossible (1996),-0.394449,1,Both are mid-90s mainstream adventure films. W...,2,648
257,577,GoldenEye (1995),Mission: Impossible (1996),-0.337714,3,Both are high-profile 90s action/spy films. Wa...,10,648
1091,2888,Jumanji (1995),First Kid (1996),-0.146594,1,Both films are family-friendly '90s movies wit...,2,881
1453,8190,Jumanji (1995),Swiss Family Robinson (1960),-0.111762,2,Both are family adventure films featuring chil...,2,1017
2015,6606,Heat (1995),Breakdown (1997),-0.118072,2,Both are intense 90s thrillers appealing to fa...,6,1518
2237,1634,Heat (1995),Absolute Power (1997),-0.178313,1,Both movies are crime/political thrillers from...,6,1459
2514,8662,Heat (1995),"Funeral, The (1996)",-0.110843,2,Both are crime dramas with ensemble casts. Wat...,6,1114
2720,3052,Heat (1995),"Grifters, The (1990)",-0.145509,2,Both are intense American crime dramas from ad...,6,1179
3542,1197,Heat (1995),"Replacement Killers, The (1998)",0.785049,2,Both are crime/action films targeting similar ...,6,1769


In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    

tensor(0.2700)

In [45]:
#711 -> 1595
#respid = 1595

tidx = 10
respidx = 648
print(cmat[tidx-1, respidx-1])
res = []
for itr in range(1):
    logging.info(f"##### : {itr}")
    watched_treatement,_ = generate_data(probs, cmat, intervention={tidx:1})
    watched_control,_ = generate_data(probs, cmat, intervention={tidx:0})
    Y1 = watched_treatement[:,respidx-1].mean()
    Y0 = watched_control[:,respidx-1].mean()
    logging.info(f"res: {(Y1, Y0, Y1-Y0)}")
    ate = Y1-Y0
    res.append(ate)



tensor(0.2700)
2025-08-11 22:16:55,317 - INFO - ##### : 0
2025-08-11 22:16:56,651 - INFO - [0] - watched:0.98M; added:32310.0
2025-08-11 22:16:59,964 - INFO - Done
2025-08-11 22:17:01,209 - INFO - [0] - watched:0.98M; added:31411.0
2025-08-11 22:17:05,399 - INFO - Done
2025-08-11 22:17:05,401 - INFO - res: (tensor(0.5699), tensor(0.4144), tensor(0.1555))
