In [1]:
import torch
from src.model_vit import vit_base_patch16
import rasterio
from pyproj import Transformer
from datetime import date
import numpy as np
import os
from torchvision.datasets.utils import download_url

  from .autonotebook import tqdm as notebook_tqdm


### Create and load a pretrained Copernicus-FM model

In [None]:
# download weights
!wget https://huggingface.co/wangyi111/Copernicus-FM/resolve/main/CopernicusFM_ViT_base_varlang_e100.pth -P ./weights

--2025-03-13 12:05:01--  https://huggingface.co/wangyi111/Copernicus-FM/resolve/main/CopernicusFM_ViT_base_varlang_e100.pth
Resolving huggingface.co (huggingface.co)... 3.160.150.2, 3.160.150.119, 3.160.150.7, ...
Connecting to huggingface.co (huggingface.co)|3.160.150.2|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.hf.co/repos/5d/56/5d5698bc57b0453934b47e33f6ad19062a8419378967ef8a9a20b5400e0d4db0/539c5dd95cdf5b95fac1c4540929eaeb24b53a694a3421535ef3322a51644397?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27CopernicusFM_ViT_base_varlang_e100.pth%3B+filename%3D%22CopernicusFM_ViT_base_varlang_e100.pth%22%3B&Expires=1741867502&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0MTg2NzUwMn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zLzVkLzU2LzVkNTY5OGJjNTdiMDQ1MzkzNGI0N2UzM2Y2YWQxOTA2MmE4NDE5Mzc4OTY3ZWY4YTlhMjBiNTQwMGUwZDRkYjAvNTM5YzVkZDk1Y2RmNWI5NWZhYzFjNDU0MDkyOWV

In [2]:
# create model
model = vit_base_patch16(num_classes=10, global_pool=False)

# load pre-trained weights
path = './weights/CopernicusFM_ViT_base_varlang_e100.pth'
check_point = torch.load(path)
if 'model' in check_point:
    state_dict = check_point['model']
else:
    state_dict = check_point
msg = model.load_state_dict(state_dict, strict=False)
print(msg)

_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=['mask_token'])


### Encode a image from any spectral or non-spectral sensor

Spectral input: any stack of spectral bands

In [None]:
# load an example Sentinel-2 image
img_path = 'assets/20201002T221611_20201002T221610_T60HWB.tif'
with rasterio.open(img_path) as src:
    img = src.read((2,3,4,8)) # take R,G,B,NIR bands for example
    img = img.astype(np.float32) / 10000.0 # normalize to [0,1], just for demonstration here, we recommend z-score normalization for S1/2 in practice
    # get meta: geocoordinates
    cx,cy = src.xy(src.height // 2, src.width // 2)
    if src.crs.to_string() != 'EPSG:4326':
        crs_transformer = Transformer.from_crs(src.crs, 'epsg:4326', always_xy=True)
        lon, lat = crs_transformer.transform(cx,cy)
    else:
        lon, lat = cx, cy
# get meta: time
img_fname = os.path.basename(img_path)
date_str = img_fname.split('_')[1][:8]
date_obj = date(int(date_str[:4]), int(date_str[4:6]), int(date_str[6:8]))
reference_date = date(1970, 1, 1)
delta = (date_obj - reference_date).days
# get meta: patch area
patch_area = (16*10/1000)**2 # patchsize 16 pix, gsd 10m

# metadata tensor
meta = np.array([lon, lat, delta, patch_area]).astype(np.float32)
meta = torch.from_numpy(meta)

img = torch.from_numpy(img).unsqueeze(0) # add batch dimension, [1, C, H, W]
meta = meta.unsqueeze(0) # add batch dimension, [1, 4]
key = 'any' # not used
wvs = [490, 560, 665, 842] # wavelength: B,G,R,NIR (Sentinel 2)
bws = [65, 35, 30, 115] # bandwidth: B,G,R,NIR (Sentinel 2)
language_embed = None # N/A
kernel_size = 16 # expected patch size
input_mode = 'spectral'

print('Encoding a spectral image with shape {}, and expected patch size {}.'.format(img.shape, kernel_size))
logit, embed = model(img, meta, wvs, bws, language_embed, input_mode, kernel_size)
print(logit.shape, embed.shape)

Encoding a spectral image with shape torch.Size([1, 4, 264, 264]), and expected patch size 16.
torch.Size([1, 10]) torch.Size([1, 768])


Non-spectral input: any image with a variable name

In [None]:
# example 1: pre-defined variable in Copernicus-FM (s5p_no2, s5p_co, s5p_o3, s5p_so2, dem)
var_name = 'Sentinel 5P Nitrogen Dioxide' # to index the predefined language embedding
img = torch.randn(1, 1, 56, 56)
meta = torch.full((1, 4), float('nan')) # [lon, lat, delta_time, patch_token_area], assume unknown
wvs = None #[0] # not used
bws = None #[0] # not used
kernel_size = 4 # expected patch size
input_mode = 'variable'

var_embed_fpath = './weights/var_embed_llama3.2_1B.pt'
if not os.path.exists(var_embed_fpath):
    url = 'https://huggingface.co/wangyi111/Copernicus-FM/resolve/main/varname_embed/varname_embed_llama3.2_1B.pt'
    download_url(url, './weights/', filename='var_embed_llama3.2_1B.pt')
language_embed = torch.load('./weights/var_embed_llama3.2_1B.pt') # 2048   
language_embed = language_embed[var_name]


print('Encoding a variable image with name "{}", shape {}, and expected patch size {}.'.format(var_name, img.shape, kernel_size))
logit, embed = model(img, meta, wvs, bws, language_embed, input_mode, kernel_size)
print(logit.shape, embed.shape)

Encoding a predefined variable image with name "Sentinel 5P Nitrogen Dioxide", shape torch.Size([1, 1, 56, 56]), and expected patch size 4.


  v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1)


torch.Size([1, 10]) torch.Size([1, 768])


In [None]:
# example 2: new variable
varname = 'temperature' # this is a new variable name
img = torch.randn(1, 1, 112, 112)
meta = torch.full((1, 4), float('nan')) # meta unavailable
wvs = None #[0] # not used
bws = None #[0] # not used
kernel_size = 8 # expected patch size
input_mode = 'variable'

# get varname embedding from a pre-trained language model (e.g. Llama 3.2 1B)
language_embed = torch.randn(2048)


print('Encoding a new variable image with name "{}", shape {}, and expected patch size {}.'.format(varname, img.shape, kernel_size))
logit, embed = model(img, meta, wvs, bws, language_embed, input_mode, kernel_size)
print(logit.shape, embed.shape)

Encoding a new variable image with name "temperature", shape torch.Size([1, 1, 112, 112]), and expected patch size 8.


  v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1)


torch.Size([1, 10]) torch.Size([1, 768])
