In [1]:
import timm
import wandb
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

import torch
import torch.nn as nn
from torchvision import transforms
from torch import from_numpy

from pytorch_lightning.loggers import WandbLogger

import matplotlib.pyplot as plt
import os

from data_models.datamodule import BirdDataModule
from models.timm import TimmModel
from callbacks.callbacks import ImagePredictionLogger

  from .autonotebook import tqdm as notebook_tqdm


### Load & transform data

In [2]:
data_transforms = transforms.Compose([
                                    #se transforma a tensor automaticamente
                                    # from_numpy,
                                    # transforms.Resize([224,224]),
                                    transforms.Normalize([-28.35204512488038], [9.804596690228543]) #values obtained from training dataset
                                   ])
                                   
target_transforms = transforms.Compose([
                                    # transforms.ToTensor,
                                   ])

In [3]:
dm = BirdDataModule(
    root_data_dir='../data/',
    batch_size=32,
    num_workers=8,
    transforms={
            'transform': data_transforms,
            'target_transform': target_transforms
        },
    seed = 5,
    test_size = 0.2
    )

In [4]:
dm.setup()

In [5]:
dm.coder.classes_

array(['Acrocephalus arundinaceus', 'Acrocephalus melanopogon',
       'Acrocephalus scirpaceus', 'Alcedo atthis', 'Anas platyrhynchos',
       'Anas strepera', 'Ardea purpurea', 'Botaurus stellaris',
       'Charadrius alexandrinus', 'Ciconia ciconia', 'Circus aeruginosus',
       'Coracias garrulus', 'Dendrocopos minor', 'Fulica atra',
       'Gallinula chloropus', 'Himantopus himantopus',
       'Ixobrychus minutus', 'Motacilla flava', 'Porphyrio porphyrio',
       'Tachybaptus ruficollis'], dtype=object)

In [6]:
dataset = dm.dataset('../data/')#, transform = from_numpy)

#### Model definition

In [7]:
# timm.list_models(pretrained = True)

In [8]:
num_classes = 20

timm_model = timm.create_model('mobilenetv3_small_100', pretrained=True, num_classes=num_classes, in_chans = 1) 
# timm_model = timm.create_model('densenet201', pretrained=True, num_classes=num_classes) 

print(timm_model.get_classifier())

Linear(in_features=1024, out_features=20, bias=True)


In [9]:
model = TimmModel(timm_model, num_classes, learning_rate=1e-3)

#### Wandb logger

In [10]:
# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(dm.val_dataloader()))

In [11]:
wandb_logger = WandbLogger(project="birds")

# Initialize callbacks
callbacks = [
    #EarlyStopping(monitor="val_loss", min_delta=0.00, patience=3, verbose=False, mode="max"),
    #LearningRateMonitor(),
    ImagePredictionLogger(val_samples),
    ModelCheckpoint(dirpath="./checkpoints", monitor="val_loss", filename="bird-{epoch:02d}-{val_loss:.2f}")
]

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpgraciae[0m. Use [1m`wandb login --relogin`[0m to force relogin


#### Training pipeline

In [12]:
trainer = pl.Trainer(max_epochs=30,
                        logger=wandb_logger,
                        callbacks=callbacks,
                        enable_progress_bar=True)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
trainer.fit(model, dm)


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name     | Type               | Params
------------------------------------------------
0 | model    | MobileNetV3        | 1.5 M 
1 | accuracy | MulticlassAccuracy | 0     
2 | f_score  | MulticlassF1Score  | 0     
------------------------------------------------
1.5 M     Trainable params
0         Non-trainable params
1.5 M     Total params
6.152     Total estimated model params size (MB)


Adjusting learning rate of group 0 to 1.0000e-03.
Epoch 0:  80%|███████▉  | 348/436 [07:11<01:49,  1.24s/it, loss=0.384, v_num=e8ca]Adjusting learning rate of group 0 to 1.0000e-03.
Epoch 1:  80%|███████▉  | 348/436 [07:35<01:55,  1.31s/it, loss=0.338, v_num=e8ca]Adjusting learning rate of group 0 to 1.0000e-03.
Epoch 2:  80%|███████▉  | 348/436 [07:54<02:00,  1.36s/it, loss=0.232, v_num=e8ca]Adjusting learning rate of group 0 to 1.0000e-03.
Epoch 3:  80%|███████▉  | 348/436 [08:01<02:01,  1.38s/it, loss=0.265, v_num=e8ca]Adjusting learning rate of group 0 to 1.0000e-03.
Epoch 4:  80%|███████▉  | 348/436 [10:40<02:41,  1.84s/it, loss=0.17, v_num=e8ca]  Adjusting learning rate of group 0 to 1.0000e-03.
Epoch 5:  80%|███████▉  | 348/436 [10:56<02:46,  1.89s/it, loss=0.144, v_num=e8ca] Adjusting learning rate of group 0 to 1.0000e-03.
Epoch 6:  80%|███████▉  | 348/436 [11:51<02:59,  2.04s/it, loss=0.18, v_num=e8ca]  Adjusting learning rate of group 0 to 1.0000e-03.
Epoch 7:  80%|███████▉ 

In [None]:
wandb.finish()