<a href="https://colab.research.google.com/github/woosterheert/rsvqa/blob/main/notebooks/prithvi_bert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

🛰 **Remote Sensing Visual Question Answering**



1. We pull all the necesary code from our repository:

In [None]:
!git clone https://github.com/woosterheert/rsvqa

Cloning into 'rsvqa'...
remote: Enumerating objects: 31, done.[K
remote: Counting objects: 100% (31/31), done.[K
remote: Compressing objects: 100% (21/21), done.[K
remote: Total 31 (delta 8), reused 21 (delta 4), pack-reused 0 (from 0)[K
Receiving objects: 100% (31/31), 16.66 KiB | 5.55 MiB/s, done.
Resolving deltas: 100% (8/8), done.


In [None]:
!cd /content/rsvqa
!git pull

fatal: not a git repository (or any of the parent directories): .git


2. We connect our notebook to our google cloud storage bucket:

In [None]:
cd /content/rsvqa

/content/rsvqa


In [None]:
from google.colab import auth
auth.authenticate_user()

!echo "deb https://packages.cloud.google.com/apt gcsfuse-`lsb_release -c -s` main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list
!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
!apt -qq update && apt -qq install gcsfuse

mount_path = "ou-genai-data"
local_path = f"/mnt/{mount_path}"

!mkdir -p {local_path}
!gcsfuse --implicit-dirs {mount_path} {local_path}

deb https://packages.cloud.google.com/apt gcsfuse-jammy main
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1022  100  1022    0     0  13029      0 --:--:-- --:--:-- --:--:-- 13102
OK
53 packages can be upgraded. Run 'apt list --upgradable' to see them.
[1;33mW: [0mhttps://packages.cloud.google.com/apt/dists/gcsfuse-jammy/InRelease: Key is stored in legacy trusted.gpg keyring (/etc/apt/trusted.gpg), see the DEPRECATION section in apt-key(8) for details.[0m
[1;33mW: [0mSkipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)[0m
The following NEW packages will be installed:
  gcsfuse
0 upgraded, 1 newly installed, 0 to remove and 53 not upgraded.
Need to get 14.6 MB of archives.
After this operation, 0 B of additional disk space will be use

3. We install necessary packages

In [None]:
!pip install rasterio
!pip install pytorch-lightning

Collecting rasterio
  Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Collecting cligj>=0.5 (from rasterio)
  Downloading cligj-0.7.2-py3-none-any.whl.metadata (5.0 kB)
Collecting click-plugins (from rasterio)
  Downloading click_plugins-1.1.1-py2.py3-none-any.whl.metadata (6.4 kB)
Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.2/22.2 MB[0m [31m19.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading cligj-0.7.2-py3-none-any.whl (7.1 kB)
Downloading affine-2.4.0-py3-none-any.whl (15 kB)
Downloading click_plugins-1.1.1-py2.py3-none-any.whl (7.5 kB)
Installing collected packages: cligj, click-plugins, affine, rasterio
Successfully installed affine-2.4.0 click-plugins-1.1.1 cligj-0.7.2 rasterio-1.4.3


In [None]:
from external.prithvi_mae import PrithviViT
from utils.data_utils import RSVQADataset
from transformers import BertModel, BertTokenizer
from torch.utils.data import DataLoader
import pandas as pd
import yaml
from models.dual_encoder import rsvqa_pl
import torch
import pytorch_lightning as pl

4. Load all the necessary data

In [None]:
weights_path = "/mnt/ou-genai-data/Prithvi_EO_V1_100M.pt"
model_cfg_path = "/mnt/ou-genai-data/prithvi_config.yaml"
with open(model_cfg_path) as f:
    model_config = yaml.safe_load(f)

model_args, train_args = model_config["model_args"], model_config["train_params"]
model_args["num_frames"] = 1
model_args["encoder_only"] = True

In [None]:
df = pd.read_csv('/mnt/ou-genai-data/questions_and_answers_binary.csv', index_col=0)

In [None]:
df_train = df.query("split == 'train'")[:100000]
df_val = df.query("split == 'validation'")[:1000]
df_test = df.query("split == 'test'")[:100]

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [None]:
train_ds = RSVQADataset(df_train, train_args, '/mnt/ou-genai-data/6d_data', tokenizer)
val_ds = RSVQADataset(df_val, train_args, '/mnt/ou-genai-data/6d_data', tokenizer)
test_ds = RSVQADataset(df_test, train_args, '/mnt/ou-genai-data/6d_data', tokenizer)

In [None]:
train_dataloader = DataLoader(train_ds, batch_size=8, num_workers=4)
val_dataloader = DataLoader(val_ds, batch_size=8, num_workers=4)
test_dataloader = DataLoader(test_ds, batch_size=8, num_workers=4)

5. Create the model

In [None]:
checkpoint = torch.load(weights_path, map_location="cpu")
vision_encoder = PrithviViT(**model_args)
del checkpoint['encoder.pos_embed']
del checkpoint['decoder.decoder_pos_embed']
_ = vision_encoder.load_state_dict(checkpoint, strict=False)

  checkpoint = torch.load(weights_path, map_location="cpu")


In [None]:
text_encoder = BertModel.from_pretrained('bert-base-uncased')

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [None]:
rsvqa_model = rsvqa_pl(vision_encoder, text_encoder)
trainer = pl.Trainer(max_epochs=10)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


6. Train the model

In [None]:
trainer.fit(model=rsvqa_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA L4') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name                 | Type       | Params | Mode 
------------------------------------------------------------
0 | vision_encoder       | PrithviViT | 86.2 M | train
1 | text_encoder         | BertModel  | 109 M  | eval 
2 | fusion_layer         | Sequential | 196 K  | train
3 | classification_layer | Sequential | 129    | train
------------------------------------------------------------
196 K     Trainable params
195 M     Non-trainable params
195 M     Total param

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]