Skip to content

Commit

Permalink
Update model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
powerycy committed Jan 26, 2022
1 parent 376cea2 commit 827bcfa
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions model/model.py
Expand Up @@ -155,17 +155,17 @@ def __init__(self, heads, head_size,hidden_size,RoPE=True):
self.head_size = head_size
self.RoPE = RoPE
self.hidden_size = hidden_size
self.linear_1 = nn.Linear(hidden_size,hidden_size * 2,bias=True)
self.linear_2 = nn.Linear(hidden_size * 2,heads * 2,bias=True)
self.linear_1 = nn.Linear(hidden_size,head_size * 2,bias=True)
self.linear_2 = nn.Linear(head_size * 2,heads * 2,bias=True)

def forward(self, inputs, mask=None):
inputs = self.linear_1(inputs)
qw, kw = inputs[..., ::2], inputs[..., 1::2]
# RoPE编码
if self.RoPE:
pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs)
cos_pos = pos[...,1::2].repeat(1,1,self.hidden_size // (self.head_size // 2))
sin_pos = pos[...,::2].repeat(1,1,self.hidden_size // (self.head_size // 2))
cos_pos = pos[...,1::2].repeat(1,1,2))
sin_pos = pos[...,::2].repeat(1,1,2))
qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3)
qw2 = torch.reshape(qw2, qw.shape)
qw = qw * cos_pos + qw2 * sin_pos
Expand Down Expand Up @@ -314,4 +314,4 @@ def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False,
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)

0 comments on commit 827bcfa

Please sign in to comment.