In [1]:
import os

# Replace 'your/directory/path' with the path to the directory you want to set as the current working directory
os.chdir('/teamspace/studios/this_studio/prithvi-pytorch')

# To verify that the current working directory has been changed, you can use:
print(os.getcwd())

/teamspace/studios/this_studio/prithvi-pytorch


In [2]:
import lightning
import kornia.augmentation as K
from torchgeo.datamodules import EuroSATDataModule
from torchgeo.trainers import ClassificationTask
from torchgeo.transforms import AugmentationSequential

from prithvi_pytorch import PrithviViT
from prithvi_pytorch.model import BANDS

CKPT_PATH = "weights/Prithvi_100M.pt"
CFG_PATH = "weights/Prithvi_100M_config.yaml"

In [3]:
os.listdir('weights')

['Prithvi_100M.pt', 'Prithvi_100M_config.yaml']

In [4]:
class CustomEuroSATDataModule(EuroSATDataModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.train_aug = AugmentationSequential(
            K.Normalize(mean=self.mean, std=self.std),
            K.Resize(size=(224, 224)),
            K.RandomHorizontalFlip(p=0.5),
            K.RandomVerticalFlip(p=0.5),
            data_keys=["image"],
        )
        self.val_aug = AugmentationSequential(
            K.Normalize(mean=self.mean, std=self.std),
            K.Resize(size=(224, 224)),
            data_keys=["image"],
        )
        self.test_aug = AugmentationSequential(
            K.Normalize(mean=self.mean, std=self.std),
            K.Resize(size=(224, 224)),
            data_keys=["image"],
        )


class PrithviClassificationTask(ClassificationTask):
    def configure_models(self):
        self.model = PrithviViT(
            num_classes=self.hparams["num_classes"],
            cfg_path=CFG_PATH,
            ckpt_path=CKPT_PATH,
            in_chans=self.hparams["in_channels"],
            img_size=224,
            freeze_encoder=True,
        )

In [5]:
module = PrithviClassificationTask(
    in_channels=6, num_classes=10, loss="ce", lr=1e-3, patience=10
)
datamodule = CustomEuroSATDataModule(
    root="data/eurosat", batch_size=64, num_workers=8, bands=BANDS
)

In [5]:
trainer = lightning.Trainer(
    accelerator="gpu", 
    logger=True, 
    max_epochs=20, 
    precision="16-mixed"
)

Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
trainer.fit(model=module, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type             | Params
---------------------------------------------------
0 | criterion     | CrossEntropyLoss | 0     
1 | train_metrics | MetricCollection | 0     
2 | val_metrics   | MetricCollection | 0     
3 | test_metrics  | MetricCollection | 0     
4 | model         | PrithviViT       | 112 M 
---------------------------------------------------
7.7 K     Trainable params
112 M     Non-trainable params
112 M     Total params
451.597   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


In [8]:
trainer.test(datamodule=datamodule)

Restoring states from the checkpoint path at /workspace/storage/github/prithvi-pytorch/lightning_logs/version_2/checkpoints/epoch=19-step=5080.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loaded model weights from the checkpoint at /workspace/storage/github/prithvi-pytorch/lightning_logs/version_2/checkpoints/epoch=19-step=5080.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 0.32776209712028503,
  'test_AverageAccuracy': 0.890811026096344,
  'test_F1Score': 0.8937036991119385,
  'test_JaccardIndex': 0.8079135417938232,
  'test_OverallAccuracy': 0.8937036991119385}]

```
[{'test_loss': 0.32776209712028503,
  'test_AverageAccuracy': 0.890811026096344,
  'test_F1Score': 0.8937036991119385,
  'test_JaccardIndex': 0.8079135417938232,
  'test_OverallAccuracy': 0.8937036991119385}]
```