In [11]:
import timm
import wandb
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

import torch
import torch.nn as nn
from torchvision import transforms
from torch import from_numpy

from pytorch_lightning.loggers import WandbLogger

import matplotlib.pyplot as plt
import os

from data.datamodule import BirdDataModule
from models.timm import TimmModel
from callbacks.callbacks import ImagePredictionLogger

### Load & transform data

In [12]:
transforms = transforms.Compose([
                                    from_numpy,
                                    #transforms.Resize(224),
                                    # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                   ])

In [13]:
dm = BirdDataModule(root_data_dir='../data/', batch_size=16, num_workers=4, transforms={'transform': transforms, 'target_transform': None}, seed = 5, test_size = 0.2)

In [14]:
dm.setup()

#### Model definition

In [15]:
num_classes = 20

timm_model = timm.create_model('densenet201', pretrained=True, num_classes=num_classes) 
# timm_model = timm.create_model('densenet201', pretrained=True, num_classes=num_classes) 

print(timm_model.get_classifier())

Linear(in_features=1920, out_features=20, bias=True)


In [16]:
model = TimmModel(timm_model, num_classes, learning_rate=1e-3)

#### Wandb logger

In [17]:
# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(dm.val_dataloader()))

In [18]:
wandb_logger = WandbLogger(project="birds")

# Initialize callbacks
callbacks = [
    #EarlyStopping(monitor="val_loss", min_delta=0.00, patience=3, verbose=False, mode="max"),
    #LearningRateMonitor(),
    ImagePredictionLogger(val_samples),
    ModelCheckpoint(dirpath="./checkpoints", monitor="val_loss", filename="bird-{epoch:02d}-{val_loss:.2f}")
]

  rank_zero_warn(


#### Training pipeline

In [19]:
trainer = pl.Trainer(max_epochs=10,
                        logger=wandb_logger,
                        callbacks=callbacks,
                        enable_progress_bar=True)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [20]:
trainer.fit(model, dm)
# wandb.finish()


  | Name     | Type               | Params
------------------------------------------------
0 | model    | DenseNet           | 18.1 M
1 | accuracy | MulticlassAccuracy | 0     
2 | f_score  | MulticlassF1Score  | 0     
------------------------------------------------
18.1 M    Trainable params
0         Non-trainable params
18.1 M    Total params
72.525    Total estimated model params size (MB)


Adjusting learning rate of group 0 to 1.0000e-03.
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[1, 16, 224, 224] to have 3 channels, but got 16 channels instead