Closed
Description
Running the following piece of code (tested with both version 1.6.0 and 1.5.3):
from aurora import AuroraHighRes
model = AuroraHighRes()
model.load_checkpoint()
Results in this error:
AssertionError Traceback (most recent call last)
Cell In[1], line 4
1 from aurora import AuroraHighRes
3 model = AuroraHighRes()
----> 4 model.load_checkpoint()
File /***/conda/envs/***/lib/python3.12/site-packages/aurora/model/aurora.py:389, in Aurora.load_checkpoint(self, repo, name, strict)
387 name = name or self.default_checkpoint_name
388 path = hf_hub_download(repo_id=repo, filename=name)
--> 389 self.load_checkpoint_local(path, strict=strict)
File /***/conda/envs/***/lib/python3.12/site-packages/aurora/model/aurora.py:403, in Aurora.load_checkpoint_local(self, path, strict)
400 device = next(self.parameters()).device
401 d = torch.load(path, map_location=device, weights_only=True)
--> 403 d = self._adapt_checkpoint(d)
405 # Check if the history size is compatible and adjust weights if necessary.
406 current_history_size = d["encoder.surf_token_embeds.weights.2t"].shape[2]
File /***/conda/envs/***/lib/python3.12/site-packages/aurora/model/aurora.py:426, in Aurora._adapt_checkpoint(self, d)
417 def _adapt_checkpoint(self, d: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
418 """Adapt an existing checkpoint to make it compatible with the current version of the model.
419
420 Args:
(...) 424 dict[str, torch.Tensor]: Adapted checkpoint.
425 """
--> 426 return _adapt_checkpoint_pretrained(self.patch_size, d)
File /***/conda/envs/***/lib/python3.12/site-packages/aurora/model/compat.py:51, in _adapt_checkpoint_pretrained(patch_size, d)
48 del d["decoder.surf_head.weight"]
49 del d["decoder.surf_head.bias"]
---> 51 assert weight.shape[0] == 4 * patch_size**2
52 assert bias.shape[0] == 4 * patch_size**2
53 weight = weight.reshape(patch_size**2, 4, -1)
AssertionError:
However, there are no issues with e.g. AuroraPretrained
.
Metadata
Metadata
Assignees
Labels
No labels