In [1]:
import torch
from torch import nn
from torch.autograd import Variable
import torch.utils.data as Data
import numpy as np
import tensorflow as tf

  from ._conv import register_converters as _register_converters


In [None]:
class Dynamic_LSTM(nn.Module):
    def __init__(self, n_actions, n_units, n_input, n_hidden, n_output, lamda, dropout=0.3):
        super(Dynamic_LSTM, self).__init__()

        # hyperparameters
        self.n_actions = n_actions  # last K hidden state
        self.n_units = n_units  # hidden unit of Agent MLP
        self.n_input = n_input  # input size
        self.n_hidden = n_hidden  # hidden size of LSTM
        self.n_output = n_output  # output dim
        self.lamda = lamda
        self.dropout = dropout

        self.agent_action = []
        self.agent_prob = []

        # layers
        self.fc1 = nn.Linear(self.n_hidden + self.n_input, self.n_units)
        self.fc2 = nn.Linear(self.n_units, self.n_actions)
        self.x2h = nn.Linear(self.n_input, 4 * self.n_hidden)
        self.h2h = nn.Linear(self.n_hidden, 4 * self.n_hidden)
        self.output = nn.Linear(self.n_hidden, self.n_output)
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax()
        self.dropout = nn.Dropout(p=self.dropout)
        self.tanh = nn.Tanh()

    def choose_action(self, observation, cur_time):
        observation = observation.detach()
        result_fc1 = self.fc1(observation)
        result_fc2 = self.fc2(result_fc1)
        probs = self.softmax(result_fc2)
        m = torch.distributions.Categorical(probs)
        actions = m.sample()
        if cur_time != 0:
            self.agent_action.append(actions.unsqueeze(-1))
            self.agent_prob.append(m.log_prob(actions))

        return actions.unsqueeze(-1)

    def forward(self, input):
        # input shape [batch_size, timestep, feature_dim]
        batch_size = input.size(0)
        time_step = input.size(1)
        feature_dim = input.size(2)
        assert(feature_dim == self.n_input)

        # Initialization
        cur_time = 0  # Current timestep
        self.agent_action = []  # Actions for agents
        self.agent_prob = []  # Probabilities for agents
        # Hidden state for lstm
        cur_h = Variable(torch.zeros(batch_size, self.n_hidden))
        # Cell memory for lstm
        cur_c = Variable(torch.zeros(batch_size, self.n_hidden))
        c = []  # Cell memory list for lstm
        h = []  # Hidden state list for lstm

        for cur_time in range(time_step):
            if cur_time == 0:
                self.choose_action(
                    torch.cat((input[:, 0, :], cur_h), 1), cur_time)
                observed_c = torch.zeros_like(cur_c, dtype=torch.float32).view(-1).repeat(
                    self.n_actions).view(self.n_actions, batch_size, self.n_hidden)
                observed_h = torch.zeros_like(cur_h, dtype=torch.float32).view(-1).repeat(
                    self.n_actions).view(self.n_actions, batch_size, self.n_hidden)
                action_c = cur_c
                action_h = cur_h
            else:
                observed_c = torch.cat((observed_c[1:], cur_c.unsqueeze(0)), 0)
                observed_h = torch.cat((observed_h[1:], cur_h.unsqueeze(0)), 0)
                # use h(t-1) or mean h?
                observation = torch.cat((input[:, cur_time, :], cur_h), 1)
                actions = self.choose_action(observation, cur_time)
                coord = torch.cat((actions.int(), torch.arange(
                    batch_size, dtype=torch.int).unsqueeze(-1)), 1)
                action_c = torch.stack([observed_c[i, j, :]
                                        for [i, j] in coord])
                action_h = torch.stack([observed_h[i, j, :]
                                        for [i, j] in coord])
            
            weighted_c = self.lamda * action_c + (1-self.lamda)*cur_c
            weighted_h = self.lamda * action_h + (1-self.lamda)*cur_h

            gates = self.x2h(input[:, cur_time, :]) + self.h2h(weighted_h)
            gates = gates.squeeze()

            ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

            ingate = self.sigmoid(ingate)
            forgetgate = self.sigmoid(forgetgate)
            if self.dropout == 0:
                cellgate = self.tanh(cellgate)
            else:
                cellgate = self.dropout(self.tanh(cellgate))
            outgate = self.sigmoid(outgate)

            cur_c = torch.mul(weighted_c, forgetgate) + \
                torch.mul(ingate, cellgate)
            cur_h = torch.mul(outgate, self.tanh(cur_c))
            c.append(cur_c)
            h.append(cur_h)

        opt = self.output(cur_h)
        opt = self.softmax(opt)

        return opt

In [None]:
all_train = []  # the first 10 numbers are inputs, the last one is label in each row
all_test = []
for _ in range(100000):
    a = np.random.choice(10, 9)
    b = np.random.choice(9, 1)
    c = a[b]
    train = np.append(np.append(a, b), c)
    all_train.append(train)


for _ in range(10000):
    a = np.random.choice(10, 9)
    b = np.random.choice(9, 1)
    c = a[b]
    test = np.append(np.append(a, b), c)
    all_test.append(test)

In [None]:
onehot = torch.eye(10)
all_train = onehot[all_train].numpy()
all_test = onehot[all_test].numpy()

train_data = all_train[:, :-1, :]
train_label = all_train[:, -1, :]
test_data = all_test[:, :-1, :]
test_label = all_test[:, -1, :]

train_dataset = Data.TensorDataset(torch.tensor(
    train_data, dtype=torch.float32), torch.tensor(train_label, dtype=torch.float32))
test_dataset = Data.TensorDataset(torch.tensor(
    test_data, dtype=torch.float32), torch.tensor(test_label, dtype=torch.float32))

In [None]:
model = Dynamic_LSTM(n_actions=10, n_units=10, n_input=10,
                     n_hidden=30, n_output=10, lamda=1, dropout=0)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100)

best_acc = 0
train_loss = []
count = 0

for epoch in range(200):
    print('****************  RL is beginning  *******************')
    model.train()
    cur_loss = []
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss_task = torch.mean(-torch.log(output+1e-7)*target)
        correct_pred = torch.argmax(output, 1).eq(torch.argmax(target, 1))
        accuracy = torch.mean(correct_pred.float())

        act_prob = model.agent_prob
        act_prob = torch.stack(act_prob).permute(1, 0)
        acts = model.agent_action
        acts = torch.squeeze(torch.stack(acts)).permute(1, 0)
        #neg_log_prob = torch.sum(-torch.log(act_prob+0.000001) * onehot[acts], dim=2)
        rewards = (correct_pred.float() - 0.5) * 2
        rewards = rewards.unsqueeze(-1)
        loss_RL = torch.mean(-act_prob * rewards)
        loss_total = loss_task+loss_RL*0.3
        cur_loss.append(loss_total.detach().numpy())
        loss_total.backward()
        optimizer.step()

        if batch_idx % 200 == 0:
            print('the %d epoch the %d time accuracy is %f, loss is %f' %
                  (epoch, batch_idx, accuracy, loss_total))
            print('loss_RL:', loss_RL.detach().numpy())
            print('loss_Task:', loss_task.detach().numpy())
            print(acts[:5,-1])
            print(torch.argmax(target[:5], dim=1))
            print()
    train_loss.append(np.average(np.array(cur_loss)))

    print('###############  TESTING  ####################')
    model.eval()
    test_loss = []
    test_acc = []
    for batch_idx, (data, target) in enumerate(test_loader):
        data, target = Variable(data), Variable(target)
        output = model(data)
        loss_task = torch.mean(-torch.log(output+1e-7)*target)
        correct_pred = torch.argmax(output, 1).eq(torch.argmax(target, 1))
        accuracy = torch.mean(correct_pred.float())

        act_prob = model.agent_prob
        acts = model.agent_action
        acts = torch.squeeze(torch.stack(acts)).permute(1, 0)
        
        act_prob = torch.stack(act_prob).permute(1, 0)
        rewards = (correct_pred.float() - 0.5) * 2
        rewards = rewards.unsqueeze(-1)
        loss_RL = torch.mean(act_prob * rewards)
        loss_total = loss_task+loss_RL

        test_loss.append(loss_total)
        test_acc.append(accuracy)
    print('the TEST accuracy is %f, loss is %f' %
          (sum(test_acc)/len(test_acc), sum(test_loss)/len(test_loss)))

    cur_acc = sum(test_acc)/len(test_acc)
    if cur_acc > best_acc:
        best_acc = cur_acc
        print('===============================================>>>> SAVE MODEL')
        count = 0
    count += 1
    if count == 5:
        print('--------------------------------------------->>>>  EARLY STOP!!!')