In [1]:
import torch
import torch.nn as nn
import numpy as np

### 搭建网络模块

In [2]:
class scaled_dot_product_attention(nn.Module):

    def __init__(self, att_dropout=0.0):
        super(scaled_dot_product_attention, self).__init__()
        self.dropout = nn.Dropout(att_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, scale=None):
        '''
        args:
            q: [batch_size, q_length, q_dimension]
            k: [batch_size, k_length, k_dimension]
            v: [batch_size, v_length, v_dimension]
            q_dimension = k_dimension = v_dimension
            scale: 缩放因子
        return:
            attention, alpha
        '''
        # 快使用神奇的爱因斯坦求和约定吧！
        alpha = torch.einsum('ijk,ilk->ijl', [q, k])# query和key向量相乘
        if scale:
            alpha = alpha * scale
        alpha = self.softmax(alpha)
        alpha = self.dropout(alpha)
        attention = torch.einsum('ijl,ilk->ijk', [alpha, v])
        return attention, alpha

In [3]:
class TaskNet(nn.Module):
    
    def __init__(self, input_dim, hidden_dim):
        '''
        input_dim:输入特征的维度
        hidden_dim:隐藏层维度
        query_dim:query向量的维度,因为其直接由隐藏层输出,故query_dim = output_dim
        此网络模块的最终输出是query向量
        '''
        super(TaskNet, self).__init__()
        self.query_dim = hidden_dim
        self.fc1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
    def forward(self,X):
        query = self.fc1(X)
        return query
        

In [4]:
class Expert(nn.Module):
    
    def __init__(self, input_dim, hidden_dim):
        '''
        input_dim:输入特征的维度
        hidden_dim:隐藏层维度
        key_dim:key向量的维度,因为其直接由隐藏层输出,故key_dim = output_dim
        此网络模块的最终输出是key向量
        '''
        super(Expert, self).__init__()
        self.key_dim = hidden_dim
        self.fc1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
    def forward(self,X):
        key = self.fc1(X)
        return key
        

In [5]:
class Tower(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super(Tower, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16,output_dim)
        )

    def forward(self, X):
        out = self.fc1(X)
        return out

In [6]:
input_dim=18
X = torch.ones(input_dim)

### 拼接网络模块，构建完整网络

In [7]:
class AOE(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, task_num, expert_num):
        super(AOE, self).__init__()
        # model
        self.tasknet_list = nn.ModuleList([TaskNet(18,64) for i in range(task_num)])
        self.expert_list = nn.ModuleList([Expert(18,64) for i in range(expert_num)])
        self.tower_list = nn.ModuleList([Tower(64,1) for i in range(task_num)])
        # vector
        self.query = []
        self.key = []
        self.value = []
        self.alpha = []
        self.attention = []
        # function
        self.softmax = nn.Softmax(dim=0)# dim=0代表作用于行向量，dim=1代表作用于列向量

    def forward(self, X):
        
        self.query = [task_net(X) for task_net in self.tasknet_list]
        self.query = torch.stack(self.query)

        self.key = [expert(X) for expert in self.expert_list]
        self.key = torch.stack(self.key)  #shape:torch.Size([5, 64])

        for q in self.query:
            a = [(q@k.T) for k in self.key]
            a = self.softmax(torch.tensor(a))
            self.alpha.append(a)

        self.value = self.key.clone()
        
        for a in self.alpha:
            a.view(1,5)
            self.attention.append(a@self.value)
        self.attention = torch.cat(self.attention, dim=0).reshape(4,64) #将attention列表转为tensor类型

        # 拼接attention值和TaskNet输出值，然后送入Tower
        # 其实TaskNet的输出值就是query
        query_add_attention = self.query + self.attention

        ## 传入TOWER
        tower_input = query_add_attention
        tower_output = [tower(x) for tower, x in zip(self.tower_list, tower_input)]

        return tower_output

### define the training part

In [8]:
input_dim=18
X = torch.ones(input_dim)
model = AOE(18, 64, 1, 4, 5)
o = model(X)

In [9]:
o

[tensor([0.0589], grad_fn=<AddBackward0>),
 tensor([-0.1372], grad_fn=<AddBackward0>),
 tensor([-0.2355], grad_fn=<AddBackward0>),
 tensor([0.1954], grad_fn=<AddBackward0>)]

### 构建q,k,v矩阵，且qkv三个矩阵不参与梯度更新，因此要with torch.no_grad()

In [10]:
# 构建query矩阵
'''
1. 确定任务数量
2. 初始化task_num个网络,可存入了列表
3. 将每个网络输出的query向量拼接起来
'''
task_num = 4
tasknet_list = nn.ModuleList([TaskNet(18,64) for i in range(task_num)])
query = [task_net(X) for task_net in tasknet_list]
query = torch.stack(query) # shape:torch.Size([4, 64])

In [11]:
# 构建key矩阵
'''
1. 确定专家网络数量
2. 初始化expert_num个网络
3. 拼接每个专家网络输出的key向量
'''
expert_num = 5
expert_list = nn.ModuleList([Expert(18,64) for i in range(expert_num)])
key = [expert(X) for expert in expert_list]
key = torch.stack(key)  #shape:torch.Size([5, 64])

In [19]:
# 构建alpha矩阵
'''
1. 针对每个任务，学得一个权重向量(0.2, 0.3, 0.1 .... 0.1)
2. 多个任务所对应的权重向量拼接成一个权重矩阵
'''
softmax = nn.Softmax(dim=0) # dim=0代表作用于行向量，dim=1代表作用于列向量
alpha = []
for q in query:
    a = [(q@k.T) for k in key]
    a = softmax(torch.tensor(a))
    alpha.append(a)

alpha = torch.stack(alpha)
alpha


tensor([[0.2252, 0.1835, 0.2419, 0.1672, 0.1822],
        [0.2161, 0.2186, 0.2349, 0.1696, 0.1607],
        [0.2128, 0.2158, 0.1943, 0.2018, 0.1754],
        [0.2130, 0.1962, 0.1929, 0.2198, 0.1780]])

In [13]:
# 构建value矩阵
value = key.clone()

In [14]:
# 计算attention值
'''
1. alpha[0]:[0.2104, 0.1306, 0.2965, 0.1911, 0.1713]
   针对任务A来说,这五个值分别代表五个专家网络所占的权重
   针对任务B来说,则需要查看alpha[1]

2. value矩阵的形状为[5, 64],5代表5个专家网络, 64代表专家网络的输出维度

3. 若想输出针对任务A的attention值,则
'''
attention = []
for a in alpha:
    a.view(1,5)
    attention.append(a@value)

In [15]:
m1 = torch.tensor([1,2,3,4]).reshape(2,2)
m2 = torch.tensor([1,2,3,4]).reshape(2,2)
m1+m2

tensor([[2, 4],
        [6, 8]])