In [4]:
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from tensorboardX import SummaryWriter
import tqdm

use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
Tensor = FloatTensor

def weights_initialize(module):
    if type(module) == nn.Linear:
        nn.init.xavier_uniform_(module.weight, gain=nn.init.calculate_gain('relu'))
        module.bias.data.fill_(0.01)
        
class _TransModel(nn.Module):
    """ Model for DQN """

    def __init__(self, input_len, output_len):
        super(_TransModel, self).__init__()
        
        self.fc1 = nn.Sequential(
            torch.nn.Linear(input_len, 512),
            torch.nn.BatchNorm1d(512),
            nn.ReLU()
        )
        self.fc1.apply(weights_initialize)
        
        self.fc2 = nn.Sequential(
            torch.nn.Linear(512, 128),
            torch.nn.BatchNorm1d(128),
            nn.ReLU()
        )
        self.fc2.apply(weights_initialize)
        
        self.output_layer = nn.Sequential(
            torch.nn.Linear(128, output_len)
        )
        self.output_layer.apply(weights_initialize)
        
    def forward(self, input):
        x = self.fc1(input)
        x = self.fc2(x)
        
        return self.output_layer(x)

    
class TransModel():
    def __init__(self, input_len, ouput_len, learning_rate = 0.0001):
        self.model = _TransModel(input_len, ouput_len)
        
        if use_cuda:
            print("Using GPU")
            self.model = self.model.cuda()
        else:
            print("Using CPU")
        self.steps = 0
        self.model = nn.DataParallel(self.model)
        self.optimizer = Adam(self.model.parameters(), lr = learning_rate)
        self.loss_fn = nn.MSELoss(reduction='mean')
        
        self.summary = SummaryWriter(log_dir = 'trans_summary/')
        self.steps = 0
        
    def predict(self, input, steps, learning):
        output = self.model(input).squeeze(1)
        #reward, next_state = output[0], output[1:]

        return output

    def predict_batch(self, input):
        output = self.model(input)
        #reward, next_state = output[:, 0], output[:, 1:]
        return output

    def fit(self, state, target_state):
        loss = self.loss_fn(state, target_state)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.steps += 1
        self.summary.add_scalar(tag="loss/train_Loss",
                                scalar_value=float(loss),
                                global_step=self.steps)
        

In [5]:
data = torch.load('random_v_random.pt')
np.set_printoptions(suppress=True)

l = len(data)

for i in range(0, len(data)):
    data[i][1] = [data[i][1][4], data[i][1][9]]

print(data[0][0], data[0][1])

np.random.shuffle(data)

train_data = np.array(data[: int(np.floor(l * 0.5))])
test_data = np.array(data[int(np.floor(l * 0.5)) : ])
print(train_data.shape, test_data.shape)

batch_size = 64
summary_test = SummaryWriter(log_dir = 'test_summary/')

[   0.    1.    0.    1. 2000.    3.    0.    0.    0. 2000.    0.    0.
    0.    0.    0.    0.    0.] [2000.0, 2000.0]
(6744, 2) (6745, 2)


In [6]:
trans_model = TransModel(len(data[0][0]), len(data[0][1]))

Using CPU


In [7]:
def evaluation(model, data, epoch):
    state_action = torch.from_numpy(np.stack(data[:, 0])).type(FloatTensor)
    next_state_reward = torch.from_numpy(np.stack(data[:, 1])).type(FloatTensor)
    
    total = 0
    total_loss = 0
    model.model.eval()
    criterion = nn.MSELoss(reduction='mean')
    outputs = model.predict_batch(state_action)
    loss = criterion(outputs, next_state_reward)
    #print(outputs[:, 0 : 4].size())
    
    
    
    accuracy = torch.sum(torch.round(outputs[:, 0]) == torch.round(next_state_reward[:, 0]))
    accuracy += torch.sum(torch.round(outputs[:, 1]) == torch.round(next_state_reward[:, 1]))
    accuracy = accuracy.item() / (8 * outputs.size()[0])
    model.model.train()
    summary_test.add_scalar(tag="loss/test_Loss",
                            scalar_value=float(loss.item()),
                            global_step=epoch)
    summary_test.add_scalar(tag="acc/accuracy",
                            scalar_value=float(accuracy),
                            global_step=epoch)
    f = open("test_loss.txt", "a+")
    f.write("loss:" + str(loss.item()) + ", ")
    f.write("acc:" + str(accuracy) + "\n")
    if epoch % 1000 == 0:
        f.write("output:" + str(outputs[0:2]) + "\n")
        f.write("ground true:" + str(next_state_reward[0:2]) + "\n")
    f.close()
    return loss.item()

In [9]:
state_action = torch.from_numpy(np.stack(train_data[:, 0])).type(FloatTensor)
next_state_reward = torch.from_numpy(np.stack(train_data[:, 1])).type(FloatTensor)
print(state_action.size(), next_state_reward.size())

for epoch in tqdm.tqdm(range(10000)):
    s = np.arange(state_action.shape[0])
    np.random.shuffle(s)
    train_x = state_action[s]
    train_y = next_state_reward[s]
    for i in range(state_action.shape[0] // batch_size + 1):
        if (i + 1) * batch_size <= state_action.shape[0]:
            start = i * batch_size
            end = (i + 1) * batch_size
        else:
            start = i * batch_size
            end = state_action.shape[0]
        #print(start, end)
        inputs, ground_true = train_x[start : end, :], train_y[start : end, :]
        outputs = trans_model.predict_batch(inputs)
        trans_model.fit(outputs, ground_true)
#     print(epoch)
    evaluation(trans_model, test_data, epoch)
    #break


  0%|          | 0/10 [00:00<?, ?it/s][A

torch.Size([6744, 17]) torch.Size([6744, 2])



 10%|█         | 1/10 [00:00<00:07,  1.13it/s][A
 20%|██        | 2/10 [00:01<00:06,  1.16it/s][A
 30%|███       | 3/10 [00:02<00:05,  1.19it/s][A
 40%|████      | 4/10 [00:03<00:04,  1.25it/s][A
 50%|█████     | 5/10 [00:03<00:03,  1.26it/s][A
 60%|██████    | 6/10 [00:04<00:03,  1.27it/s][A
 70%|███████   | 7/10 [00:05<00:02,  1.23it/s][A
 80%|████████  | 8/10 [00:06<00:01,  1.24it/s][A
 90%|█████████ | 9/10 [00:07<00:00,  1.25it/s][A
100%|██████████| 10/10 [00:07<00:00,  1.27it/s][A