In [None]:
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
from IPython.display import clear_output
import numpy as np
import random
import time as timer
import pdb
import gensim
from itertools import count
from torch.distributions import Categorical

# import eventime as Tlink


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [2]:
dct_inputs = torch.randint(0, 100, (50, 1, 3), dtype=torch.long)
time_inputs = torch.randint(0, 100, (50, 1, 2, 3), dtype=torch.long)
targets = torch.randint(0, 2, (50, 3), dtype=torch.long)
print(dct_inputs.size(), time_inputs.size(), targets.size())
print(dct_inputs.is_cuda)

torch.Size([50, 1, 3]) torch.Size([50, 1, 2, 3]) torch.Size([50, 3])
False


In [3]:
dataset = Tlink.MultipleDatasets(dct_inputs, time_inputs, targets)

BATCH_SIZE = 1
loader = Data.DataLoader(
    dataset = dataset,
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers = 1,
    pin_memory=True
)

In [5]:
GAMMA = 0.99
EMBEDDING_DIM = 64
DCT_HIDDEN_DIM = 60
TIME_HIDDEN_DIM = 50
VOCAB_SIZE = 100
ACTION_SIZE = 8
EPOCH_NUM = 50

In [6]:
def select_action(state_i, dpath_input):
    probs = policy(state_i, dpath_input)
    m = Categorical(probs)
    action = m.sample()
    policy.saved_log_probs.append(m.log_prob(action))
    return action.item()

In [7]:
def finish_episode():

    R = 0
    policy_loss = []
    rewards = []
    for r in policy.rewards[::-1]:
        R = r + GAMMA * R
        rewards.insert(0, R)
    rewards = torch.tensor(rewards, device=device)
    # print(rewards, eps, rewards.mean(), rewards.std())
    rewards = (rewards - rewards.mean()) / (rewards.std() + eps) # reward normalization
    for log_prob, reward in zip(policy.saved_log_probs, rewards):
        # print(log_prob, reward, -log_prob * reward)
        policy_loss.append(-log_prob * reward)
    optimizer.zero_grad()
    # print('finish:', policy.rewards, policy.saved_log_probs, torch.cat(policy_loss))
    policy_loss = torch.cat(policy_loss).sum()

    policy_loss.backward(retain_graph=True)
    optimizer.step()

    del policy.rewards[:]
    del policy.saved_log_probs[:]
    return policy_loss

In [8]:
def step(state_i, action, anchor, target):

    update_strategies = torch.tensor([[0, 0, 0],
                                      [1, 0, 0],
                                      [0, 1, 0],
                                      [0, 0, 1],
                                      [1, 1, 0],
                                      [0, 1, 1],
                                      [1, 0, 1],
                                      [1, 1, 1]], dtype=torch.long, device=device)
    timex = torch.tensor([0, 0, 0], dtype=torch.long, device=device) \
        if state_i == 0 else torch.tensor([1, 1, 1], dtype=torch.long, device=device)
    anchor = update_strategies[action] * timex \
             + (torch.ones_like(update_strategies[action], dtype=torch.long, device=device) - update_strategies[action]) \
             * anchor
    # print(anchor, target)
    # print(torch.eq(anchor, target).sum(), anchor.numel())
    reward = torch.div(torch.eq(anchor, target).sum().float(), anchor.numel())
    state_i += 1
    return state_i, reward, anchor

In [10]:
policy = Tlink.DqnInferrer(EMBEDDING_DIM, DCT_HIDDEN_DIM, TIME_HIDDEN_DIM, VOCAB_SIZE, ACTION_SIZE, BATCH_SIZE).to(device)
optimizer = optim.Adam(policy.parameters(), lr=1e-2)
eps = np.finfo(np.float32).eps.item()

for epoch in range(EPOCH_NUM):
    start_time = timer.time()
    total_loss = torch.tensor([0.], device=device)
    total_reward = torch.tensor([0.], device=device)
    
    for episode, (dct_input, time_inputs, target) in enumerate(loader):
#         print(dct_input, time_inputs, target)
        dct_input = dct_input.to(device=device)
        time_inputs = time_inputs.to(device=device)
        target = target.to(device=device)
        
        ## dct state
        state_i = 0
        dct_action = select_action(state_i, dct_input)
        anchor = torch.tensor([-1, -1, -1], dtype=torch.long, device=device)
        state_i, reward, anchor = step(state_i, dct_action, anchor, target)
#         print('dct action:', dct_action, anchor, reward)
        policy.rewards.append(reward)

        ## time states
        for timex in range(time_inputs.size()[1]):
            time_input = time_inputs[:, timex, :, :]
            time_action = select_action(state_i, time_input)
            state_i, reward, anchor = step(state_i, time_action, anchor, target)
#             print('time action:', time_action, anchor, reward)
            policy.rewards.append(reward)
        total_reward += reward
#         print(total_reward)
        total_loss += finish_episode()
    print('Iter:', epoch, ', loss: %.4f' % (total_loss.item()/50), ', reward: %.4f' % (total_reward.item() / 50), ", %.4s seconds" % (timer.time() - start_time))

Iter: 0 , loss: 0.0113 , reward: 0.3467 , 3.32 seconds


Iter: 1 , loss: -0.1033 , reward: 0.3533 , 10.7 seconds


Iter: 2 , loss: -0.1803 , reward: 0.4133 , 15.6 seconds


Iter: 3 , loss: -0.1843 , reward: 0.4733 , 21.1 seconds


Iter: 4 , loss: -0.0857 , reward: 0.3667 , 26.8 seconds


Iter: 5 , loss: -0.0588 , reward: 0.4133 , 33.1 seconds


Iter: 6 , loss: -0.1993 , reward: 0.3800 , 39.8 seconds


Iter: 7 , loss: -0.1919 , reward: 0.4267 , 47.1 seconds


Iter: 8 , loss: -0.1926 , reward: 0.4667 , 52.0 seconds


Iter: 9 , loss: -0.1702 , reward: 0.4333 , 58.9 seconds


Iter: 10 , loss: -0.2217 , reward: 0.3867 , 73.7 seconds


Iter: 11 , loss: -0.0888 , reward: 0.3733 , 79.6 seconds


Iter: 12 , loss: -0.2986 , reward: 0.4400 , 91.7 seconds


Iter: 13 , loss: -0.2864 , reward: 0.5000 , 103. seconds


Iter: 14 , loss: -0.3266 , reward: 0.4067 , 124. seconds


Iter: 15 , loss: -0.4152 , reward: 0.4467 , 166. seconds


Iter: 16 , loss: -0.2532 , reward: 0.3333 , 131. seconds


Iter: 17 , loss: -0.2882 , reward: 0.4533 , 132. seconds


Iter: 18 , loss: -0.4488 , reward: 0.4800 , 120. seconds


Iter: 19 , loss: -0.4537 , reward: 0.4467 , 132. seconds


Iter: 20 , loss: -0.1848 , reward: 0.4067 , 125. seconds


Iter: 21 , loss: -0.2145 , reward: 0.3933 , 133. seconds


Iter: 22 , loss: -0.3986 , reward: 0.4267 , 147. seconds


Iter: 23 , loss: -0.1957 , reward: 0.5000 , 158. seconds


Iter: 24 , loss: -0.3219 , reward: 0.4533 , 162. seconds


Iter: 25 , loss: -0.2891 , reward: 0.4400 , 169. seconds


Iter: 26 , loss: -0.1917 , reward: 0.4133 , 206. seconds


Iter: 27 , loss: -0.2425 , reward: 0.4133 , 208. seconds


Iter: 28 , loss: -0.2045 , reward: 0.4000 , 208. seconds


Iter: 29 , loss: -0.2265 , reward: 0.4267 , 216. seconds


Iter: 30 , loss: -0.3524 , reward: 0.3867 , 221. seconds


Iter: 31 , loss: -0.2425 , reward: 0.3800 , 235. seconds


Iter: 32 , loss: -0.4200 , reward: 0.4867 , 231. seconds


Iter: 33 , loss: -0.2184 , reward: 0.4533 , 240. seconds


Iter: 34 , loss: -0.2513 , reward: 0.4133 , 249. seconds


Iter: 35 , loss: -0.4097 , reward: 0.4400 , 257. seconds


Iter: 36 , loss: -0.2778 , reward: 0.4333 , 268. seconds


Iter: 37 , loss: -0.5463 , reward: 0.4200 , 257. seconds


Iter: 38 , loss: -0.3539 , reward: 0.4733 , 282. seconds


Iter: 39 , loss: -0.2312 , reward: 0.4067 , 287. seconds


Iter: 40 , loss: -0.2554 , reward: 0.4133 , 299. seconds


Iter: 41 , loss: -0.3933 , reward: 0.4200 , 294. seconds


Iter: 42 , loss: -0.1829 , reward: 0.4467 , 300. seconds


Iter: 43 , loss: -0.1371 , reward: 0.4133 , 308. seconds


Iter: 44 , loss: -0.2186 , reward: 0.4533 , 317. seconds


Iter: 45 , loss: -0.5117 , reward: 0.4667 , 327. seconds


Iter: 46 , loss: -0.3183 , reward: 0.5200 , 353. seconds


Iter: 47 , loss: -0.3742 , reward: 0.4667 , 357. seconds


Iter: 48 , loss: -0.5219 , reward: 0.4467 , 361. seconds


Iter: 49 , loss: -0.3861 , reward: 0.4800 , 370. seconds
