In [68]:
import numpy as np
import torch

import os
from einops import rearrange

from itertools import combinations


from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

from matplotlib.colors import ListedColormap

from logging import Logger
import yaml

from src.dataset.pairedmodalitycrops import PairedModalityLC

from src.model.prithvi_encoder import Prithvi_Encoder
from src.model.dofa_encoder import DOFA_Encoder

prithvi_config = './configs/prithvi.yaml'
dofa_config = './configs/dofa.yaml'

embed_dir = './embeddings/pairedmodalitycrops/'

---

Load dataset and calculate statistics for normalization

---

In [69]:
dataset = PairedModalityLC(
    split='all',
    dataset_name='PairedModalityCrops',
    multi_modal=True,
    multi_temporal=False,
    root_path='./data/pairedmodalitycrops/',
    classes=None,
    img_size=224,
    ignore_index=-1,
    num_classes= 254,
    bands = None,
    distribution=None,
    data_max=None,
    data_mean=None,
    data_min=None,
    data_std=None,
    download_url=None,
    auto_download=False
)

modalities = dataset.modalities

dset = {    
    'hls':[],
    'l8':[],
    'l9':[],
    's2':[],
    's1':[],
    'target':[]
}
for data in dataset:
    for modality in modalities:
        dset[modality].append(data['image'][modality])
    dset['target'].append(data['target'])
    
    
means = {}
stds = {}
for modality in modalities:
    data = torch.stack(dset[modality])
    mean = torch.mean(data,dim=(0,2,3))
    std = torch.std(data,dim=(0,2,3))
    means[modality] =  mean
    stds[modality] = std

---

1a. Generate Prithvi Embedding 

---

In [72]:
import yaml

with open(prithvi_config) as stream:
    try:
        config = yaml.safe_load(stream)

        config.pop('_target_')
        config['num_frames'] = 1
    except yaml.YAMLError as exc:
        print(exc)

print(config)

encoder = Prithvi_Encoder(
    encoder_weights=config['encoder_weights'],
    input_size=config['input_size'],
    input_bands=config['input_bands'],
    embed_dim=config['embed_dim'],
    output_layers=config['output_layers'],
    output_dim=config['output_dim'],
    patch_size=config['patch_size'],
    tubelet_size=config['tubelet_size'],
    num_heads=config['num_heads'],
    mlp_ratio=config['mlp_ratio'],
    depth=config['depth'],
    in_chans=config['in_chans'],
    download_url=config['download_url']
)

encoder.initialize_weights()
encoder.load_encoder_weights(logger=Logger('encoder'))

for param in encoder.parameters():
    param.requires_grad = False

targets  = []

embeds = {
    'hls':[],
    's2':[],
    'l9':[],
    'l8':[],
}
cls_tokens = {
    'hls':[],
    's2':[],
    'l9':[],
    'l8':[],
}

# if not os.path.isdir(os.path.join(embed_dir,'prithvi','hls')):
#     os.makedirs(os.path.join(embed_dir,'prithvi','hls'))

for i in range(len(dset['hls'])):

    hls_im = dset['hls'][i]
    s2_im = dset['s2'][i]
    l9_im = dset['l9'][i]
    l8_im = dset['l8'][i]
    target = dset['target'][i]
 
    hls_im = torch.unsqueeze(torch.unsqueeze(hls_im,dim=0),dim=2)
    s2_im = torch.unsqueeze(torch.unsqueeze(s2_im,dim=0),dim=2)
    l9_im = torch.unsqueeze(torch.unsqueeze(l9_im,dim=0),dim=2)
    l8_im = torch.unsqueeze(torch.unsqueeze(l8_im,dim=0),dim=2)

    hls_im = (hls_im - means['hls'][None,:,None,None,None]) / stds['hls'][None,:,None,None,None]
    s2_im = (s2_im - means['s2'][None,:,None,None,None]) / stds['s2'][None,:,None,None,None]
    l9_im = (l9_im - means['l9'][None,:,None,None,None]) / stds['l9'][None,:,None,None,None]
    l8_im = (l8_im - means['l8'][None,:,None,None,None]) / stds['l8'][None,:,None,None,None]

    with torch.no_grad():
        hls_cls, hls_embed = encoder.forward(image={'optical':hls_im},return_cls=True)
        s2_cls, s2_embed = encoder.forward(image={'optical':s2_im[:,1:7,:,:,:]},return_cls=True)
        l9_cls, l9_embed = encoder.forward(image={'optical':l9_im},return_cls=True)
        l8_cls, l8_embed = encoder.forward(image={'optical':l8_im},return_cls=True)

        embeds['hls'].append(hls_embed[-1].detach().numpy())
        embeds['s2'].append(s2_embed[-1].detach().numpy())
        embeds['l9'].append(l9_embed[-1].detach().numpy())
        embeds['l8'].append(l8_embed[-1].detach().numpy())

        cls_tokens['hls'].append(hls_cls[-1].detach().numpy())
        cls_tokens['s2'].append(s2_cls[-1].detach().numpy())
        cls_tokens['l9'].append(l9_cls[-1].detach().numpy())
        cls_tokens['l8'].append(l8_cls[-1].detach().numpy())

        targets.append(rearrange(target,'(p h) (q w) -> p q h w',p=config['patch_size'],q=config['patch_size']))

embeds['hls'] = np.concatenate(embeds['hls'],axis=0)
embeds['s2'] = np.concatenate(embeds['s2'],axis=0)
embeds['l9'] = np.concatenate(embeds['l9'],axis=0)
embeds['l8'] = np.concatenate(embeds['l8'],axis=0)
targets = np.stack(targets,axis=0)

cls_tokens['hls'] = np.concatenate(cls_tokens['hls'],axis=0)
cls_tokens['s2'] = np.concatenate(cls_tokens['s2'],axis=0)
cls_tokens['l9'] = np.concatenate(cls_tokens['l9'],axis=0)
cls_tokens['l8'] = np.concatenate(cls_tokens['l8'],axis=0)

hls_embeds = rearrange(embeds['hls'],'n d h w -> (n h w) d')
s2_embeds = rearrange(embeds['s2'],'n d h w -> (n h w) d')
l9_embeds = rearrange(embeds['l9'],'n d h w -> (n h w) d')
l8_embeds = rearrange(embeds['l8'],'n d h w -> (n h w) d')

prithvi_save_dir = os.path.join(embed_dir,'prithvi')
if not os.path.isdir(prithvi_save_dir):
    os.makedirs(prithvi_save_dir)

for key in embeds.keys():
    embed_fname = f'{key}_embeds.npy'
    cls_fname = f'{key}_cls_tokens.npy'
    np.save(os.path.join(prithvi_save_dir,embed_fname),embeds[key])
    np.save(os.path.join(prithvi_save_dir,cls_fname),cls_tokens[key])

np.save(os.path.join(embed_dir,'labels.npy'),targets)

{'encoder_weights': './pretrained_models/Prithvi_100M.pt', 'download_url': 'https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-1.0-100M/resolve/main/Prithvi_100M.pt?download=true', 'embed_dim': 768, 'input_size': 224, 'in_chans': 6, 'patch_size': 16, 'num_heads': 12, 'depth': 12, 'mlp_ratio': 4, 'tubelet_size': 1, 'num_frames': 1, 'input_bands': {'optical': ['B2', 'B3', 'B4', 'B8A', 'B11', 'B12']}, 'output_layers': [3, 5, 7, 11], 'output_dim': 768}


Incompatible parameters:
pos_embed: expected torch.Size([1, 197, 768]) but found torch.Size([1, 589, 768])


---

1b. Generate DOFA Embeddings

---

In [73]:
import yaml

with open(dofa_config) as stream:
    try:
        config = yaml.safe_load(stream)

        config.pop('_target_')
        config['num_frames'] = 1
    except yaml.YAMLError as exc:
        print(exc)

print(config)

config['input_bands'] = {
    'optical': {
        'B2',
        'B3',
        'B4',
        'B8A',
        'B11',
        'B12',
    }
}

encoder = DOFA_Encoder(
    encoder_weights=config['encoder_weights'],
    input_size=config['input_size'],
    input_bands=config['input_bands'],
    embed_dim=config['embed_dim'],
    output_layers=config['output_layers'],
    output_dim=config['output_dim'],
    patch_size=config['patch_size'],
    num_heads=config['num_heads'],
    mlp_ratio=config['mlp_ratio'],
    depth=config['depth'],
    wave_list=config['wave_list'],
    download_url=config['download_url']
)

encoder.load_encoder_weights(logger=Logger('encoder'))

for param in encoder.parameters():
    param.requires_grad = False

embeds = {
    'hls':[],
    's2':[],
    'l9':[],
    'l8':[],
}
cls_tokens = {
    'hls':[],
    's2':[],
    'l9':[],
    'l8':[],
}

# if not os.path.isdir(os.path.join(embed_dir,'prithvi','hls')):
#     os.makedirs(os.path.join(embed_dir,'prithvi','hls'))

for i in range(len(dset['hls'])):
    hls_im = dset['hls'][i]
    s2_im = dset['s2'][i]
    l9_im = dset['l9'][i]
    l8_im = dset['l8'][i]
 
    hls_im = torch.unsqueeze(torch.unsqueeze(hls_im,dim=0),dim=2)
    s2_im = torch.unsqueeze(torch.unsqueeze(s2_im,dim=0),dim=2)
    l9_im = torch.unsqueeze(torch.unsqueeze(l9_im,dim=0),dim=2)
    l8_im = torch.unsqueeze(torch.unsqueeze(l8_im,dim=0),dim=2)

    hls_im = (hls_im - means['hls'][None,:,None,None,None]) / stds['hls'][None,:,None,None,None]
    s2_im = (s2_im - means['s2'][None,:,None,None,None]) / stds['s2'][None,:,None,None,None]
    l9_im = (l9_im - means['l9'][None,:,None,None,None]) / stds['l9'][None,:,None,None,None]
    l8_im = (l8_im - means['l8'][None,:,None,None,None]) / stds['l8'][None,:,None,None,None]

    with torch.no_grad():
        hls_cls, hls_embed = encoder.forward(image={'optical':hls_im},return_cls=True)
        s2_cls, s2_embed = encoder.forward(image={'optical':s2_im[:,1:7,:,:,:]},return_cls=True)
        l9_cls, l9_embed = encoder.forward(image={'optical':l9_im},return_cls=True)
        l8_cls, l8_embed = encoder.forward(image={'optical':l8_im},return_cls=True)

        embeds['hls'].append(hls_embed[-1].detach().numpy())
        embeds['s2'].append(s2_embed[-1].detach().numpy())
        embeds['l9'].append(l9_embed[-1].detach().numpy())
        embeds['l8'].append(l8_embed[-1].detach().numpy())

        cls_tokens['hls'].append(hls_cls[-1].detach().numpy())
        cls_tokens['s2'].append(s2_cls[-1].detach().numpy())
        cls_tokens['l9'].append(l9_cls[-1].detach().numpy())
        cls_tokens['l8'].append(l8_cls[-1].detach().numpy())


embeds['hls'] = np.concatenate(embeds['hls'],axis=0)
embeds['s2'] = np.concatenate(embeds['s2'],axis=0)
embeds['l9'] = np.concatenate(embeds['l9'],axis=0)
embeds['l8'] = np.concatenate(embeds['l8'],axis=0)

cls_tokens['hls'] = np.concatenate(cls_tokens['hls'],axis=0)
cls_tokens['s2'] = np.concatenate(cls_tokens['s2'],axis=0)
cls_tokens['l9'] = np.concatenate(cls_tokens['l9'],axis=0)
cls_tokens['l8'] = np.concatenate(cls_tokens['l8'],axis=0)

hls_embeds = rearrange(embeds['hls'],'n d h w -> (n h w) d')
s2_embeds = rearrange(embeds['s2'],'n d h w -> (n h w) d')
l9_embeds = rearrange(embeds['l9'],'n d h w -> (n h w) d')
l8_embeds = rearrange(embeds['l8'],'n d h w -> (n h w) d')

dofa_save_dir = os.path.join(embed_dir,'dofa')
if not os.path.isdir(dofa_save_dir):
    os.makedirs(dofa_save_dir)

for key in embeds.keys():
    embed_fname = f'{key}_embeds.npy'
    cls_fname = f'{key}_cls_tokens.npy'
    np.save(os.path.join(dofa_save_dir,embed_fname),embeds[key])
    np.save(os.path.join(dofa_save_dir,cls_fname),cls_tokens[key])

{'encoder_weights': './pretrained_models/DOFA_ViT_base_e100.pth', 'download_url': 'https://huggingface.co/XShadow/DOFA/resolve/main/DOFA_ViT_base_e100.pth', 'embed_dim': 768, 'input_size': 224, 'patch_size': 16, 'depth': 12, 'num_heads': 12, 'mlp_ratio': 4, 'use_norm': False, 'input_bands': '${dataset.bands}', 'wave_list': {'optical': {'B1': 0.44, 'B2': 0.49, 'B3': 0.56, 'B4': 0.665, 'B5': 0.705, 'B6': 0.74, 'B7': 0.783, 'B8': 0.832, 'B8A': 0.864, 'B9': 0.945, 'B10': 1.373, 'B11': 1.61, 'B12': 2.2}, 'sar': {'VV': 3.75, 'VH': 3.75, 'ASC_VV': 3.75, 'ASC_VH': 3.75, 'DSC_VV': 3.75, 'DSC_VH': 3.75, 'VV-VH': 3.75}}, 'output_layers': [3, 5, 7, 11], 'output_dim': 768, 'num_frames': 1}
