In [1]:
from prithvi.prithvi_mae import PrithviMAE
from prithvi.prithvi_mae import PrithviViT
from transformers import BertModel, BertTokenizer
import os
import matplotlib.pyplot as plt
import numpy as np
import rasterio
import yaml
from rasterio.enums import Resampling
import torch.nn as nn
import torch
from torch import optim
import lightning as pl
from rasterio.plot import reshape_as_image
import pandas as pd
import ast
from datasets import Dataset
from PIL import Image
from torch.utils.data import DataLoader, IterableDataset

## Settings

In [2]:
weights_path = "prithvi/Prithvi_EO_V1_100M.pt"
model_cfg_path = "./prithvi/config.yaml"
with open(model_cfg_path) as f:
    model_config = yaml.safe_load(f)

model_args, train_args = model_config["model_args"], model_config["train_params"]
model_args["num_frames"] = 1
model_args["encoder_only"] = True

## Data preprocessing and loading

In [3]:
means = np.array(train_args["data_mean"]).reshape(-1, 1, 1)
stds = np.array(train_args["data_std"]).reshape(-1, 1, 1)

NO_DATA = -9999
NO_DATA_FLOAT = 0.0001
PERCENTILES = (0.1, 99.9)

In [4]:
def preprocess_image(image):
    normalized = image.copy()
    normalized = np.expand_dims(((image - means) / stds), axis=1)
    normalized = torch.from_numpy(normalized).to(torch.float32)
    return normalized

In [5]:
def load_raster(path, crop=(112, 112)):
    with rasterio.open(path) as src:
        img = src.read()
        img = np.where(img == NO_DATA, NO_DATA_FLOAT, img)
        if crop:
            img = img[:, -crop[0]:, -crop[1]:]
    return img

def binary_data_gen_112(df, data_dir = '6d_data'):
    for idx, row in df.iterrows():

        file = os.path.join(data_dir, row.tile_name, row.patch_name)
        img = load_raster(file)
        normalized_img = preprocess_image(img)
        
        tokens = tokenizer(row.question, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
        input_ids = tokens['input_ids'].squeeze(0)
        attention_mask = tokens['attention_mask'].squeeze(0)

        label = torch.tensor(np.expand_dims(np.array(row.binary_answer), axis=0), dtype=torch.float32)
    
        yield normalized_img, input_ids, attention_mask, label

In [6]:
def load_raster_224(path):
    with rasterio.open(path) as src:
        img = src.read(out_shape=(src.count, 224, 224), resampling=Resampling.bilinear)
        img = np.where(img == NO_DATA, NO_DATA_FLOAT, img)
    return img

def binary_data_gen(df, data_dir = '6d_data'):
    for idx, row in df.iterrows():

        file = os.path.join(data_dir, row.tile_name, row.patch_name)
        img = load_raster_224(file)
        normalized_img = preprocess_image(img)
        
        tokens = tokenizer(row.question, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
        input_ids = tokens['input_ids'].squeeze(0)
        attention_mask = tokens['attention_mask'].squeeze(0)

        label = torch.tensor(np.expand_dims(np.array(row.binary_answer), axis=0), dtype=torch.float32)
        
        yield normalized_img, input_ids, attention_mask, label

In [7]:
class IterDataset(IterableDataset):
    def __init__(self, generator, df):
        self.generator = generator
        self.df = df

    def __iter__(self):
        return self.generator(self.df)

In [8]:
df = pd.read_csv('binary_questions_and_answers.csv', index_col=0, converters={"generic_question": ast.literal_eval})
df = df.replace({np.nan: 'None'}).sample(frac=1)

In [16]:
df_train = df.query("split == 'train'")[:100]
df_val = df.query("split == 'validation'")[:100]
df_test = df.query("split == 'test'")[:100]

In [17]:
train_ds = IterDataset(binary_data_gen, df_train)
val_ds = IterDataset(binary_data_gen, df_val)
test_ds = IterDataset(binary_data_gen, df_test)

In [25]:
train_dataloader = DataLoader(train_ds, batch_size=4)
val_dataloader = DataLoader(val_ds, batch_size=4)
test_dataloader = DataLoader(test_ds, batch_size=4)

## Model definition

#### Vision encoder

In [26]:
checkpoint = torch.load(weights_path, map_location="cpu")
vision_encoder = PrithviViT(**model_args)
#del checkpoint['encoder.pos_embed']
#del checkpoint['decoder.decoder_pos_embed']
_ = vision_encoder.load_state_dict(checkpoint, strict=False)

#### Text encoder

In [27]:
text_encoder = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

#### Multimodal model

In [34]:
class rsvqa_pl(pl.LightningModule):
    def __init__(self, vision_encoder, text_encoder):
        super().__init__()

        self.vision_encoder = vision_encoder
        self.text_encoder = text_encoder
        for param in self.vision_encoder.parameters():
            param.requires_grad = False
        for param in self.text_encoder.parameters():
            param.requires_grad = False

        self.fusion_layer = nn.Sequential(nn.Linear(768+768, 128), nn.ReLU())
        self.classification_layer = nn.Sequential(nn.Linear(128, 1), nn.Sigmoid())

    def forward(self, normalised_image, text_tokens, attention_mask):
        img_embedding, _, _ = vision_encoder(normalised_image)
        img_cls = img_embedding[:,0,:]
        txt_embedding = text_encoder(input_ids=text_tokens, attention_mask=attention_mask)
        txt_cls = txt_embedding.last_hidden_state[:,0,:]
        fused_embedding = torch.cat([img_cls, txt_cls], dim=1)
        fused_projection = self.fusion_layer(fused_embedding)
        prediction = self.classification_layer(fused_projection)
        return prediction

    def training_step(self, batch, batch_idx):
        images, input_ids, attention_mask, labels = batch
        prediction = self.forward(images, input_ids, attention_mask)
        loss = nn.functional.binary_cross_entropy(prediction, labels)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, input_ids, attention_mask, labels = batch
        prediction = self.forward(images, input_ids, attention_mask)
        loss = nn.functional.binary_cross_entropy(prediction, labels)
        preds = prediction > 0.5
        acc = (preds == labels).float().mean()
        self.log('val_loss', loss, prog_bar=True)
        self.log('acc', acc, prog_bar=True)
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

## Training

In [35]:
rsvqa_model = rsvqa_pl(vision_encoder, text_encoder)
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model=rsvqa_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name                 | Type       | Params | Mode 
------------------------------------------------------------
0 | vision_encoder       | PrithviViT | 86.2 M | eval 
1 | text_encoder         | BertModel  | 109 M  | eval 
2 | fusion_layer         | Sequential | 196 K  | train
3 | classification_layer | Sequential | 129    | train
------------------------------------------------------------
196 K     Trainable params
195 M     Non-trainable params
195 M     Total params
783.665   Total estimated model params size (MB)
6         Modules in train mode
486       Modules in eval mode


Sanity Checking: |                                                                                        | 0/…

/Users/wouter/miniconda3/envs/genai/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |                                                                                               | 0/…

Validation: |                                                                                             | 0/…

Validation: |                                                                                             | 0/…

Validation: |                                                                                             | 0/…

Validation: |                                                                                             | 0/…

Validation: |                                                                                             | 0/…

Validation: |                                                                                             | 0/…

Validation: |                                                                                             | 0/…

Validation: |                                                                                             | 0/…

Validation: |                                                                                             | 0/…

Validation: |                                                                                             | 0/…

`Trainer.fit` stopped: `max_epochs=10` reached.


In [36]:
tensorboard --logdir .

SyntaxError: invalid syntax (1888371297.py, line 1)