In [1]:
from dl_toolbox import datamodules
from torchvision import tv_tensors
import torchvision.transforms.v2 as v2

tf_train = v2.Compose([
    v2.RandomCrop(size=(504, 504)),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

tf_test = v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

dm = datamodules.Flair(
    data_path='/data',
    merge='main13',
    bands=[1,2,3],
    sup=1,
    unsup=0,
    train_tf=tf_train,
    test_tf=tf_test,
    batch_size=4,
    num_workers=6,
    pin_memory=True
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from dl_toolbox.callbacks import ProgressBar, FeatureFt, Lora, TiffPredsWriter
from dl_toolbox.modules import Segmenter
from functools import partial
from dl_toolbox.losses import CrossEntropy
from dl_toolbox.transforms import Sliding

module = Segmenter(
    num_classes=13,
    backbone='vit_small_patch14_dinov2',
    optimizer=partial(torch.optim.Adam, lr=0.001),
    scheduler=partial(torch.optim.lr_scheduler.ConstantLR, factor=1),
    loss=CrossEntropy(),
    batch_tf=None,
    metric_ignore_index=None,
    tta=None,
    sliding=Sliding(
        nols=512,
        nrows=512,
        width=504,
        height=504,
        step_w=500,
        step_h=500
    )
)

ckpt = ModelCheckpoint(
    dirpath='/tmp',
    filename="epoch_{epoch:03d}",
    save_last=True
)

lora = Lora('encoder', 4)

trainer = pl.Trainer(
    accelerator='gpu',
    devices=1,
    max_epochs=1,
    limit_train_batches=1,
    limit_val_batches=1,
    callbacks=[ProgressBar(), lora, ckpt]
)

trainer.fit(
    module,
    datamodule=dm
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.


Processing domains


/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /tmp exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type                      | Params
-----------------------------------------------------------
0 | encoder      | VisionTransformer         | 22.4 M
1 | decoder      | DecoderLinear             | 5.0 K 
2 | loss         | CrossEntropy              | 0     
3 | val_accuracy | MulticlassAccuracy        | 0     
4 | val_cm       | MulticlassConfusionMatrix | 0     
5 | val_jaccard  | MulticlassJaccardIndex    | 0     
-----------------------------------------------------------
299 K     Trainable params
22.1 M    Non-trainable params
22.4 M    Total params
89.424    Total estimated model params size (MB)


The model will start training with only 299917 trainable parameters out of 22356109.
1 params do not undergo weight decay
Sanity Checking DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  0.98it/s]

/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 4. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


                                                                                                                                                                

/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  0.74it/s, v_num=35]
Epoch 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  0.45it/s, v_num=35][A

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  0.40it/s, v_num=35]


In [3]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from dl_toolbox.callbacks import ProgressBar, FeatureFt, Lora, TiffPredsWriter
from dl_toolbox.modules import Segmenter
from functools import partial
from dl_toolbox.losses import CrossEntropy
from dl_toolbox.transforms import Sliding

module = Segmenter(
    num_classes=13,
    backbone='vit_small_patch14_dinov2',
    optimizer=partial(torch.optim.Adam, lr=0.001),
    scheduler=partial(torch.optim.lr_scheduler.ConstantLR, factor=1),
    loss=CrossEntropy(),
    batch_tf=None,
    metric_ignore_index=None,
    tta=None,
    sliding=Sliding(
        nols=512,
        nrows=512,
        width=504,
        height=504,
        step_w=500,
        step_h=500
    )
)

lora = Lora('encoder', 4)

writer = TiffPredsWriter(
    out_path='/tmp/preds',
    base='/data'
)

trainer = pl.Trainer(
    accelerator='gpu',
    devices=1,
    max_epochs=1,
    limit_predict_batches=100,
    callbacks=[lora, writer]
)

trainer.predict(
    module,
    datamodule=dm,
    ckpt_path='/data/outputs/flair_segmenter/remote_ckpt/2024-05-16_115135/0/checkpoints/last.ckpt',
    return_predictions=False
)

import pandas as pd
df = pd.read_csv(writer.out_path / 'stats.csv', index_col=0)
df

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Processing domains


Restoring states from the checkpoint path at /data/outputs/flair_segmenter/remote_ckpt/2024-05-16_115135/0/checkpoints/last.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /data/outputs/flair_segmenter/remote_ckpt/2024-05-16_115135/0/checkpoints/last.ckpt


Predicting DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:34<00:00,  2.93it/s]


Unnamed: 0,img_path,pred_path,conf,acc
0,/data/FLAIR_1/train/D058_2020/Z3_UU/img/IMG_03...,/tmp/preds/FLAIR_1/train/D058_2020/Z3_UU/img/I...,0.877644,tensor(0.8700)
1,/data/FLAIR_1/train/D017_2018/Z17_AA/img/IMG_0...,/tmp/preds/FLAIR_1/train/D017_2018/Z17_AA/img/...,0.761988,tensor(0.8269)
2,/data/FLAIR_1/train/D072_2019/Z6_UU/img/IMG_04...,/tmp/preds/FLAIR_1/train/D072_2019/Z6_UU/img/I...,0.826615,tensor(0.7786)
3,/data/FLAIR_1/train/D074_2020/Z1_NN/img/IMG_05...,/tmp/preds/FLAIR_1/train/D074_2020/Z1_NN/img/I...,0.906497,tensor(0.9382)
4,/data/FLAIR_1/train/D058_2020/Z11_FA/img/IMG_0...,/tmp/preds/FLAIR_1/train/D058_2020/Z11_FA/img/...,0.953491,tensor(0.9882)
...,...,...,...,...
395,/data/FLAIR_1/train/D063_2019/Z17_AF/img/IMG_0...,/tmp/preds/FLAIR_1/train/D063_2019/Z17_AF/img/...,0.787233,tensor(0.8437)
396,/data/FLAIR_1/train/D080_2021/Z1_UA/img/IMG_05...,/tmp/preds/FLAIR_1/train/D080_2021/Z1_UA/img/I...,0.707491,tensor(0.5589)
397,/data/FLAIR_1/train/D058_2020/Z14_AA/img/IMG_0...,/tmp/preds/FLAIR_1/train/D058_2020/Z14_AA/img/...,0.872598,tensor(0.)
398,/data/FLAIR_1/train/D049_2020/Z5_UF/img/IMG_03...,/tmp/preds/FLAIR_1/train/D049_2020/Z5_UF/img/I...,0.775064,tensor(0.2953)
