In [6]:
import torch 
import torch.nn as nn

import torch 
import torch.nn as nn

class Attention(nn.Module): 
    def __init__(self, dmodel, dk): 
        super(Attention, self).__init__() 
        self.dk = dk # 键的维度
        self.W = nn.Linear(dmodel, dk) #查询的线性层
        self.V = nn.Linear(dk, dk)  # 调整输入维度 值的线性层
        self.a = nn.Linear(dmodel, 1) #注意力的线性层

    def forward(self, Q, K, V):
        a = self.a(Q)
        a = torch.tanh(a + self.W(Q) + K)  # 确保维度匹配
        a = self.V(a)
        a = torch.softmax(a, dim=-1)
        return a * V
    
# 使用 Attention 模块
attention = Attention(dmodel=64, dk=32) 
Q = torch.randn(1, 1, 64)  #(1,1,64) 1个样本 1个查询 64个维度 1*1*64
K = torch.randn(1, 32, 32)  #(1,32,32) 1个样本 32个键 32个维度 1*32*32
V = torch.randn(1, 32, 32)

output = attention(Q, K, V) 
print(output.shape) 
#print(output) 
print(Q)
print(K[0])


# torch.Size([1, 1, 32]) ```


torch.Size([1, 32, 32])
tensor([[[-0.7812,  0.1937,  0.6831,  0.3020, -1.2855,  0.0128,  0.4752,
          -0.6468, -1.7165, -0.4823,  2.3699, -2.2589, -1.0785,  1.7208,
           0.0644, -0.3116,  0.4442, -0.9582,  0.9313,  0.1171, -0.9809,
           0.9163, -0.3754,  0.0775, -0.4352,  0.1542, -0.5990, -1.6279,
          -0.2050, -1.2688,  0.2543,  1.2175,  0.2484, -1.2445,  1.2138,
          -2.5827,  2.4274, -0.3891,  0.6663,  0.2754, -2.1325, -1.4978,
          -0.3794, -1.5681,  0.9710,  0.7007,  0.4299,  0.1714, -1.0776,
          -0.9400, -0.4812, -0.7455,  1.3695, -0.4955, -0.8966,  0.1539,
           1.9773, -1.5391,  1.3548, -1.4211,  0.0226,  1.5352, -1.3459,
           0.1566]]])
tensor([[-0.5906, -2.0971, -0.1380,  ..., -1.6176, -0.3098, -0.2979],
        [ 0.3817, -0.4182,  0.0311,  ..., -0.1159, -0.1287, -0.0273],
        [ 0.0532,  0.0124,  1.1253,  ..., -1.0223,  0.0213,  1.3701],
        ...,
        [-1.0162, -0.8913,  1.8390,  ...,  1.0211,  0.5831,  1.8436],
    

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super(Attention, self).__init__()
        self.query = nn.Linear(in_dim, hidden_dim) #查询
        self.key = nn.Linear(in_dim, hidden_dim)
        self.value = nn.Linear(in_dim, hidden_dim)
    
    def forward(self, x):
        '''
        公式为：Attention(Q, K, V) = softmax(Q*K^T/sqrt(d_k)) * V
        Q: 查询
        K: 键
        V: 值
        d_k: 键的维度 这里用hidden_dim表示 即:K.size(-1)
        对张量 K 进行转置（transpose）。具体来说，这个操作会将张量 K 的第0维和第1维进行交换
        '''
        Q = self.query(x) 
        K = self.key(x)
        V = self.value(x)
        attn_scores = torch.matmul(Q, K.transpose(0, 1)) / torch.sqrt(torch.tensor(K.size(-1), dtype=torch.float32))
        attn_probs = F.softmax(attn_scores, dim=-1)
        attended_values = torch.matmul(attn_probs, V)
        return attended_values

class MLPNetworkWithAttention(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim_1=256, hidden_dim_2=128, attention_dim=64):
        super(MLPNetworkWithAttention, self).__init__()
        self.attention = Attention(in_dim, attention_dim)
        self.fc1 = torch.nn.Linear(attention_dim, hidden_dim_1)
        self.fc2 = torch.nn.Linear(hidden_dim_1, hidden_dim_2)
        self.fc3 = torch.nn.Linear(hidden_dim_2, out_dim)
        
        # 根据计算增益
        gain1 = nn.init.calculate_gain('relu')
        # Xavier均匀分布初始化
        torch.nn.init.xavier_uniform_(self.fc1.weight, gain=gain1)
        torch.nn.init.xavier_uniform_(self.fc2.weight, gain=gain1)
        torch.nn.init.xavier_uniform_(self.fc3.weight, gain=gain1)
        # 初始化参数
        self.fc1.bias.data.fill_(0.01)
        self.fc2.bias.data.fill_(0.01)
    
    def forward(self, x): #这里的x是状态加动作
        x = self.attention(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# 创建带有注意力机制的critic
global_obs_dim = 128  # 示例输入维度
critic = MLPNetworkWithAttention(global_obs_dim, 1)
#print(critic)
query = nn.Linear(128, 1)
x = torch.randn(12, 128)
Q = query(x)
print(Q.transpose(0, 1).size())
print(Q.size()) #.size表示维度
print(Q.size(-1))
an = Attention(128, 64)
output = an(x)
print(output)

torch.Size([1, 12])
torch.Size([12, 1])
1
tensor([[ 3.7421e-01,  4.7305e-02,  2.4503e-03, -1.7902e-03,  6.8952e-02,
          1.5911e-01, -1.1148e-01,  3.2580e-01,  2.3623e-02,  1.6507e-01,
          2.8500e-02,  1.7523e-02,  1.3803e-01,  2.0690e-01,  2.4201e-02,
         -1.1153e-01,  8.0610e-03,  1.5245e-01, -7.2237e-02,  8.8497e-02,
         -1.5001e-01, -1.4662e-01,  4.8354e-01, -6.8436e-02,  2.4036e-01,
          1.2587e-01, -2.9741e-01, -1.9121e-02, -3.4182e-01, -3.6344e-02,
         -1.1900e-01,  1.6724e-01,  3.6068e-01, -7.9976e-02,  1.6924e-01,
         -5.1975e-02, -1.9620e-01, -2.6547e-03,  1.1183e-01, -3.5658e-02,
          4.1525e-01,  2.4581e-01, -4.7811e-01,  1.6771e-01,  1.8765e-01,
          3.5561e-02,  2.0565e-01,  1.7503e-01,  6.9023e-02, -2.1510e-01,
         -3.2191e-01, -2.7941e-01,  1.9586e-01,  1.6022e-01,  9.0823e-02,
         -1.2894e-01, -2.6922e-01,  9.0844e-02,  1.4083e-01, -2.0077e-01,
          5.7030e-02,  8.1929e-02, -6.1835e-02,  1.0544e-01],
        

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiAgentAttention(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super(MultiAgentAttention, self).__init__()
        self.query = nn.Linear(in_dim, hidden_dim)
        self.key = nn.Linear(in_dim, hidden_dim)
        self.value = nn.Linear(in_dim, hidden_dim)
    
    def forward(self, agent_states):
        """
        Inputs:
            agent_states (list of tensors): List of states for each agent
        Outputs:
            attended_values (tensor): Attention-weighted values for each agent
        """
        Q = [self.query(state) for state in agent_states]
        K = [self.key(state) for state in agent_states]
        V = [self.value(state) for state in agent_states]
        
        Q = torch.stack(Q)
        K = torch.stack(K)
        V = torch.stack(V)
        
        attn_scores = torch.matmul(Q, K.transpose(0, 1)) / torch.sqrt(torch.tensor(K.size(-1), dtype=torch.float32))
        attn_probs = F.softmax(attn_scores, dim=-1)
        attended_values = torch.matmul(attn_probs, V)
        
        return attended_values

class MultiAgentMLPNetworkWithAttention(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim_1=256, hidden_dim_2=128, attention_dim=64, ):
        super(MultiAgentMLPNetworkWithAttention, self).__init__()
        self.attention = MultiAgentAttention(in_dim, attention_dim)
        self.fc1 = torch.nn.Linear(attention_dim, hidden_dim_1)
        self.fc2 = torch.nn.Linear(hidden_dim_1, hidden_dim_2)
        self.fc3 = torch.nn.Linear(hidden_dim_2, out_dim)
        
        # 根据计算增益
        gain1 = nn.init.calculate_gain('relu')
        # Xavier均匀分布初始化
        torch.nn.init.xavier_uniform_(self.fc1.weight, gain=gain1)
        torch.nn.init.xavier_uniform_(self.fc2.weight, gain=gain1)
        torch.nn.init.xavier_uniform_(self.fc3.weight, gain=gain1)
        # 初始化参数
        self.fc1.bias.data.fill_(0.01)
        self.fc2.bias.data.fill_(0.01)
    
    def forward(self, agent_states):
        """
        Inputs:
            agent_states (list of tensors): List of states for each agent
        Outputs:
            outputs (list of tensors): Outputs for each agent
        """
        attended_values = self.attention(agent_states)
        outputs = []
        for value in attended_values:
            x = F.relu(self.fc1(value))
            x = F.relu(self.fc2(x))
            output = self.fc3(x)
            outputs.append(output)
        return outputs


In [21]:
import torch

# 假设 state_list 包含两个张量，每个张量的形状为 (batch_size, state_dim)
state_list = [
    torch.tensor([[1, 2, 3], [4, 5, 6]]),  # 形状为 (2, 3)
    torch.tensor([[7, 8, 9], [10, 11, 12]])  # 形状为 (2, 3)
]

# 假设 act_list 包含两个张量，每个张量的形状为 (batch_size, action_dim)
act_list = [
    torch.tensor([[13, 14], [15, 16]]),  # 形状为 (2, 2)
    torch.tensor([[17, 18], [19, 20]])  # 形状为 (2, 2)
]
x = torch.cat(state_list + act_list, 1) #按列拼接
print(x)


tensor([[ 1,  2,  3,  7,  8,  9, 13, 14, 17, 18],
        [ 4,  5,  6, 10, 11, 12, 15, 16, 19, 20]])
