In [1]:
import os

# Replace 'your/directory/path' with the path to the directory you want to set as the current working directory
os.chdir('/teamspace/studios/this_studio/prithvi-pytorch')

# To verify that the current working directory has been changed, you can use:
print(os.getcwd())

/teamspace/studios/this_studio/prithvi-pytorch


In [4]:
import lightning
from torchgeo.trainers import SemanticSegmentationTask

from prithvi_pytorch import PrithviUnet
from prithvi_pytorch.datasets import HLSBurnScarsDataModule

CKPT_PATH = "weights/Prithvi_100M.pt"
CFG_PATH = "weights/Prithvi_100M_config.yaml"

In [5]:
class PrithviSegmentationTask(SemanticSegmentationTask):
    def configure_models(self):
        self.model = PrithviUnet(
            num_classes=self.hparams["num_classes"],
            cfg_path=CFG_PATH,
            ckpt_path=CKPT_PATH,
            in_chans=self.hparams["in_channels"],
            img_size=512,
            n=[2, 5, 8, 11],
            norm=False,
            decoder_channels=[256, 128, 64, 32],
            freeze_encoder=False,
        )

In [6]:
module = PrithviSegmentationTask(
    in_channels=6, num_classes=2, loss="focal", lr=1e-3, patience=10, ignore_index=0
)
datamodule = HLSBurnScarsDataModule(
    root="data/hls_burn_scars",
    batch_size=4,
    num_workers=8,
)

In [4]:
trainer = lightning.Trainer(
    accelerator="gpu", 
    logger=True, 
    max_epochs=20, 
    precision="16-mixed"
)

Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [5]:
trainer.fit(model=module, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type             | Params
---------------------------------------------------
0 | criterion     | FocalLoss        | 0     
1 | train_metrics | MetricCollection | 0     
2 | val_metrics   | MetricCollection | 0     
3 | test_metrics  | MetricCollection | 0     
4 | model         | PrithviUnet      | 119 M 
---------------------------------------------------
118 M     Trainable params
1.3 M     Non-trainable params
119 M     Total params
479.962   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

/root/miniconda3/envs/torchenv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
trainer.test(datamodule=datamodule)