In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchrl.networks.init as init
import json
import numpy as np
from torch.distributions import Distribution, Normal
from torchrl.utils import get_params

In [2]:
# pf explore
class TanhNormal(Distribution):
    """
    Basically from RLKIT

    Represent distribution of X where
        X ~ tanh(Z)
        Z ~ N(mean, std)

    Note: this is not very numerically stable.
    """

    def __init__(self, normal_mean, normal_std, epsilon=1e-6):
        """
        :param normal_mean: Mean of the normal distribution
        :param normal_std: Std of the normal distribution
        :param epsilon: Numerical stability epsilon when computing log-prob.
        """
        self.normal_mean = normal_mean
        self.normal_std = normal_std
        self.normal = Normal(normal_mean, normal_std)
        self.epsilon = epsilon

    def sample_n(self, n, return_pre_tanh_value=False):
        z = self.normal.sample_n(n)
        if return_pre_tanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)

    def log_prob(self, value, pre_tanh_value=None):
        """

        :param value: some value, x
        :param pre_tanh_value: arctanh(x)
        :return:
        """
        if pre_tanh_value is None:
            pre_tanh_value = torch.log(
                (1 + value) / (1 - value)
            ) / 2
        return self.normal.log_prob(pre_tanh_value) - torch.log(
            1 - value * value + self.epsilon
        )

    def sample(self, return_pretanh_value=False):
        """
        Gradients will and should *not* pass through this operation.

        See https://github.com/pytorch/pytorch/issues/4620 for discussion.
        """
        z = self.normal.sample().detach()

        if return_pretanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)

    def rsample(self, return_pretanh_value=False):
        """
        Sampling in the reparameterization case.
        """
        z = (
                self.normal_mean +
                self.normal_std *
                Normal(
                    torch.zeros(self.normal_mean.size()),
                    torch.ones(self.normal_std.size())
                ).sample().to(self.normal_mean.device)
        )

        if return_pretanh_value:
            return torch.tanh(z), z
        else:
            return torch.tanh(z)

    def entropy(self):
        return self.normal.entropy()

In [3]:
# new 2 MLPBaseAE
class MLPBaseAE(nn.Module):
    def __init__(self, input_shape, hidden_shapes, expert_nums, expert_hidden_shapes, tower_hidden_shapes,
                 attention_shapes, activation_func=F.relu, init_func=init.basic_init, last_activation_func=None,
                 flag=None, expert_id=None, v_id=None, q_id=None, k_id=None, task_nums = None):
        super().__init__()
        self.flag = flag
        self.activation_func = activation_func
        self.fcs = []
        if last_activation_func is not None:
            self.last_activation_func = last_activation_func
        else:
            self.last_activation_func = activation_func

        if flag == 'baseline':
            input_shape = np.prod(input_shape)
            self.output_shape = input_shape
            for i, next_shape in enumerate(hidden_shapes):
                fc = nn.Linear(input_shape, next_shape)
                init_func(fc)
                self.fcs.append(fc)
                # set attr for pytorch to track parameters( device )
                self.__setattr__("baseline_fc{}".format(i), fc)

                input_shape = next_shape
                self.output_shape = next_shape

        elif flag == 'expert':
            input_shape = hidden_shapes[-1]
            self.output_shape = input_shape

            for i, next_shape in enumerate(expert_hidden_shapes):
                fc = nn.Linear(input_shape, next_shape)
                init_func(fc)
                self.fcs.append(fc)
                # set attr for pytorch to track parameters( device )
                self.__setattr__("expert{}_fc{}".format(expert_id, i), fc)

                input_shape = next_shape
                self.output_shape = next_shape

        elif flag == 'tower':
            input_shape = attention_shapes
            self.output_shape = input_shape

            for i, next_shape in enumerate(tower_hidden_shapes):
                fc = nn.Linear(input_shape, next_shape)
                init_func(fc)
                self.fcs.append(fc)
                # set attr for pytorch to track parameters( device )
                self.__setattr__("tower_fc{}".format(i), fc)

                input_shape = next_shape
                self.output_shape = next_shape

        elif flag == 'attention_v':
            input_shape = expert_hidden_shapes[-1]
            self.output_shape = input_shape

            fc = nn.Linear(input_shape, attention_shapes, bias=False)
            # init_func(fc)
            self.fcs.append(fc)
            # set attr for pytorch to track parameters( device )
            self.__setattr__("v{}_fc".format(v_id), fc)

            self.output_shape = attention_shapes

        elif flag == 'attention_k':
            input_shape = expert_hidden_shapes[-1]
            self.output_shape = input_shape

            fc = nn.Linear(input_shape, attention_shapes, bias=False)
            # init_func(fc)
            self.fcs.append(fc)
            # set attr for pytorch to track parameters( device )
            self.__setattr__("k{}_fc".format(k_id), fc)

            self.output_shape = attention_shapes

        elif flag == 'attention_q':
            input_shape = 1
            self.output_shape = input_shape

            fc = nn.Linear(input_shape, attention_shapes, bias=False)
#             init_func(fc)
            self.fcs.append(fc)
            # set attr for pytorch to track parameters( device )
            self.__setattr__("q_fc", fc)

            self.output_shape = attention_shapes

    def forward(self, x):
        out = x
        for fc in self.fcs[:-1]:
            out = fc(out)
            out = self.activation_func(out)
        out = self.fcs[-1](out)
        if 'attention' not in self.flag:
            out = self.last_activation_func(out)
        return out

In [4]:
# 1 NetAE
class NetAE(nn.Module):
    def __init__(
            self, output_shape,
            base_type,
            expert_nums,
            task_nums,
            append_hidden_shapes=[],
            append_hidden_init_func=init.basic_init,
            net_last_init_func=init.uniform_init,
            activation_func=F.relu,
            **kwargs):

        super().__init__()
        self.task_nums = task_nums
        self.n = expert_nums
        # 0 baseline network 2 layers mlp 300 300
        self.base = base_type(activation_func=activation_func, expert_nums=self.n, flag='baseline', **kwargs)
        self.activation_func = activation_func

        # 1 experts networks 3 * 2 layers mlp 400 400 output
        self.mlp_list = nn.ModuleList()

        for i in range(self.n):
            self.mlp_list.append(
                base_type(activation_func=activation_func, expert_nums=self.n, flag='expert', expert_id=i, **kwargs))

            # 2 attention module
        self.attention_v_list = nn.ModuleList()
#         self.attention_q_list = nn.ModuleList()
        self.attention_k_list = nn.ModuleList()
        for i in range(self.n):
            self.attention_v_list.append(
                base_type(activation_func=activation_func, expert_nums=self.n, flag='attention_v', v_id=i, **kwargs))
            self.attention_k_list.append(
                base_type(activation_func=activation_func, expert_nums=self.n, flag='attention_k', k_id=i, **kwargs))
        
        self.attention_q = base_type(activation_func=activation_func, expert_nums=self.n, flag='attention_q', q_id=i, **kwargs) 
        self.activation_func = activation_func
        #         for i in range(self.task_nums):
#             self.attention_q_list.append(
#                 base_type(activation_func=activation_func, expert_nums=self.n, flag='attention_q', q_id=i, task_nums = self.task_nums,**kwargs))
            # 3 tower network 1 layers mlp 100
        self.tower = base_type(activation_func=activation_func, expert_nums=self.n, flag='tower', **kwargs)
        self.activation_func = activation_func

        #         append_input_shape = self.experts.output_shape

        #         self.append_fcs = []
        #         for i, next_shape in enumerate(append_hidden_shapes):
        #             fc = nn.Linear(append_input_shape, next_shape)
        #             append_hidden_init_func(fc)
        #             self.append_fcs.append(fc)
        #             # set attr for pytorch to track parameters( device )
        #             self.__setattr__("append_fc{}".format(i), fc)
        #             append_input_shape = next_shape

        # 4 last network
        self.last = nn.Linear(self.tower.output_shape, output_shape)
        net_last_init_func(self.last)

    def forward(self, x, task_id):
        # 0 baseline network
        out = self.base(x)

        expkqs = []
        vs = []
        
        for i in range(self.n):
            # 1 expert networks
            e = self.mlp_list[i](out)
            # 2 attention module
            v = self.attention_v_list[i](e)
            k = self.attention_k_list[i](e)
#             q = self.attention_q_list[task_id.item()](task_idx)
            q = self.attention_q(task_id)
    
            expkq = torch.sum(torch.mul(k, q),dim=-1)
            expkqs.append(expkq)
            print('expkq',expkq.shape)
            vs.append(v)

        res = torch.zeros(vs[0].shape).to('cuda:0')
        alphas = [i / sum(expkqs) for i in expkqs]

        for i in range(self.n):
            res += vs[i] * alphas[i].unsqueeze(-1)

        # 3 tower network
        out = self.tower(res)

        #         for append_fc in self.append_fcs:
        #             out = append_fc(out)
        #             out = self.activation_func(out)

        # 4 last network
        out = self.last(out)
        return out

In [5]:
import torch

# 定义三维数组和二维数组
arr_3d = torch.ones((2, 2, 5))
arr_2d = torch.ones((2, 2))

# 将二维数组扩展为(2, 2, 1)的三维数组
arr_2d = arr_2d.unsqueeze(2)

# 沿5方向的每一面对应乘上二维数组
result = arr_3d * arr_2d

# 查看结果
print(result.shape)

torch.Size([2, 2, 5])


In [6]:
# 0 GrassianContPolicyAE
LOG_SIG_MAX = 2
LOG_SIG_MIN = -20


class GuassianContPolicyAE(NetAE):
    def forward(self, x, task_id):
        qs = super().forward(x, task_id)

        #         mean, log_std = x.chunk(2, dim=-1)

        #         log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
        #         std = torch.exp(log_std)

        # return mean, std, log_std
        return qs

    def eval_act(self, x):
        with torch.no_grad():
            mean, _, _ = self.forward(x)
        return torch.tanh(mean.squeeze(0)).detach().cpu().numpy()

    def explore(self, x, return_log_probs=False, return_pre_tanh=False):

        mean, std, log_std = self.forward(x)

        dis = TanhNormal(mean, std)

        ent = dis.entropy().sum(-1, keepdim=True)

        dic = {
            "mean": mean,
            "log_std": log_std,
            "ent": ent
        }

        if return_log_probs:
            action, z = dis.rsample(return_pretanh_value=True)
            log_prob = dis.log_prob(
                action,
                pre_tanh_value=z
            )
            log_prob = log_prob.sum(dim=-1, keepdim=True)
            dic["pre_tanh"] = z.squeeze(0)
            dic["log_prob"] = log_prob
        else:
            if return_pre_tanh:
                action, z = dis.rsample(return_pretanh_value=True)
                dic["pre_tanh"] = z.squeeze(0)
            action = dis.rsample(return_pretanh_value=False)

        dic["action"] = action.squeeze(0)
        return dic

    def update(self, obs, actions):
        mean, std, log_std = self.forward(obs)
        dis = TanhNormal(mean, std)

        log_prob = dis.log_prob(actions).sum(-1, keepdim=True)
        ent = dis.entropy().sum(-1, keepdim=True)

        out = {
            "mean": mean,
            "log_std": log_std,
            "log_prob": log_prob,
            "ent": ent
        }
        return out

In [7]:
a = torch.zeros((4,4)).to('cuda:0')

In [8]:
params = get_params('./meta_config/mt10/mtsac_ae.json')
params['net']['base_type'] = MLPBaseAE

obs = torch.rand((19)).float().to('cuda:0')
action = torch.rand((4,)).float().to('cuda:0')
task_id = torch.tensor(([6])).float().to('cuda:0')
pf = GuassianContPolicyAE(
        input_shape=19,
        output_shape=2 * 4,
        **params['net'])

pf.to('cuda:0')
pf.forward(obs,task_id)

expkq torch.Size([])
expkq torch.Size([])
expkq torch.Size([])


tensor([-0.0023, -0.0018, -0.0003,  0.0010,  0.0030,  0.0047,  0.0009, -0.0006],
       device='cuda:0', grad_fn=<AddBackward0>)

In [186]:
torch.rand((3,4))

tensor([[0.6292, 0.9095, 0.0874, 0.3740],
        [0.3330, 0.8461, 0.2225, 0.2676],
        [0.0229, 0.1489, 0.4175, 0.3487]])

In [163]:
import torch

# 定义两个3x2x2的三维矩阵
a = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]])
b = torch.tensor([[[2, 2], [2, 2]], [[3, 3], [3, 3]], [[4, 4], [4, 4]]])

# 对应位置元素相乘
c = torch.mul(a, b)

# 输出结果
print(torch.sum(c,dim=0))

tensor([[53, 62],
        [71, 80]])


In [107]:
pf.state_dict().keys()

odict_keys(['base.baseline_fc0.weight', 'base.baseline_fc0.bias', 'base.baseline_fc1.weight', 'base.baseline_fc1.bias', 'mlp_list.0.expert0_fc0.weight', 'mlp_list.0.expert0_fc0.bias', 'mlp_list.0.expert0_fc1.weight', 'mlp_list.0.expert0_fc1.bias', 'mlp_list.1.expert1_fc0.weight', 'mlp_list.1.expert1_fc0.bias', 'mlp_list.1.expert1_fc1.weight', 'mlp_list.1.expert1_fc1.bias', 'mlp_list.2.expert2_fc0.weight', 'mlp_list.2.expert2_fc0.bias', 'mlp_list.2.expert2_fc1.weight', 'mlp_list.2.expert2_fc1.bias', 'attention_v_list.0.v0_fc.weight', 'attention_v_list.1.v1_fc.weight', 'attention_v_list.2.v2_fc.weight', 'attention_q_list.0.q0_fc.weight', 'attention_q_list.1.q1_fc.weight', 'attention_q_list.2.q2_fc.weight', 'attention_q_list.3.q3_fc.weight', 'attention_q_list.4.q4_fc.weight', 'attention_q_list.5.q5_fc.weight', 'attention_q_list.6.q6_fc.weight', 'attention_q_list.7.q7_fc.weight', 'attention_q_list.8.q8_fc.weight', 'attention_q_list.9.q9_fc.weight', 'attention_k_list.0.k0_fc.weight', 'atten

In [None]:
# critics:
class FlattenNetAE(NetAE):
    def forward(self, input,task_id):
        out = torch.cat(input, dim = -1)
        return super().forward(out,task_id)

In [29]:
qf1 = FlattenNetAE(
        input_shape=23,
        output_shape=1,
        **params['net'])
qf1.forward([obs,action],task_id)

tensor([0.0009], grad_fn=<AddBackward0>)

In [34]:
batch ={'task_idxs':torch.tensor([5])}
task_idx_num = batch['task_idxs'].item()
task_idx = torch.zeros((10,))
task_idx[task_idx_num]=1
print(task_idx)

tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.])


搞错了，应该用BootstrappedNet更好。我们试试看。首先看看BootstrapNet原版：

In [35]:
class Net(nn.Module):
    def __init__(
            self, output_shape,
            base_type,
            append_hidden_shapes=[],
            append_hidden_init_func=init.basic_init,
            net_last_init_func=init.uniform_init,
            activation_func=F.relu,
            **kwargs):

        super().__init__()

        self.base = base_type(activation_func=activation_func, **kwargs)
        self.activation_func = activation_func
        append_input_shape = self.base.output_shape
        self.append_fcs = []
        for i, next_shape in enumerate(append_hidden_shapes):
            fc = nn.Linear(append_input_shape, next_shape)
            append_hidden_init_func(fc)
            self.append_fcs.append(fc)
            # set attr for pytorch to track parameters( device )
            self.__setattr__("append_fc{}".format(i), fc)
            append_input_shape = next_shape

        self.last = nn.Linear(append_input_shape, output_shape)
        net_last_init_func(self.last)

    def forward(self, x):
        out = self.base(x)

        for append_fc in self.append_fcs:
            out = append_fc(out)
            out = self.activation_func(out)

        out = self.last(out)
        return out

In [37]:
class MLPBase(nn.Module):
    def __init__(self, input_shape, hidden_shapes, activation_func=F.relu, init_func = init.basic_init, last_activation_func = None ):
        super().__init__()
        
        self.activation_func = activation_func
        self.fcs = []
        if last_activation_func is not None:
            self.last_activation_func = last_activation_func
        else:
            self.last_activation_func = activation_func
        input_shape = np.prod(input_shape)

        self.output_shape = input_shape
        for i, next_shape in enumerate( hidden_shapes ):
            fc = nn.Linear(input_shape, next_shape)
            init_func(fc)
            self.fcs.append(fc)
            # set attr for pytorch to track parameters( device )
            self.__setattr__("fc{}".format(i), fc)

            input_shape = next_shape
            self.output_shape = next_shape

    def forward(self, x):

        out = x
        for fc in self.fcs[:-1]:
            out = fc(out)
            out = self.activation_func(out)
        out = self.fcs[-1](out)
        out = self.last_activation_func(out)
        return out

In [58]:
class BootstrappedNet(Net):
    def __init__(self, output_shape, 
                 head_num = 10,
                 **kwargs ):
        self.head_num = head_num
        self.origin_output_shape = output_shape
        output_shape *= self.head_num
        super().__init__(output_shape = output_shape, **kwargs)

    def forward(self, x, idx):
        base_shape = x.shape[:-1]
        print(base_shape)
        out = super().forward(x)# 8*10
        out_shape = base_shape + torch.Size([self.origin_output_shape, self.head_num])
        print(out_shape,torch.Size([self.origin_output_shape, self.head_num]))
        view_idx_shape = base_shape + torch.Size([1, 1])
        print(base_shape)
        expand_idx_shape = base_shape + torch.Size([self.origin_output_shape, 1])
        print(base_shape,torch.Size([self.origin_output_shape, 1]))
        
        out = out.reshape(out_shape)

        idx = idx.view(view_idx_shape)
        idx = idx.expand(expand_idx_shape)

        out = out.gather(-1, idx).squeeze(-1)
        return out

In [126]:
params = get_params('./meta_config/mt10/mtmhsac.json')
params['net']['base_type'] = MLPBase
pf = BootstrappedNet (
        input_shape = 19, 
        output_shape = 2 * 4,
        head_num=10,
        **params['net'] )
obs = torch.rand((19,))

pf2 = Net(input_shape = 19, 
        output_shape = 2 * 4,
        **params['net'] )

pf.forward(obs,torch.tensor([1]))
pf2.forward(obs)

torch.Size([])
torch.Size([8, 10]) torch.Size([8, 10])
torch.Size([])
torch.Size([]) torch.Size([8, 1])
tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]])


tensor([ 6.0238e-03, -4.7925e-04,  2.1551e-03,  6.3379e-03, -3.2525e-03,
         2.4140e-03, -6.3365e-06,  1.2156e-03], grad_fn=<AddBackward0>)

In [49]:
base_shape = (19,)
out_shape = base_shape + torch.Size([19, 10])
print(out_shape)
view_idx_shape = base_shape + torch.Size([1, 1])
expand_idx_shape = base_shape + torch.Size([self.origin_output_shape, 1])

(19, 19, 10)


NameError: name 'self' is not defined

## 好的，现在来弄一下我们自己的

In [None]:
class BootstrappedNetAE(NetAE):
    def __init__(self, output_shape, 
                 head_num = 10,
                 **kwargs ):
        self.head_num = head_num
        self.origin_output_shape = output_shape
        output_shape *= self.head_num
        super().__init__(output_shape = output_shape, **kwargs)

    def forward(self, x, idx):
        base_shape = x.shape[:-1]
        print(base_shape)
        out = super().forward(x)# 8*10
        out_shape = base_shape + torch.Size([self.origin_output_shape, self.head_num])
        print(out_shape,torch.Size([self.origin_output_shape, self.head_num]))
        view_idx_shape = base_shape + torch.Size([1, 1])
        print(base_shape)
        expand_idx_shape = base_shape + torch.Size([self.origin_output_shape, 1])
        print(base_shape,torch.Size([self.origin_output_shape, 1]))
        
        out = out.reshape(out_shape)

        idx = idx.view(view_idx_shape)
        idx = idx.expand(expand_idx_shape)

        out = out.gather(-1, idx).squeeze(-1)
        return out

In [125]:
a = torch.rand((1,10))
print(a)
a.view()

tensor([[0.8168, 0.6146, 0.2686, 0.3414, 0.0517, 0.3150, 0.1318, 0.7991, 0.1808,
         0.8374]])


TypeError: view() received an invalid combination of arguments - got (), but expected one of:
 * (torch.dtype dtype)
 * (tuple of ints size)
