## Deep learning training pipeline Birds Songs

Visualizacion en Wandb.

In [1]:
import timm
import wandb
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

import torch
import torch.nn as nn
from torchvision import transforms
from torch import from_numpy

from pytorch_lightning.loggers import WandbLogger

import matplotlib.pyplot as plt
import os

from data_models.datamodule import BirdDataModule
from models.timm import TimmModel
from callbacks.callbacks import ImagePredictionLogger

### Load & transform data

In [2]:
data_transforms = transforms.Compose([
                                    #se transforma a tensor automaticamente
                                    # from_numpy,
                                    # transforms.Resize([224,224]),
                                    transforms.Normalize([-28.35204512488038], [9.804596690228543]) #values obtained from training dataset
                                   ])
                                   
target_transforms = transforms.Compose([
                                    # transforms.ToTensor,
                                   ])

In [3]:
dm = BirdDataModule(
    root_data_dir='../data/',
    batch_size=64,
    num_workers=8,
    transforms={
            'transform': data_transforms,
            'target_transform': target_transforms
        },
    seed = 5,
    test_size = 0.2
    )

In [4]:
dm.setup()

In [5]:
dm.coder.classes_

array(['Acrocephalus arundinaceus', 'Acrocephalus melanopogon',
       'Acrocephalus scirpaceus', 'Alcedo atthis', 'Anas platyrhynchos',
       'Anas strepera', 'Ardea purpurea', 'Botaurus stellaris',
       'Charadrius alexandrinus', 'Ciconia ciconia', 'Circus aeruginosus',
       'Coracias garrulus', 'Dendrocopos minor', 'Fulica atra',
       'Gallinula chloropus', 'Himantopus himantopus',
       'Ixobrychus minutus', 'Motacilla flava', 'Porphyrio porphyrio',
       'Tachybaptus ruficollis'], dtype=object)

In [6]:
dataset = dm.dataset('../data/')#, transform = from_numpy)

#### Model definition

In [7]:
# timm.list_models(pretrained = True)

In [8]:
num_classes = 20

timm_model = timm.create_model('mobilenetv3_small_100', pretrained=True, num_classes=num_classes, in_chans = 1) 
# timm_model = timm.create_model('densenet201', pretrained=True, num_classes=num_classes) 

print(timm_model.get_classifier())

Linear(in_features=1024, out_features=20, bias=True)


In [9]:
model = TimmModel(timm_model, num_classes, learning_rate=1e-3)

#### Wandb logger

In [10]:
# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(dm.val_dataloader()))

In [11]:
wandb_logger = WandbLogger(project="birds")

# Initialize callbacks
callbacks = [
    ImagePredictionLogger(val_samples, coder = dm.coder),
    ModelCheckpoint(dirpath="./checkpoints", monitor="val_loss", filename="bird-{epoch:02d}-{val_loss:.2f}")
]

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpgraciae[0m. Use [1m`wandb login --relogin`[0m to force relogin


#### Training pipeline

In [12]:
trainer = pl.Trainer(max_epochs=15,
                        logger=wandb_logger,
                        callbacks=callbacks,
                        enable_progress_bar=True)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
trainer.fit(model, dm)
trainer.test(model, dm)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name     | Type               | Params
------------------------------------------------
0 | model    | MobileNetV3        | 1.5 M 
1 | accuracy | MulticlassAccuracy | 0     
2 | f_score  | MulticlassF1Score  | 0     
------------------------------------------------
1.5 M     Trainable params
0         Non-trainable params
1.5 M     Total params
6.152     Total estimated model params size (MB)


Adjusting learning rate of group 0 to 1.0000e-03.


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.


Validation: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.


Validation: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.


Validation: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.


Validation: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.


Validation: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.


Validation: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.


Validation: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


Validation: 0it [00:00, ?it/s]

[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.
[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.
[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.


Adjusting learning rate of group 0 to 1.0000e-04.


Validation: 0it [00:00, ?it/s]

[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.
[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.
[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.


Adjusting learning rate of group 0 to 1.0000e-04.


Validation: 0it [00:00, ?it/s]

[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.
[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.
[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.


Adjusting learning rate of group 0 to 1.0000e-04.


Validation: 0it [00:00, ?it/s]

[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.
[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.
[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.


Adjusting learning rate of group 0 to 1.0000e-04.


Validation: 0it [00:00, ?it/s]

[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.
[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.
[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.


In [None]:
wandb.finish()

VBox(children=(Label(value='23.934 MB of 23.934 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
epoch,▁▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
test_acc,▁
test_loss_epoch,▁
test_loss_step,▄▂▃▁▅▂▂▃▄▁▂▂▂▁▂▂▁▂▁▂▇▁▂▁▁▁▁▂▄▁█▁▅▃▃▂▂▃▂▁
train_acc_epoch,▁▄▆▅▆▇▆▇███████
train_acc_step,▆▆▇█▇█▆▇▇███▇▇██▁███▇███████████████████
train_f_score_epoch,▁▄▆▅▆▆▆▇▇██████
train_f_score_step,▆▆▆▇▆▂▇▆▇█▂▇▇▆▇▇▁▆▇▇▇▁▆▇▇▇▂▆▅▇▇▆▁▇▇█▇▁▇▇
train_loss,▃▃▂▁▂▂▂▂▂▂▁▁▂▂▁▁█▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▁▁▁▁▁▁▂▂▄▂▂▂▂▅▂▂▂▂▆▂▂▂▂▇▂▃▃▃█▃▃▃▃▁▁▁

0,1
epoch,15.0
test_acc,0.97817
test_loss_epoch,0.08874
test_loss_step,0.00678
train_acc_epoch,0.99919
train_acc_step,1.0
train_f_score_epoch,0.84487
train_f_score_step,0.8
train_loss,0.0002
trainer/global_step,2625.0
