Skip to content

AssertionError for AuroraHighRes #106

Closed
@WeatherPredictionEnthusiast

Description

Running the following piece of code (tested with both version 1.6.0 and 1.5.3):

from aurora import AuroraHighRes

model = AuroraHighRes()
model.load_checkpoint()

Results in this error:

AssertionError                            Traceback (most recent call last)
Cell In[1], line 4
      1 from aurora import AuroraHighRes
      3 model = AuroraHighRes()
----> 4 model.load_checkpoint()

File /***/conda/envs/***/lib/python3.12/site-packages/aurora/model/aurora.py:389, in Aurora.load_checkpoint(self, repo, name, strict)
    387 name = name or self.default_checkpoint_name
    388 path = hf_hub_download(repo_id=repo, filename=name)
--> 389 self.load_checkpoint_local(path, strict=strict)

File /***/conda/envs/***/lib/python3.12/site-packages/aurora/model/aurora.py:403, in Aurora.load_checkpoint_local(self, path, strict)
    400 device = next(self.parameters()).device
    401 d = torch.load(path, map_location=device, weights_only=True)
--> 403 d = self._adapt_checkpoint(d)
    405 # Check if the history size is compatible and adjust weights if necessary.
    406 current_history_size = d["encoder.surf_token_embeds.weights.2t"].shape[2]

File /***/conda/envs/***/lib/python3.12/site-packages/aurora/model/aurora.py:426, in Aurora._adapt_checkpoint(self, d)
    417 def _adapt_checkpoint(self, d: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    418     """Adapt an existing checkpoint to make it compatible with the current version of the model.
    419 
    420     Args:
   (...)    424         dict[str, torch.Tensor]: Adapted checkpoint.
    425     """
--> 426     return _adapt_checkpoint_pretrained(self.patch_size, d)

File /***/conda/envs/***/lib/python3.12/site-packages/aurora/model/compat.py:51, in _adapt_checkpoint_pretrained(patch_size, d)
     48 del d["decoder.surf_head.weight"]
     49 del d["decoder.surf_head.bias"]
---> 51 assert weight.shape[0] == 4 * patch_size**2
     52 assert bias.shape[0] == 4 * patch_size**2
     53 weight = weight.reshape(patch_size**2, 4, -1)

AssertionError: 

However, there are no issues with e.g. AuroraPretrained.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions