In [1]:
from dl_toolbox import datamodules
from torchvision import tv_tensors
import torchvision.transforms.v2 as v2

tf_train = v2.Compose([
    v2.RandomCrop(size=(504, 504)),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

tf_test = v2.Compose([
    v2.RandomCrop(size=(504, 504)),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dm = datamodules.Flair(
    data_path='/data',
    merge='main13',
    bands=[1,2,3],
    sup=1,
    unsup=0,
    train_tf=tf_train,
    test_tf=tf_test,
    batch_tf=None,
    batch_size=4,
    num_workers=6,
    pin_memory=True
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%matplotlib inline

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from dl_toolbox.callbacks import ProgressBar, Lora, TiffPredsWriter, CalibrationLogger
from dl_toolbox.modules import Segmenter
from functools import partial
from dl_toolbox.losses import CrossEntropy
from dl_toolbox.transforms import Sliding

module = Segmenter(
    num_classes=13,
    backbone='vit_small_patch14_dinov2',
    optimizer=partial(torch.optim.Adam, lr=0.001),
    scheduler=partial(torch.optim.lr_scheduler.ConstantLR, factor=1),
    loss=CrossEntropy(),
    batch_tf=None,
    metric_ignore_index=None,
    tta=None,
    onehot=False,
    sliding=Sliding(
        nols=512,
        nrows=512,
        width=504,
        height=504,
        step_w=500,
        step_h=500
    )
)

ckpt = ModelCheckpoint(
    dirpath='/tmp',
    filename="epoch_{epoch:03d}",
    save_last=True
)

lora = Lora('encoder', 4, False)
calib = CalibrationLogger(freq=1)

trainer = pl.Trainer(
    accelerator='gpu',
    devices=1,
    max_epochs=1,
    limit_train_batches=1,
    limit_val_batches=1,
    callbacks=[ProgressBar(), lora, ckpt, calib]
)

trainer.fit(
    module,
    datamodule=dm
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
Missing logger folder: /d/pfournie/dl_toolbox/dl_toolbox/à ranger/segmenter/lightning_logs


Processing domains


/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /tmp exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type                      | Params
-----------------------------------------------------------
0 | backbone     | VisionTransformer         | 22.1 M
1 | decoder      | DecoderLinear             | 5.0 K 
2 | loss         | CrossEntropy              | 0     
3 | val_accuracy | MulticlassAccuracy        | 0     
4 | val_cm       | MulticlassConfusionMatrix | 0     
5 | val_jaccard  | MulticlassJaccardIndex    | 0     
-----------------------------------------------------------
22.1 M    Trainable params
0         Non-trainable params
22.1 M    Total params
88.245    Total estimated model params size (MB)


Training 22061197 params out of 22061197.
Sanity Checking DataLoader 0:   0%|                                                                                                       | 0/1 [00:00<?, ?it/s]torch.Size([4, 1296, 13])
Sanity Checking DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  0.91it/s]

/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 4. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


torch.Size([4, 1296, 13])
                                                                                                                                                                

/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0:   0%|                                                                                                                            | 0/1 [00:00<?, ?it/s]torch.Size([4, 1296, 13])
Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  0.48it/s, v_num=0]
Validation: |                                                                                                                             | 0/? [00:00<?, ?it/s][Atorch.Size([4, 1296, 13])
torch.Size([4, 1296, 13])


/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [6]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from dl_toolbox.callbacks import ProgressBar, Finetuning, Lora, TiffPredsWriter
from dl_toolbox.modules import Segmenter
from functools import partial
from dl_toolbox.losses import CrossEntropy
from dl_toolbox.transforms import Sliding

module = Segmenter(
    num_classes=13,
    backbone='vit_small_patch14_dinov2',
    optimizer=partial(torch.optim.Adam, lr=0.001),
    scheduler=partial(torch.optim.lr_scheduler.ConstantLR, factor=1),
    loss=CrossEntropy(),
    batch_tf=None,
    metric_ignore_index=None,
    tta=None,
    onehot=False,
    sliding=Sliding(
        nols=512,
        nrows=512,
        width=504,
        height=504,
        step_w=500,
        step_h=500
    )
)

lora = Lora('encoder', 4, activated=True)

writer = TiffPredsWriter(
    out_path='/tmp/preds',
    base='/data'
)

trainer = pl.Trainer(
    accelerator='gpu',
    devices=1,
    max_epochs=1,
    limit_predict_batches=100,
    callbacks=[lora, writer]
)

trainer.predict(
    module,
    datamodule=dm,
    ckpt_path='/data/outputs/flair_segmenter/remote_ckpt/2024-05-16_115135/0/checkpoints/last.ckpt',
    return_predictions=False
)

import pandas as pd
df = pd.read_csv(writer.out_path / 'stats.csv', index_col=0)

FileExistsError: [Errno 17] File exists: '/tmp/preds'

In [None]:
sorted_df = df.sort_values('acc')
sorted_df

In [None]:
row = 180
img_path = sorted_df.iloc[[row]]['img_path'].item()
pred_path = sorted_df.iloc[[row]]['pred_path'].item()
msk_path = img_path.replace('img', 'msk').replace('IMG', 'MSK')
print(msk_path)

In [None]:
from dl_toolbox.datasets import Flair
import rasterio
import numpy as np
from dl_toolbox.utils import merge_labels
import torchmetrics.functional.classification as metrics

with rasterio.open(img_path, "r") as f:
    image = f.read(out_dtype=np.uint8, indexes=[1,2,3])

classes = Flair.classes['main13'].value
with rasterio.open(msk_path, "r") as f:
    mask = torch.from_numpy(f.read(out_dtype=np.uint8))
    mask = merge_labels(mask, [list(l.values) for l in classes]) 
    
with rasterio.open(pred_path, "r") as f:
    pred = torch.from_numpy(f.read(out_dtype=np.uint8))

conf_mat = metrics.multiclass_confusion_matrix(
    pred,
    mask,
    len(classes),
    ignore_index=0
)

%matplotlib inline

from dl_toolbox.utils import plot_confusion_matrix

fig = plot_confusion_matrix(
    conf_mat,
    [c.name for c in classes],
    'precision'
)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt 
from dl_toolbox.utils import labels_to_rgb

colors = [(i, c.color) for i, c in enumerate(classes)]
overlay = np.zeros(shape=(*pred.squeeze().shape, 3), dtype=np.uint8)
idx = 3
label_bool = mask.squeeze() == idx
pred_bool = pred.squeeze() == idx

# Correct predictions (Hits) painted with green
overlay[label_bool & pred_bool] = np.array([0, 250, 0], dtype=overlay.dtype)
# Misses painted with red
overlay[label_bool & ~pred_bool] = np.array([250, 0, 0], dtype=overlay.dtype)
# False alarm painted with yellow
overlay[~label_bool & pred_bool] = np.array([250, 250, 0], dtype=overlay.dtype)

zone = np.s_[0:1500, 0:1500, ...]

fig = plt.figure(figsize=(20,20))
ax1 = fig.add_subplot(221)
ax1.imshow(image.transpose(1,2,0)[zone])
ax2 = fig.add_subplot(222)
ax2.imshow(labels_to_rgb(mask.squeeze(), colors)[zone])
ax3 = fig.add_subplot(223)
ax3.imshow(labels_to_rgb(pred.squeeze(), colors)[zone])
ax4 = fig.add_subplot(224)
ax4.imshow(overlay[zone])