# Create a simple forward or forward + inverse model to identify actions that need to be learned?

In [200]:
#prepare data
import pickle

with open('../projectFlorian/gym-microrts/experiments/data.pickle', 'rb') as handle:
    b = pickle.load(handle)
print(b.keys())

dict_keys(['obs', 'b_logprobs', 'b_actions', 'b_advantages', 'b_returns', 'b_values', 'b_invalid_action_masks'])


In [313]:
b_actions, b_obs, b_invalid_masks = b["b_actions"], b["obs"], b["b_invalid_action_masks"]
b_actions = np.array(b_actions)
b_obs = np.array(b_obs)
b_invalid_masks = np.array(b_invalid_masks)
print(b_actions.shape, b_obs.shape)

#collapse one dimension
b_obs = b_obs.reshape(-1, *b_obs.shape[-3:])
b_actions = b_actions.reshape(-1, *b_actions.shape[-2:])
b_invalid_masks = b_invalid_masks.reshape(-1, *b_invalid_masks.shape[-2:])
print(b_actions.shape, b_obs.shape, b_invalid_masks.shape)

(3, 2048, 256, 7) (3, 2048, 16, 16, 27)
(6144, 256, 7) (6144, 16, 16, 27) (6144, 256, 78)


In [316]:
#generate "fake next states"
shift = list(range(1,b_obs.shape[0]))
b_next_state = b_obs[shift]
#remove last elemt in actions and obs (no next state)
b_obs = b_obs[:-1] 
b_actions = b_actions[:-1]
b_invalid_masks = b_invalid_masks[:-1]

In [317]:
#create tensors
b_obs_tensor = torch.Tensor(b_obs)
b_actions_tensor = torch.Tensor(b_actions)
b_next_state_tensor = torch.Tensor(b_next_state)
b_invalid_masks_tensor = torch.Tensor(b_invalid_masks)

In [101]:
print(b_obs_tensor.shape, b_next_state_tensor.shape, b_actions_tensor.shape)

torch.Size([6143, 16, 16, 27]) torch.Size([6143, 16, 16, 27]) torch.Size([6143, 256, 7])


In [342]:
import torch
import torch.nn as nn
import numpy as np
from torch import optim


class Transpose(nn.Module):
    def __init__(self, permutation):
        super().__init__()
        self.permutation = permutation

    def forward(self, x):
        return x.permute(self.permutation)


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class DynamicsModel(nn.Module):
    def __init__(self, observation_space_shape, num_actions, action_space):
        super(DynamicsModel, self).__init__()
        h, w, c = observation_space_shape
        action_ohe_shape = num_actions[0] * num_actions[1]
        
        c += action_space
        
        num_actions  = num_actions[0]
        self.encoder = nn.Sequential(
            Transpose((0, 3, 1, 2)),
            layer_init(nn.Conv2d(c, 32, kernel_size=3, padding=1)),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, kernel_size=3, padding=1)),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.ReLU(),
        )
        
        
        self.feature_size = 64 * 4 * 4
        #simplest inverse model that ouputs a single action
        
#         self.inverse_net = nn.Sequential(
#             nn.Linear(self.feature_size * 2, 128),
#             nn.LeakyReLU(),
#             nn.Linear(128, num_actions)
#         )
        
        #version that predicts actions and parameters
        self.inverse_net = nn.Sequential(
            layer_init(nn.ConvTranspose2d(128, 32, 3, stride=2, padding=1, output_padding=1)),
            nn.ReLU(),
            layer_init(nn.ConvTranspose2d(32, 78, 3, stride=2, padding=1, output_padding=1)),
            Transpose((0, 2, 3, 1)),
        )
        
        self.forward_net = nn.Sequential(
            nn.Linear(self.feature_size + action_ohe_shape, 128),
            nn.LeakyReLU(),
            nn.Linear(128, self.feature_size)
        )
        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                # nn.init.kaiming_uniform_(module.weight)
                nn.init.constant_(module.bias, 0)

    def forward(self, state, next_state, action, action_ohe, masks):
        
        state_ft = self.encoder(torch.cat((state, masks),3))
        next_state_ft = self.encoder(torch.cat((next_state, masks),3))
        state_ft2 = state_ft.view(-1, self.feature_size)
        next_state_ft2 = next_state_ft.view(-1, self.feature_size)
        
    
        print(state_ft2.shape, action_ohe.shape, masks.shape)
        print(action.shape, masks.shape)
        
        #print("MSSK",masks.shape, state_ft.shape, state.shape, action.shape)
        return self.inverse_net(torch.cat((state_ft, next_state_ft), 1)), self.forward_net(
            torch.cat((state_ft2, action_ohe), 1)), next_state_ft2

    
def magic_combine(x, dim_begin, dim_end):
    combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
    return x.view(combined_shape)
    
def train(device, nb_epochs = 5, action_plane_space_sum = 78, action_plane_space_tolist = [6, 4, 4, 4, 4, 7, 49]):
    
    batch_size = 32
    
    #create a fake batch
    state = b_obs_tensor[:batch_size]
    next_state = b_next_state_tensor[:batch_size]
    action = b_actions_tensor[:batch_size].to(dtype=torch.long)
    invalid_masks = b_invalid_masks_tensor[:batch_size]
    
    
    action_ohe = action.flatten(1,2)
    
    
    
    inv_criterion = nn.CrossEntropyLoss()
    fwd_criterion = nn.MSELoss()
    m = DynamicsModel((16, 16, 27), action.shape[1:], 78)
    m = m.to(device)
    optimizer = optim.Adam(m.parameters(), lr=1e-4, eps=1e-5)
    
    inv_losses = []
    fwd_losses = []
    total_losses = []
    
    beta = 0.2
    
    
    
    for epoch in range(nb_epochs):
        #state, action, next_state = torch.rand((32,16, 16, 27)), torch.rand((32,334)), torch.rand((32,16, 16, 27))
        
        invalid_masks = invalid_masks.reshape(state.shape[0], state.shape[1], state.shape[2], -1)
        
        
        pred_logits, pred_phi, phi = m(state, next_state, action, action_ohe, invalid_masks)
        action_oh = torch.zeros((batch_size),dtype=torch.long)
        
        
        
        
        #if envs is available
        #grid_logits = pred_logits.reshape(-1, envs.action_plane_space.nvec.sum())
        #split_logits = torch.split(grid_logits, envs.action_plane_space.nvec.tolist(), dim=1)
        #grid_logits = pred_logits.reshape(-1, action_plane_space_sum)
        split_logits = torch.split(pred_logits, action_plane_space_tolist, dim=3)
        
        #print(grid_logits.shape)
        
        
        
        inv_loss = 0 
        for i in range(1):
            splitted_logit = split_logits[i].view(split_logits[i].shape[0], split_logits[i].shape[1] * split_logits[i].shape[2], split_logits[i].shape[3])
            splitted_logit = splitted_logit.reshape(-1, splitted_logit.shape[-1])
            splitted_actions = action[:, : , i].reshape(-1)
            inv_loss  = inv_loss + inv_criterion(splitted_logit, splitted_actions) * 1.0 / len(split_logits)
            
            print(splitted_logit[:5], splitted_actions[:5])
        
        
        
        
        fwd_loss = fwd_criterion(pred_phi, phi) / 2
        
        total_loss = (1 - beta) * inv_loss + beta * fwd_loss
        optimizer.zero_grad()
        total_loss.backward()
        
        inv_losses.append(inv_loss)
        fwd_losses.append(fwd_loss)
        total_losses.append(total_loss)
        print("Total loss {}".format(total_loss.detach().item()))
        optimizer.step()
    
    
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
train(device , nb_epochs=50)

# ERREUR DANS CODE POUR INV LOSS!!! (maybe how it concatenated next state/ state and masks, or loss computing? , or forward function)
# INV lOSS DOES NOT REALLY decrease!
# Maybe not good to use batch 8XXX ? (wrong use of cross entropy?)
# How about just ouput 256 values? And use Categorical cross entropy on it? (rather than as now?)
# Maybe mask not correctly used, so it is impossible to correctly predict actions?
# PROBABLY LOSS COMPUTATION NOT CORRECT, CHECK PPO
# THE WAY LOSS IS COMPUTED MAKES IT IMPOSSIBLE FOR MODEL TO UNDERSTAND HOW TO UPDATE GRADIENT (8000 elemens  * 6 !!!)
#TEST WITH MORE EXAMPLES!!! (Maybe current examples make it to learn?)


"""TO FIX:
     
WHY RANDOM (STATE , value should not changed?) ??? (true state should not change)
ACTION GENERATION NOT CORRECT (USE MULTI DISTRIBUTION)
USE RANDMOM BATCH of DATA
SEARCH FOR BUGS?
HOW TO HANDLE ACTION? SAME AS NOW ? USE CATEGORICAL DISTRIBITION?

action_oh is wrong (not implemented)!!!

maybe keep like now? (only predit the action type  - not the specified parameter) ? (especially, if 
IF PREDICT PARAMETERS, SHOULD COMPUTE LOSS ONLY FOR SELECTED ACTIONS, PARAMETERS OF OTHER ACTIONS IS NOT SO RELEVANT I THINK)
Anyway, predicted action seems weird ? torch.Size([32, 256]) / SHOULD IT BE ONE VALUE FOR EACH BOARD? SEEMS LKE IT HANDLES LIKE ONE ACTION
FOR EACH ENV, NOT 256 ACTIONS PER ENV



"""

torch.Size([32, 1024]) torch.Size([32, 1792]) torch.Size([32, 16, 16, 78])
torch.Size([32, 256, 7]) torch.Size([32, 16, 16, 78])
tensor([[ 0.0122, -0.0012, -0.0322, -0.0562, -0.0893,  0.0134],
        [ 0.0521,  0.0506,  0.0313, -0.0123, -0.0878,  0.0344],
        [ 0.0268, -0.0384,  0.0085, -0.0253,  0.0007,  0.0146],
        [-0.0059,  0.0071,  0.0426, -0.0680, -0.1612,  0.0420],
        [ 0.0149, -0.0030, -0.0297, -0.0581, -0.0815,  0.0196]],
       grad_fn=<SliceBackward>) tensor([1, 2, 0, 2, 0])
Total loss 1.7865792512893677
torch.Size([32, 1024]) torch.Size([32, 1792]) torch.Size([32, 16, 16, 78])
torch.Size([32, 256, 7]) torch.Size([32, 16, 16, 78])
tensor([[ 0.0124, -0.0013, -0.0306, -0.0559, -0.0868,  0.0118],
        [ 0.0545,  0.0509,  0.0291, -0.0126, -0.0863,  0.0330],
        [ 0.0260, -0.0383,  0.0095, -0.0244,  0.0010,  0.0149],
        [-0.0037,  0.0076,  0.0402, -0.0694, -0.1588,  0.0406],
        [ 0.0146, -0.0036, -0.0284, -0.0578, -0.0795,  0.0181]],
       grad_fn

torch.Size([32, 1024]) torch.Size([32, 1792]) torch.Size([32, 16, 16, 78])
torch.Size([32, 256, 7]) torch.Size([32, 16, 16, 78])
tensor([[ 1.6153e-02, -9.4427e-03, -1.0236e-04, -6.0241e-02, -7.3286e-02,
         -2.2201e-02],
        [ 8.8408e-02,  5.8413e-02,  6.4229e-03,  9.3229e-03, -7.8659e-02,
          2.4168e-02],
        [ 2.6566e-02, -4.5873e-02,  2.1704e-02, -1.6750e-02,  3.0448e-03,
          1.1192e-02],
        [ 3.2499e-02,  2.1698e-02,  1.1818e-02, -6.9049e-02, -1.5261e-01,
          3.1988e-02],
        [ 1.6992e-02, -1.0325e-02, -1.6171e-02, -6.0112e-02, -6.4794e-02,
         -2.4591e-03]], grad_fn=<SliceBackward>) tensor([1, 2, 0, 2, 0])
Total loss 0.3523474931716919
torch.Size([32, 1024]) torch.Size([32, 1792]) torch.Size([32, 16, 16, 78])
torch.Size([32, 256, 7]) torch.Size([32, 16, 16, 78])
tensor([[ 0.0163, -0.0103,  0.0012, -0.0610, -0.0727, -0.0240],
        [ 0.0896,  0.0582,  0.0059,  0.0106, -0.0783,  0.0245],
        [ 0.0270, -0.0467,  0.0215, -0.0174,  0.0

KeyboardInterrupt: 

In [308]:
# Example of target with class indices
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()

print(input.shape, target.shape)

torch.Size([3, 5]) torch.Size([3])
