In [1]:
from env import FrameEnv
from env import DataPath

In [None]:
dataPath = DataPath(
    base="topk-off-policy-correction/data/ml-1m/",
    ratings="ratings.csv",
    embeddings="ml20_pca128.pkl",
    use_cache=False,
)

# env = FrameEnv(
#     path=dataPath
# )

In [None]:
import pickle

path_embedding = "topk-off-policy-correction/data/ml-1m/ml20_pca128.pkl"

movie_embeddings_key_dict = pickle.load(open(path_embedding, "rb"))

In [4]:
import torch

keys = list(sorted(movie_embeddings_key_dict.keys()))
key_to_id = dict(zip(keys, range(len(keys))))
id_to_key = dict(zip(range(len(keys)), keys))

items_embeddings_id_dict = {}
for k in movie_embeddings_key_dict.keys():
    items_embeddings_id_dict[key_to_id[k]] = movie_embeddings_key_dict[k]

items_embeddings_id_dict = torch.stack(
    [items_embeddings_id_dict[i] for i in range(len(items_embeddings_id_dict))]
)

In [5]:
embeddings = items_embeddings_id_dict
key_to_id = key_to_id
id_to_key = id_to_key
num_items = embeddings.shape[0]

In [None]:
import pandas as pd

path_ratings = "topk-off-policy-correction/data/ml-1m/ratings.csv"
ratings = pd.read_csv(
    path_ratings,
    names=[
        "userId",
        "movieId",
        "rating",
        "timestamp",
    ]
)

In [7]:
ratings["rating"] = ratings["rating"].apply(lambda i: 2* (i - 2.5))
ratings["movieId"] = ratings["movieId"].apply(key_to_id.get).fillna(0.0)




In [8]:
users = ratings[["userId", "movieId"]].groupby(["userId"]).size()
frames_size = 10
users = users[users > frames_size].sort_values(ascending=False).index

In [9]:
ratings_grp = (
    ratings.sort_values(by="timestamp")
    .set_index("userId")
    .drop("timestamp", axis=1)
    .groupby("userId")
)

ratings_grp

<pandas.core.groupby.generic.DataFrameGroupBy object at 0x7fe32185d040>

In [10]:
user_dict = {}

def app(x):
    userid = x.index[0]
    user_dict[userid] = {}
    user_dict[userid]['items'] = x["movieId"].values
    user_dict[userid]['ratings'] = x["rating"].values


ratings_grp.apply(app)

In [11]:
user_dict, users

({1: {'items': array([3099., 1661., 1242., 1003., 2255., 1755., 3320., 1181., 2718.,
           257.,  708., 1169.,  902.,  602., 2606., 1877., 1944., 3018.,
           921., 1016., 1878., 1009., 1934.,  148., 1075.,  897., 1259.,
          2711., 1218., 2676.,  653., 2832.,  527., 3027., 2705., 1010.,
          2236., 1172.,  588., 2313., 1495.,  523.,  732.,  589.,  582.,
             0., 2601.,  770., 2209., 2270., 1823., 1515.,   47.]),
   'ratings': array([3., 3., 5., 5., 1., 5., 3., 3., 5., 3., 1., 5., 3., 3., 3., 5., 5.,
          5., 3., 5., 3., 5., 3., 5., 3., 1., 5., 3., 3., 3., 1., 3., 3., 3.,
          3., 5., 1., 1., 3., 3., 3., 5., 1., 5., 3., 5., 1., 3., 3., 5., 3.,
          3., 5.])},
  2: {'items': array([1173., 1191., 1184., 2631., 1265., 2857., 1199., 1169.,  315.,
          2772., 2943., 1187., 1861., 1181., 3008.,  587., 3379.,  511.,
          1789., 1068., 2416.,  108., 2948., 1983., 3060., 1219., 3018.,
          1327., 1171., 1873.,  903., 1869., 1753., 1062.,

In [12]:
from sklearn.model_selection import train_test_split

test_size = 0.05

train_users, test_users = train_test_split(users, test_size=test_size)

len(train_users), len(test_users)

(5738, 302)

In [13]:
train_users = pd.Series(dict([(i, user_dict[i]["items"].shape[0]) for i in users])).sort_values(ascending=False).index

In [14]:
train_users

Index([4169, 1680, 4277, 1941, 1181,  889, 3618, 2063, 1150, 1015,
       ...
       4628, 2584, 3291, 2696, 4178, 5380, 1406, 5525, 2111, 3021],
      dtype='int64', length=6040)

In [15]:
from torch.utils.data import Dataset, DataLoader

class UserDataset(Dataset):
    def __init__(self, users, user_dict):
        self.users = users
        self.user_dict = user_dict
        
    def __len__(self):
        return len(self.users)
    
    def __getitem__(self, idx):
        idx = self.users[idx]
        group = self.user_dict[idx]
        items = group["items"][:]
        rates = group["ratings"][:]
        size = items.shape[0]
        return {"items": items, "rates": rates, "sizes": size, "users": idx}

train_user_dataset = UserDataset(
    train_users,
    user_dict,
)

In [16]:
train_dataloader = DataLoader(
    train_user_dataset,
    batch_size=1,
    shuffle=False,
)

for idx, batch in enumerate(train_dataloader):
    print(idx, batch.keys())
    print(batch["items"], len(batch["items"][0]))
    print(batch["rates"])
    print(batch["sizes"])
    print(batch["users"])
    break

0 dict_keys(['items', 'rates', 'sizes', 'users'])
tensor([[2567., 1240.,  419.,  ...,  490., 1728., 1498.]], dtype=torch.float64) 2314
tensor([[ 3.,  5.,  1.,  ...,  3., -1.,  1.]], dtype=torch.float64)
tensor([2314])
tensor([4169])


In [17]:
import numpy as np

def rolling_window(a, window):
    """numpy를 사용한 시퀀스 frame_size 만큼 window
    
          [2567. 1240.  419. ...  490. 1728. 1498.]
          -> [[2567. 1240.  419. ... 3710. 3741. 3703.]
              [1240.  419. 2543. ... 3741. 3703. 3720.]
              ...
              [2725. 3823. 3791. ... 1380.  490. 1728.]
              [3823. 3791. 3289. ...  490. 1728. 1498.]]

    Args:
        a (_type_): _description_
        window (_type_): _description_

    Returns:
        _type_: _description_
    """
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    res = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
    return res

def prepare_batch_wrapper(batch):
    items_t, ratings_t, sizes_t, users_t = [], [], [], []

    for i in range(len(batch)):
        items_t.append(batch[i]["items"])
        ratings_t.append(batch[i]["rates"])
        sizes_t.append(batch[i]["sizes"])
        users_t.append(batch[i]["users"])

    frame_size = 10
    
    items_t = np.concatenate([rolling_window(i, frame_size + 1) for i in items_t], 0)
    ratings_t = np.concatenate([rolling_window(i, frame_size + 1) for i in ratings_t], 0)
    
    items_t = torch.tensor(items_t)
    users_t = torch.tensor(users_t)
    ratings_t = torch.tensor(ratings_t).float()
    sizes_t = torch.tensor(sizes_t)
    
    # frame_size is 10
    items_emb = embeddings[items_t.long()] # shape: (b, 11, 128)
    b_size = ratings_t.size(0)
    
    
    items = items_emb[:, :-1, :].view(b_size, -1) # start from 0 (b, 10, 128) -> (b, 10*1280)
    next_items = items_emb[:, 1:, :].view(b_size, -1) # start from 1 (b, 10, 128) -> (b, 10*1280)
    ratings = ratings_t[:, :-1] # (b, 10)
    next_ratings = ratings_t[:, 1:]

    state = torch.cat([items, ratings], 1) # (b, 10*1280 + 10)
    next_state = torch.cat([next_items, next_ratings], 1)
    action = items_t[:, -1] # last item's emb
    reward = ratings[:, -1] # last item's rating
    
    done = torch.zeros(b_size) # ?
    done[torch.cumsum(sizes_t-frame_size, dim=0) - 1] = 1 # done은 어디에 쓰지?

    one_hot_action = torch.zeros(b_size, num_items)
    one_hot_action.scatter_(1, action.view(-1, 1).long(), 1)

    batch = {
        "state": state, # items + ratings
        "action": one_hot_action, # last item
        "reward": reward, # last item rating
        "next_state": next_state, # next items + next ratings
        "done": done, # each user history end
        "meta" : {"users": users_t, "sizes": sizes_t}, # userID, hist_length
    }
    
    return batch

train_dataloader = DataLoader(
    train_user_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=prepare_batch_wrapper
)

for idx, batch in enumerate(train_dataloader):
    print(batch["meta"]["users"])
    print(batch["meta"]["sizes"])
    print(batch["state"].shape)
    print(batch["action"].shape)
    break

tensor([4169])
tensor([2314])
torch.Size([2304, 1290])
torch.Size([2304, 27278])


In [18]:
from models import ChooseREINFORCE
from models import DiscreteActor, Critic


rl_params = {
    "reinforce":    ChooseREINFORCE(ChooseREINFORCE.basic_reinforce),
    "gamma":        0.99,
    "min_value":    -10,
    "max_value":    10,
    "policy_step":  10,
    "soft_tau":     0.001,
    "policy_lr":    1e-5,
    "value_lr":     1e-5,
    "actor_weight_init":    54e-2,
    "critic_weight_init":   6e-1,
}

num_items = embeddings.shape[0]

networks = {
    "value_net":            Critic(1290, num_items, 2048, rl_params["critic_weight_init"]),
    "target_value_net":     Critic(1290, num_items, 2048, rl_params["actor_weight_init"]),
    "policy_net":           DiscreteActor(2048, 1290, num_items),
    "target_policy_net":    DiscreteActor(2048, 1290, num_items).eval(),
}

optimizer = {
    "value_optimizer": torch.optim.Adam(networks["value_net"].parameters(),
                                        lr=rl_params["value_lr"]),
    "policy_optimizer": torch.optim.Adam(networks["policy_net"].parameters(),
                                         lr=rl_params["policy_lr"])
}

loss = {
    "tst": {"value": [], "policy": [], "step": []},
    "train": {"value": [], "policy": [], "step": []}
}
debug = {}


In [19]:
learn = True

for idx, batch in enumerate(train_dataloader):
    state = batch["state"]
    action = batch["action"]
    reward = batch["reward"].unsqueeze(1)
    next_state = batch["next_state"]
    done = batch["done"].unsqueeze(1)

    predicted_action, predicted_probs = networks["policy_net"].select_action(state) # 현재 시점에 가장 좋을 만한 액션 선택
    
    expected_reward = networks["value_net"](state, predicted_probs).detach() # 선택한 액션 기반으로 받을 보상 예측
    networks["policy_net"].rewards.append(expected_reward.mean())

    # value loss
    value_loss = None
    
    # -----------------
    # func value update
    with torch.no_grad():
        next_action = networks["target_policy_net"](next_state) # 다음 시점에 가장 좋을 만한 액션 선택
        target_value = networks["target_value_net"](next_state, next_action.detach()) # 다음 시점에 받을 보상 예측
        # temporal difference
        expected_value = reward + (1.0 - done) * rl_params["gamma"] * target_value # 감쇄된 기대 보상
        expected_value = torch.clamp(
            expected_value, rl_params["min_value"], rl_params["max_value"]
        )

    value = networks["value_net"](state, action) # 실제 label 기반 받을 보상 예측
    value_loss = torch.pow(value - expected_value.detach(), 2).mean() # 받을 보상 예측 - 기대 보상 예측 MSE
    
    if learn:
        optimizer["value_optimizer"].zero_grad()
        value_loss.backward(retain_graph=True) # Value 모델 예측 
        optimizer["value_optimizer"].step()
    # -----------------
    # -----------------
    # func policy update
    if idx % 10 == 0 & idx > 0:
        policy_loss = rl_params["reinforce"](
            policy=networks["policy_net"],
            optimizer=optimizer["policy_optimizer"],
            learn=learn,
        )
    # -----------------
        del networks["policy_net"].rewards[:]
        del networks["policy_net"].saved_log_probs[:]

    ## soft update value_network
        for target_param, param in zip(networks["target_value_net"].parameters(), networks["value_net"].parameters()):
            target_param.data.copy_(
                target_param.data * (1.0 - rl_params["soft_tau"]) + param.data * rl_params["soft_tau"]
            )
    ## soft update policy_network
        for target_param, param in zip(networks["target_policy_net"].parameters(), networks["policy_net"].parameters()):
            target_param.data.copy_(
                target_param.data * (1.0 - rl_params["soft_tau"]) + param.data * rl_params["soft_tau"]
            )


tensor([[  6.5216,  -4.8781,  -2.6142,  ...,   1.0000,   5.0000,  -3.0000],
        [  1.9724,  -3.9022,   1.3081,  ...,   5.0000,  -3.0000,   3.0000],
        [  1.5967,  -7.9507,  -5.2240,  ...,  -3.0000,   3.0000,   5.0000],
        ...,
        [-11.6282,  -2.8539,   4.5772,  ...,   1.0000,  -1.0000,   1.0000],
        [ -9.1619,  11.3900,   0.6219,  ...,  -1.0000,   1.0000,   3.0000],
        [  7.5283,   1.2876,  -1.2489,  ...,   1.0000,   3.0000,  -1.0000]])
ACTION SCORES :  torch.Size([2304, 27278])
ACTION SCORES :  torch.Size([2304, 27278])
tensor([[  8.2706,   0.2508,  -4.5213,  ...,   1.0000,  -3.0000,   5.0000],
        [  5.1512,  -7.0847,   4.8552,  ...,  -3.0000,   5.0000,  -1.0000],
        [-13.5168,  -2.7655,   2.7820,  ...,   5.0000,  -1.0000,   1.0000],
        ...,
        [  6.0839,  -2.1943,  -1.7734,  ...,  -3.0000,   5.0000,   3.0000],
        [  6.9507,   1.4995,  -4.9188,  ...,   5.0000,   3.0000,  -3.0000],
        [  1.6219,  -7.4186,  -0.4721,  ...,   3.00

: 

: 