diff --git a/gliner/modeling/span_rep.py b/gliner/modeling/span_rep.py index 677447e..573b8e5 100644 --- a/gliner/modeling/span_rep.py +++ b/gliner/modeling/span_rep.py @@ -75,6 +75,7 @@ def __init__(self, hidden_size, max_width): """ super().__init__() + self.max_width = max_width self.mlp = nn.Linear(hidden_size, hidden_size * max_width) def forward(self, h, *args): @@ -95,7 +96,7 @@ def forward(self, h, *args): span_rep = self.mlp(h) - span_rep = span_rep.view(B, L, -1, D) + span_rep = span_rep.view(B, L, self.max_width, D) return span_rep.relu()