In [2]:
import os
import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F

import torchvision
from torchvision import transforms
from torchmetrics import (
    Accuracy,
    MetricCollection,
    JaccardIndex,
    F1Score,
    MeanSquaredError,
)
import segmentation_models_pytorch as smp
from pytorch_toolbelt.losses import JaccardLoss, BinaryFocalLoss
from torch.utils.data import Dataset, random_split, DataLoader

from pytorch_lightning import Trainer
from typing import Any, Tuple, Optional, Callable, cast

import matplotlib.pyplot as plt
import wandb
import tqdm as notebook_tqdm
import random

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

In [4]:
import segmentation_models_pytorch as smp

In [5]:
pl.seed_everything(27)

Global seed set to 27


27

In [6]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33msofstef[0m ([33mtrees[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [7]:
%load_ext autoreload
%autoreload 2

In [8]:
%load_ext lab_black

In [9]:
os.chdir("/Users/sofija/Ai4er/mres/tree-segmentation")
os.getcwd()

'/Users/sofija/Ai4er/mres/tree-segmentation'

In [10]:
from src.datasets import TreeSegments
from src.evaluation import BinaryIoU
from src.datamodules import TreeDataModule
from src.models import UNet, SegModel
from src.evaluation import LogPredictionSamplesCallback

In [11]:
dm = TreeDataModule(
    data_dir="data/raw/samples/",
    target_dir="data/segments/",
    test_data_dir="data/test_samples",
    test_target_dir="data/test_segments/",
    batch_size=8,
    num_workers=0,
    drop_last_batch=True,
)

In [12]:
best_path = "lightning_logs/colab_ckpt/epoch=78-step=632-v1.ckpt"

In [13]:
model = SegModel.load_from_checkpoint(checkpoint_path=best_path)

In [14]:
trainer = Trainer()

  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [21]:
# dropout, seeing whether frozen/unfrozen weights make a difference if this is possible here
sweep_config = {
    "entity": "trees",
    "project": "Trees",
    "method": "random",
    "metric": {"name": "val_loss", "goal": "minimize"},
    "parameters": {
        "encoder_name": {"values": ["resnet18", "resnet34", "resnet50"]},
        # dropout is not directly specified in the smp models but through aux parameters
        # "dropout": {"values": [0.3, 0.4, 0.5]},
        "batch_size": {
            "distribution": "q_log_uniform_values",
            "q": 8,
            "min": 8,
            "max": 64,
        },
        "lr": {
            # a flat distribution between 0.0001 and 0.01
            "distribution": "uniform",
            "min": 1e-4,
            "max": 1e-1,
            # log uniform distribution between exp(min) and exp(max)
            # "distribution": "log_uniform",
            # "min": -9.21,  # exp(-9.21) = 1e-4
            # "max": -4.61,  # exp(-4.61) = 1e-2
        },
        "epochs": {"value": 5},
        "encoder_weights": {"value": "imagenet"},
        "loss": {"value": "dice"},
        "num_classes": {"value": 1},
        "in_channels": {"value": 1},
        "jaccard_average": {"value": "macro"},
    },
    "early_terminate": {
        "type": "hyperband",
        "min_iter": 3,
    },
}

In [22]:
sweep_id = wandb.sweep(sweep_config, project="Trees")

Create sweep with ID: kja7hfs1
Sweep URL: https://wandb.ai/trees/Trees/sweeps/kja7hfs1


Questions about setting sweeps up:

1. Which search strategy should I use? 
2. Should batch size be passed explicitly to datamodule (example doesn't seem to do that)?
3. Should I vary the number of epochs?

In [23]:
def sweep_iteration():
    # set up W&B logger
    wandb.init()  # required to have access to `wandb.config`
    wandb_logger = WandbLogger()

    # setup data
    dm = TreeDataModule(
        data_dir="data/train/samples/",
        target_dir="data/train/segments/",
        test_data_dir="data/test/samples",
        test_target_dir="data/test/segments/",
        batch_size=wandb.config.batch_size,
        drop_last_batch=True,
    )

    # setup model - note how we refer to sweep parameters with wandb.config
    model = SegModel(
        in_channels=wandb.config.in_channels,
        encoder_name=wandb.config.encoder_name,
        encoder_weights=wandb.config.encoder_weights,
        num_classes=wandb.config.num_classes,
        loss=wandb.config.loss,
        ignore_zeros=None,
        lr=wandb.config.lr,
        jaccard_average=wandb.config.jaccard_average,
    )

    # setup Trainer
    trainer = Trainer(
        logger=wandb_logger,
        # callbacks=callbacks,
        max_epochs=wandb.config.epochs,
    )

    # train
    trainer.fit(model, dm)

In [24]:
wandb.agent(sweep_id, function=sweep_iteration, count=5)

[34m[1mwandb[0m: Agent Starting Run: o6xd5bay with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	encoder_name: resnet18
[34m[1mwandb[0m: 	encoder_weights: imagenet
[34m[1mwandb[0m: 	epochs: 100
[34m[1mwandb[0m: 	in_channels: 1
[34m[1mwandb[0m: 	jaccard_average: macro
[34m[1mwandb[0m: 	loss: dice
[34m[1mwandb[0m: 	lr: 0.00981975863692553
[34m[1mwandb[0m: 	num_classes: 1


  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

   | Name       | Type             | Params
-------------------------------------------------
0  | net        | Unet             | 14.3 M
1  | loss       | DiceLoss         | 0     
2  | train_acc  | Accuracy         | 0     
3  | val_acc    | Accuracy         | 0     
4  | test_acc   | Accuracy         | 0     
5  | train_jacc | JaccardIndex     | 0     
6  | val_jacc   | JaccardIndex     | 0     
7  | test_jacc  | JaccardIndex     | 0     
8  | train_f1   | F1Score          | 0     
9  | val_f1     | F1Score          | 0     
10 | test_f1    | F1Score          | 0     
11 | rmse       | MeanSquaredError | 0     
-------------------------------------------------
14.3 M    Trainable params
0         Non-trainable params
14.3 M    Total params
57.288    Total estimated model params size (MB)


                                                                                                   

  rank_zero_warn(


                                                                                                   

  rank_zero_warn(
  rank_zero_warn(


Epoch 0:  80%|████████████████████████▊      | 4/5 [00:07<00:01,  1.98s/it, loss=0.732, v_num=5bay]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                            | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                               | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|███████████████████████████████████████| 1/1 [00:00<00:00,  1.76it/s][A
Epoch 0: 100%|███████████████████████████████| 5/5 [00:09<00:00,  1.81s/it, loss=0.732, v_num=5bay][A
Epoch 1:  80%|████████████████████████▊      | 4/5 [00:07<00:01,  1.79s/it, loss=0.628, v_num=5bay][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                            | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                               | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|███████████████████████████████████████| 1/1 [00:00<00:00,  1.75it/s][A
Epoch 1:

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_accuracy,▁▅▆▆▇▆▆▇▆▇▇▇▇█▇█▇▇▇▇▇▇▇▇▇▇▇▇████▇███████
train_f1,▁▄▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇█▇█████████████████████
train_jaccard,▁▄▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇██████████████████
train_loss,█▄▂▆▂▂▁▁
trainer/global_step,▁▁▁▂▂▂▂▁▂▃▃▃▂▃▄▄▄▂▄▂▅▂▅▅▂▅▂▆▆▆▆▂▇▇▇▇▃███
val_accuracy,██▁▁▄▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
val_f1,▂▂▁▂▅▇▇▇▇███████████████████████████████
val_jaccard,▁▁▅▅▆▇▇▇▇███████████████████████████████
val_loss_epoch,▇▇█▇▄▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,100.0
train_accuracy,0.93072
train_f1,0.92351
train_jaccard,0.91693
train_loss,0.07871
trainer/global_step,399.0
val_accuracy,0.86503
val_f1,0.87735
val_jaccard,0.87473
val_loss_epoch,0.1227


[34m[1mwandb[0m: Agent Starting Run: irxmhibj with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	encoder_name: resnet34
[34m[1mwandb[0m: 	encoder_weights: imagenet
[34m[1mwandb[0m: 	epochs: 1
[34m[1mwandb[0m: 	in_channels: 1
[34m[1mwandb[0m: 	jaccard_average: macro
[34m[1mwandb[0m: 	loss: dice
[34m[1mwandb[0m: 	lr: 0.002348779611111376
[34m[1mwandb[0m: 	num_classes: 1


  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

   | Name       | Type             | Params
-------------------------------------------------
0  | net        | Unet             | 24.4 M
1  | loss       | DiceLoss         | 0     
2  | train_acc  | Accuracy         | 0     
3  | val_acc    | Accuracy         | 0     
4  | test_acc   | Accuracy         | 0     
5  | train_jacc | JaccardIndex     | 0     
6  | val_jacc   | JaccardIndex     | 0     
7  | test_jacc  | JaccardIndex     | 0     
8  | train_f1   | F1Score          | 0     
9  | val_f1     | F1Score          | 0     
10 | test_f1    | F1Score          | 0     
11 | rmse       | MeanSquaredError | 0     
-------------------------------------------------
24.4 M    Trainable params
0         Non-trainable params
24.4 M    Total params
97.720    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Run irxmhibj errored: ZeroDivisionError('float division by zero')
[34m[1mwandb[0m: [32m[41mERROR[0m Run irxmhibj errored: ZeroDivisionError('float division by zero')
[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: tp7f7gkr with config:
[34m[1mwandb[0m: 	batch_size: 8
[34m[1mwandb[0m: 	encoder_name: resnet34
[34m[1mwandb[0m: 	encoder_weights: imagenet
[34m[1mwandb[0m: 	epochs: 100
[34m[1mwandb[0m: 	in_channels: 1
[34m[1mwandb[0m: 	jaccard_average: macro
[34m[1mwandb[0m: 	loss: dice
[34m[1mwandb[0m: 	lr: 0.009950137989085233
[34m[1mwandb[0m: 	num_classes: 1


  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

   | Name       | Type             | Params
-------------------------------------------------
0  | net        | Unet             | 24.4 M
1  | loss       | DiceLoss         | 0     
2  | train_acc  | Accuracy         | 0     
3  | val_acc    | Accuracy         | 0     
4  | test_acc   | Accuracy         | 0     
5  | train_jacc | JaccardIndex     | 0     
6  | val_jacc   | JaccardIndex     | 0     
7  | test_jacc  | JaccardIndex     | 0     
8  | train_f1   | F1Score          | 0     
9  | val_f1     | F1Score          | 0     
10 | test_f1    | F1Score          | 0     
11 | rmse       | MeanSquaredError | 0     
-------------------------------------------------
24.4 M    Trainable params
0         Non-trainable params
24.4 M    Total params
97.720    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


                                                                                                   

  rank_zero_warn(
  rank_zero_warn(


Epoch 0:  82%|████████████████████████▌     | 9/11 [00:10<00:02,  1.15s/it, loss=0.574, v_num=7gkr]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                            | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                               | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:  50%|███████████████████▌                   | 1/2 [00:00<00:00,  3.14it/s][A
Epoch 0:  91%|██████████████████████████▎  | 10/11 [00:10<00:01,  1.09s/it, loss=0.574, v_num=7gkr][A
Validation DataLoader 0: 100%|███████████████████████████████████████| 2/2 [00:01<00:00,  1.68it/s][A
Epoch 0: 100%|█████████████████████████████| 11/11 [00:11<00:00,  1.07s/it, loss=0.574, v_num=7gkr][A
Epoch 1:  82%|████████████████████████▌     | 9/11 [00:11<00:02,  1.32s/it, loss=0.476, v_num=7gkr][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                            | 0/2 [00:00<?, ?it/s][A
Validati

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train_accuracy,▆▅▁▆▂▄▄▅▃▅▇▄▆▅▇▅▆▆▆▇▇▇▇▇▇▆▇▇█▆
train_f1,▁▄▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇████████████
train_jaccard,▁▄▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇████████████
train_loss,█▃▂▄▁
trainer/global_step,▁▁▁▁▁▂▂▁▁▃▃▁▁▂▄▂▂▂▄▄▂▅▅▅▂▂▆▆▆▂▆▇▂▂▂▇▂▂██
val_accuracy,▁▁█▁▆▅▇▇▂▃▃▇▇▂▇▇▆▅▄▇▆▇▃▄▅▇▇▇▆▇
val_f1,▁▁▃▂▄▆██▄▄▅██▃▇██▆▆███▅▆▆███▇█
val_jaccard,▄▄▁▄▃▆██▅▅▆██▅▇██▇▆███▅▆▇███▇█
val_loss_epoch,██▆▇▅▃▁▁▅▅▄▁▁▆▂▁▁▃▃▁▁▁▅▃▃▁▁▁▂▁

0,1
epoch,30.0
train_accuracy,0.84907
train_f1,0.83265
train_jaccard,0.83055
train_loss,0.13185
trainer/global_step,269.0
val_accuracy,0.82878
val_f1,0.80627
val_jaccard,0.8077
val_loss_epoch,0.19464


[34m[1mwandb[0m: Agent Starting Run: tjbo9ubu with config:
[34m[1mwandb[0m: 	batch_size: 40
[34m[1mwandb[0m: 	encoder_name: resnet34
[34m[1mwandb[0m: 	encoder_weights: imagenet
[34m[1mwandb[0m: 	epochs: 1
[34m[1mwandb[0m: 	in_channels: 1
[34m[1mwandb[0m: 	jaccard_average: macro
[34m[1mwandb[0m: 	loss: dice
[34m[1mwandb[0m: 	lr: 0.0041666781172457505
[34m[1mwandb[0m: 	num_classes: 1


  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

   | Name       | Type             | Params
-------------------------------------------------
0  | net        | Unet             | 24.4 M
1  | loss       | DiceLoss         | 0     
2  | train_acc  | Accuracy         | 0     
3  | val_acc    | Accuracy         | 0     
4  | test_acc   | Accuracy         | 0     
5  | train_jacc | JaccardIndex     | 0     
6  | val_jacc   | JaccardIndex     | 0     
7  | test_jacc  | JaccardIndex     | 0     
8  | train_f1   | F1Score          | 0     
9  | val_f1     | F1Score          | 0     
10 | test_f1    | F1Score          | 0     
11 | rmse       | MeanSquaredError | 0     
-------------------------------------------------
24.4 M    Trainable params
0         Non-trainable params
24.4 M    Total params
97.720    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Run tjbo9ubu errored: ZeroDivisionError('float division by zero')
[34m[1mwandb[0m: [32m[41mERROR[0m Run tjbo9ubu errored: ZeroDivisionError('float division by zero')
[34m[1mwandb[0m: Agent Starting Run: qg4eqlhg with config:
[34m[1mwandb[0m: 	batch_size: 48
[34m[1mwandb[0m: 	encoder_name: resnet50
[34m[1mwandb[0m: 	encoder_weights: imagenet
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	in_channels: 1
[34m[1mwandb[0m: 	jaccard_average: macro
[34m[1mwandb[0m: 	loss: dice
[34m[1mwandb[0m: 	lr: 0.00951808682450227
[34m[1mwandb[0m: 	num_classes: 1


  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

   | Name       | Type             | Params
-------------------------------------------------
0  | net        | Unet             | 32.5 M
1  | loss       | DiceLoss         | 0     
2  | train_acc  | Accuracy         | 0     
3  | val_acc    | Accuracy         | 0     
4  | test_acc   | Accuracy         | 0     
5  | train_jacc | JaccardIndex     | 0     
6  | val_jacc   | JaccardIndex     | 0     
7  | test_jacc  | JaccardIndex     | 0     
8  | train_f1   | F1Score          | 0     
9  | val_f1     | F1Score          | 0     
10 | test_f1    | F1Score          | 0     
11 | rmse       | MeanSquaredError | 0     
-------------------------------------------------
32.5 M    Trainable params
0         Non-trainable params
32.5 M    Total params
130.059   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Run qg4eqlhg errored: ZeroDivisionError('float division by zero')
[34m[1mwandb[0m: [32m[41mERROR[0m Run qg4eqlhg errored: ZeroDivisionError('float division by zero')


In [25]:
wandb.finish()