In [1]:
import os
import timm
import torch
import terratorch
from terratorch.datasets import HLSBands
from terratorch.models import PrithviModelFactory
from torchgeo.datamodules import LandCoverAIDataModule
from terratorch.tasks import SemanticSegmentationTask
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichProgressBar
from lightning.pytorch.loggers import TensorBoardLogger

INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.15 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations
  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


In [2]:
datamodule = LandCoverAIDataModule(root="data/landcoverai", batch_size=32, num_workers=2, download=True)
datamodule.prepare_data()
datamodule.setup('fit')

In [4]:
model_args = {
    "backbone":"prithvi_vit_100",
    "decoder":"UperNetDecoder",
    "in_channels": 3,
    "num_classes": 5,
    "bands": [
        HLSBands.RED,
        HLSBands.GREEN,
        HLSBands.BLUE,
    ],
    "pretrained": True,
    "num_frames":1,
    "decoder_channels":256,
    "head_dropout":0.1,
    "decoder_scale_modules":True,     
}

task = SemanticSegmentationTask(
    model_args,
    "PrithviModelFactory",
    loss="ce",
    lr=5e-4,
    ignore_index=-1,
    optimizer="AdamW",
    optimizer_hparams={"weight_decay": 0.05},
    freeze_backbone=False,
    class_names=['Background', 'Building', 'Woodland', 'Water', 'Road'],
    class_weights=[0.02, 0.55, 0.04, 0.14, 0.25]
)

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (('ibm-nasa-geospatial/Prithvi-100M', 'Prithvi_100M.pt'))
  return torch.load(cached_file, map_location='cpu')


In [5]:
checkpoint_callback = ModelCheckpoint(monitor=task.monitor, save_top_k=1, save_last=True)
early_stopping_callback = EarlyStopping(monitor=task.monitor, min_delta=0.00, patience=20)
logger = TensorBoardLogger(save_dir='models', name='log')
torch.set_float32_matmul_precision('high')
trainer = Trainer(
    devices=1, # Number of GPUs
    precision="16-mixed",
    callbacks=[
        RichProgressBar(),
        checkpoint_callback,
        LearningRateMonitor(logging_interval="epoch"),
    ],
    logger=logger,
    max_epochs=50,
    default_root_dir='models/logs',
    log_every_n_steps=1,
    check_val_every_n_epoch=1,
    accelerator='gpu'
)
_ = trainer.fit(model=task, datamodule=datamodule)

/share/home/e2305599/.conda/envs/terratorch/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /share/home/e2305599/.conda/envs/terratorch/lib/pyth ...
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/share/home/e2305599/.conda/envs/terratorch/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/amp.py:55: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU av

Output()

INFO: `Trainer.fit` stopped: `max_epochs=50` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.


In [6]:
datamodule.setup('test')
res = trainer.test(model=task, datamodule=datamodule)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()