In [23]:
import os

# Replace 'your/directory/path' with the path to the directory you want to set as the current working directory
os.chdir('/teamspace/studios/this_studio/prithvi-pytorch')

# To verify that the current working directory has been changed, you can use:
print(os.getcwd())

/teamspace/studios/this_studio/prithvi-pytorch


In [24]:
import lightning
from lightning.pytorch.loggers import WandbLogger
from torchgeo.trainers import SemanticSegmentationTask

from prithvi_pytorch import PrithviUnet
from prithvi_pytorch.datasets import HLSBurnScarsDataModule

from torchmetrics import MetricCollection
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassF1Score,
    MulticlassJaccardIndex,
)
from torchmetrics.wrappers import ClasswiseWrapper

import wandb

CKPT_PATH = "weights/Prithvi_100M.pt"
CFG_PATH = "weights/Prithvi_100M_config.yaml"

In [25]:
class PrithviSegmentationTask(SemanticSegmentationTask):
    def configure_models(self):
        self.model = PrithviUnet(
            num_classes=self.hparams["num_classes"],
            cfg_path=CFG_PATH,
            ckpt_path=CKPT_PATH,
            in_chans=self.hparams["in_channels"],
            img_size=512,
            n=[2, 5, 8, 11],
            norm=False,
            decoder_channels=[256, 128, 64, 32],
            freeze_encoder=False,
        )

    def configure_metrics(self) -> None:
        """Initialize the performance metrics."""
        num_classes: int = self.hparams["num_classes"]
        ignore_index: Optional[int] = self.hparams["ignore_index"]
        metrics = MetricCollection(
            {
                "accuracy": ClasswiseWrapper(
                    MulticlassAccuracy(
                        num_classes=num_classes, ignore_index=ignore_index, average=None
                    ),
                ),
                "jaccard": ClasswiseWrapper(
                    MulticlassJaccardIndex(
                        num_classes=num_classes, ignore_index=ignore_index, average=None
                    ),
                ),
                "f1": ClasswiseWrapper(
                    MulticlassF1Score(
                        num_classes=num_classes, ignore_index=ignore_index, average=None
                    ),
                ),
            }
        )
        self.train_metrics = metrics.clone(prefix="train_")
        self.val_metrics = metrics.clone(prefix="val_")
        self.test_metrics = metrics.clone(prefix="test_")

In [26]:
module = PrithviSegmentationTask(
    in_channels=6, num_classes=2, loss="focal", lr=1e-3, patience=10, ignore_index=0
)
datamodule = HLSBurnScarsDataModule(
    root="data/hls_burn_scars",
    batch_size=4,
    num_workers=8,
)

datamodule.setup('fit')
datamodule.setup('test')

In [27]:
len(datamodule.train_dataset)

540

In [28]:
len(datamodule.test_dataset)

264

In [29]:
wandb_logger = WandbLogger(
    project="prithvi", 
    log_model=True, # True or 'all'
    save_dir = "wandb_logs"
)

trainer = lightning.Trainer(
    accelerator="gpu", 
    logger=wandb_logger, 
    max_epochs=20, 
    precision="16-mixed"
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [30]:
trainer.fit(module, datamodule)

/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/loggers/wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | criterion     | FocalLoss        | 0     
1 | train_metrics | MetricCollection | 0     
2 | val_metrics   | MetricCollection | 0     
3 | test_metrics  | MetricCollection | 0     
4 | model         | PrithviUnet      | 119 M 
---------------------------------------------------
118 M     Trainable params
1.3 M     Non-trainable params
119 M     Total params
479.962   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

ValueError: The `.compute()` return of the metric logged as 'val_accuracy' must be a tensor. Found {'multiclassaccuracy_0': tensor(0., device='cuda:0'), 'multiclassaccuracy_1': tensor(0.0019, device='cuda:0')}

In [None]:
trainer.test(module, datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 7.786973732493152e-09,
  'test_MulticlassAccuracy': 1.0,
  'test_MulticlassJaccardIndex': 1.0}]

In [None]:
wandb.finish()

VBox(children=(Label(value='1107.939 MB of 1107.939 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))



0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
test_MulticlassAccuracy,▁
test_MulticlassJaccardIndex,▁
test_loss,▁
train_MulticlassAccuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_MulticlassJaccardIndex,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▄▆▁▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
val_MulticlassAccuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_MulticlassJaccardIndex,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,20.0
test_MulticlassAccuracy,1.0
test_MulticlassJaccardIndex,1.0
test_loss,0.0
train_MulticlassAccuracy,1.0
train_MulticlassJaccardIndex,1.0
train_loss,0.0
trainer/global_step,2700.0
val_MulticlassAccuracy,1.0
val_MulticlassJaccardIndex,1.0
