diff --git a/model/model.py b/model/model.py index d5aed3f..5e4b1ec 100644 --- a/model/model.py +++ b/model/model.py @@ -155,8 +155,8 @@ def __init__(self, heads, head_size,hidden_size,RoPE=True): self.head_size = head_size self.RoPE = RoPE self.hidden_size = hidden_size - self.linear_1 = nn.Linear(hidden_size,hidden_size * 2,bias=True) - self.linear_2 = nn.Linear(hidden_size * 2,heads * 2,bias=True) + self.linear_1 = nn.Linear(hidden_size,head_size * 2,bias=True) + self.linear_2 = nn.Linear(head_size * 2,heads * 2,bias=True) def forward(self, inputs, mask=None): inputs = self.linear_1(inputs) @@ -164,8 +164,8 @@ def forward(self, inputs, mask=None): # RoPE编码 if self.RoPE: pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs) - cos_pos = pos[...,1::2].repeat(1,1,self.hidden_size // (self.head_size // 2)) - sin_pos = pos[...,::2].repeat(1,1,self.hidden_size // (self.head_size // 2)) + cos_pos = pos[...,1::2].repeat(1,1,2)) + sin_pos = pos[...,::2].repeat(1,1,2)) qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3) qw2 = torch.reshape(qw2, qw.shape) qw = qw * cos_pos + qw2 * sin_pos @@ -314,4 +314,4 @@ def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False, return FusedLayerNorm(normalized_shape, eps, elementwise_affine) except ImportError: pass - return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) \ No newline at end of file + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)