Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Models not converging in image classification problem #1989

Closed
ecm200 opened this issue May 6, 2021 · 12 comments
Closed

Models not converging in image classification problem #1989

ecm200 opened this issue May 6, 2021 · 12 comments
Labels

Comments

@ecm200
Copy link

ecm200 commented May 6, 2021

The problem

I am trying to train a Torchvision GoogleNet model on a fine grained classification problem using Ignite. I have successfully trained this model and additional models from Torchvision and various other libraries (e.g. pytorchcv, timm), on this dataset, using a simple training script that I developed.

However, training with Ignite using all the same parameters as my simple training script, I am getting non-convergence (or it appears even non-training) of the network and I haven't managed to figure it out yet. It does seem that somehow the model is not being updated, but I have triple checked my implementation against the examples I have followed, and as far as I can see, it it should be working.

Results from simple training script

Here are the metrics during training from the simple training script I developed, and show convergence during the inital epochs:

Epoch 0/39
----------
train Loss: 5.2044 Acc: 0.0252
test Loss: 4.8963 Acc: 0.1293

Epoch 1/39
----------
train Loss: 4.7696 Acc: 0.1009
test Loss: 4.1853 Acc: 0.2579

Epoch 2/39
----------
train Loss: 4.2409 Acc: 0.1832
test Loss: 3.5428 Acc: 0.3428

Epoch 3/39
----------
train Loss: 3.7754 Acc: 0.2579
test Loss: 3.0363 Acc: 0.4237

Epoch 4/39
----------
train Loss: 3.4005 Acc: 0.3143
test Loss: 2.5909 Acc: 0.4924

Results from Ignite

Here are example outputs of the epoch metrics during training for both training and validation datasets. The initial training loss is similar for both approaches, but the 2nd epoch of the simple training script shows the network has learnt something as the loss has been reduced. However, comparing this the Ignite version, there is no improvement in loss over 5 epochs.

Train - Epoch: 0001  Accuracy: 0.01 Precision: 0.00 Recall: 0.01 F1-score: 0.00 TopKCatAcc: 0.03 Loss: 5.33
Valid - Epoch: 0001  Accuracy: 0.01 Precision: 0.00 Recall: 0.01 F1-score: 0.00 TopKCatAcc: 0.03 Loss: 5.33
Train - Epoch: 0002  Accuracy: 0.01 Precision: 0.00 Recall: 0.01 F1-score: 0.00 TopKCatAcc: 0.03 Loss: 5.33
Valid - Epoch: 0002  Accuracy: 0.01 Precision: 0.00 Recall: 0.01 F1-score: 0.00 TopKCatAcc: 0.03 Loss: 5.33
Train - Epoch: 0003  Accuracy: 0.01 Precision: 0.00 Recall: 0.01 F1-score: 0.00 TopKCatAcc: 0.03 Loss: 5.33
Valid - Epoch: 0003  Accuracy: 0.01 Precision: 0.00 Recall: 0.01 F1-score: 0.00 TopKCatAcc: 0.03 Loss: 5.33
Train - Epoch: 0004  Accuracy: 0.00 Precision: 0.00 Recall: 0.00 F1-score: 0.00 TopKCatAcc: 0.03 Loss: 5.33
Valid - Epoch: 0004  Accuracy: 0.01 Precision: 0.00 Recall: 0.01 F1-score: 0.00 TopKCatAcc: 0.03 Loss: 5.33
Train - Epoch: 0005  Accuracy: 0.01 Precision: 0.00 Recall: 0.01 F1-score: 0.00 TopKCatAcc: 0.02 Loss: 5.33
Valid - Epoch: 0005  Accuracy: 0.01 Precision: 0.00 Recall: 0.01 F1-score: 0.00 TopKCatAcc: 0.03 Loss: 5.33

Ignite Training Script

This is the script I have developed looking at the CIFAR-10 example as a guide:

#!/usr/bin/env python
# coding: utf-8


from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

from torchvision import models

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, Engine
from ignite.metrics import Accuracy, Loss, RunningAverage, ConfusionMatrix, Recall, Fbeta, Precision, TopKCategoricalAccuracy
from ignite.handlers import ModelCheckpoint, EarlyStopping
from ignite.contrib.handlers.param_scheduler import LRScheduler
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, global_step_from_engine, WeightsHistHandler, GradsHistHandler
from ignite.contrib.engines import common

import timm

import os, shutil

import pandas as pd
import matplotlib.pylab as plt
import numpy as np

# Local modules
from cub_tools.transforms import makeDefaultTransforms
from cub_tools.data import create_dataloaders

#################################################
## Script runtime options

# Model settings
model_name = 'googlenet'
model_args = {'pretrained' : True}
model_func = models.googlenet

# Directory settings
root_dir = '/home/edmorris/projects/image_classification/caltech_birds'
data_dir = os.path.join(root_dir,'data/images')
working_dir = os.path.join(root_dir,'models/classification', 'ignite_'+model_name)
clean_up = True

# Training parameters
criterion = None # default is nn.CrossEntropyLoss
optimizer = None # default is optim.SGD
scheduler = None # default is StepLR
batch_size = 64
num_workers = 4
num_epochs = 40
early_stopping_patience = 5

# Image parameters
img_crop_size=224
img_resize=256
#################################################

def initialize(model_func, model_args, criterion=None, optimizer=None, scheduler=None, is_torchvision=True, num_classes=200):
    
     # Setup the model object
    model = model_func(**model_args)
    if is_torchvision:
        # Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
        model.fc = nn.Linear(model.fc.in_features, num_classes)

    # Setup loss criterion and optimizer
    if (optimizer == None):
        optimizer = optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)

    if criterion == None:
        criterion = nn.CrossEntropyLoss()

    # Setup learning rate scheduler
    if scheduler == None:
        scheduler = StepLR(optimizer=optimizer,step_size=7, gamma=0.1)

    return model, optimizer, criterion, scheduler


def create_trainer(model, optimizer, criterion, lr_scheduler):

     # Define any training logic for iteration update
    def train_step(engine, batch):

        x, y = batch[0].to(device), batch[1].to(device)

        model.train()
        y_pred = model(x)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        return loss.item()

    # Define trainer engine
    trainer = Engine(train_step)

    return trainer

def create_evaluator(model, metrics, tag='val'):

    # Evaluation step function
    @torch.no_grad()
    def evaluate_step(engine: Engine, batch):
        model.eval()
        x, y = batch[0].to(device), batch[1].to(device)
        y_pred = model(x)
        return y_pred, y

    # Create the evaluator object
    evaluator = Engine(evaluate_step)

    # Attach the metrics
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    return evaluator

def evaluate_model():
    epoch = trainer.state.epoch
    train_state = train_evaluator.run(train_loader)
    tr_accuracy = train_state.metrics['accuracy']
    tr_precision = train_state.metrics['precision']
    tr_recall = train_state.metrics['recall']
    tr_f1 = train_state.metrics['f1']
    tr_topKCatAcc = train_state.metrics['topKCatAcc']
    tr_loss = train_state.metrics['loss']
    print("Train - Epoch: {:0>4}  Accuracy: {:.2f} Precision: {:.2f} Recall: {:.2f} F1-score: {:.2f} TopKCatAcc: {:.2f} Loss: {:.2f}"
          .format(epoch, tr_accuracy, tr_precision, tr_recall, tr_f1, tr_topKCatAcc, tr_loss))

    val_state = evaluator.run(val_loader)
    val_accuracy = val_state.metrics['accuracy']
    val_precision = val_state.metrics['precision']
    val_recall = val_state.metrics['recall']
    val_f1 = val_state.metrics['f1']
    val_topKCatAcc = val_state.metrics['topKCatAcc']
    val_loss = val_state.metrics['loss']
    print("Valid - Epoch: {:0>4}  Accuracy: {:.2f} Precision: {:.2f} Recall: {:.2f} F1-score: {:.2f} TopKCatAcc: {:.2f} Loss: {:.2f}"
          .format(epoch, val_accuracy, val_precision, val_recall, val_f1, val_topKCatAcc, val_loss))


print('')
print('***********************************************')
print('**                                           **')
print('**         CUB 200 DATASET TRAINING          **')
print('**        --------------------------         **')
print('**                                           **')
print('**    Image Classification of 200 North      **')
print('**          American Bird Species            **')
print('**                                           **')
print('**   PyTorch Ignite Based Training Script    **')
print('**                                           **')
print('**            Ed Morris (c) 2021             **')
print('**                                           **')
print('***********************************************')
print('')
print('[INFO] Model and Directories')
print('[PARAMS] Model Name:: {}'.format(model_name))
print('[PARAMS] Model Dir:: {}'.format(working_dir))
print('[PARAMS] Data Dir:: {}'.format(data_dir))
print('')
print('[INFO] Training Parameters')
print('[PARAMS] Batch Size:: {}'.format(batch_size))
print('[PARAMS] Number of image processing CPUs:: {}'.format(num_workers))
print('[PARAMS] Number of epochs to train:: {}'.format(num_epochs))
print('')
print('[INFO] Image Settings')
print('[PARAMS] Image Size:: {}'.format(img_crop_size))
print('')

## SETUP DIRS
# Clean up the output directory by removing it if desired
if clean_up:
    shutil.rmtree(working_dir)
# Create the output directory for results
os.makedirs(working_dir, exist_ok=True)

## SETUP DEVICE
# Setup the device to run the computations
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## SETUP DATALOADERS
print('[INFO] Begining setting up training of model {}'.format(model_name))
print('[INFO] Setting up dataloaders for train and test sets.')
# Get data transforms
data_transforms = makeDefaultTransforms(img_crop_size=img_crop_size, img_resize=img_resize)
# Make train and test loaders
train_loader, val_loader = create_dataloaders(data_transforms=data_transforms, data_dir=data_dir, batch_size=batch_size, num_workers=num_workers)

## SETUP MODEL OBJECTS
# Get the model, optimizer, loss criterion and learning rate scheduler objects
print('[INFO] Getting model {} from library and setting up loss criterion, optimizer and learning rate scheduler...'.format(model_name), end='')
model, optimizer, criterion, lr_scheduler = initialize(
    model_func=model_func, 
    model_args=model_args, 
    criterion=criterion, 
    optimizer=optimizer, 
    scheduler=scheduler, 
    is_torchvision=True, 
    num_classes=200
    )
print('done')
# send model to the device for training
print('[INFO] Sending model {} to device {}...'.format(model_name, device), end='')
model = model.to(device)
print('done')

## SETUP TRAINER AND EVALUATOR
# Setup model trainer and evaluator
print('[INFO] Creating Ignite training, evaluation objects and logging...', end='')
trainer = create_trainer(model=model, optimizer=optimizer, criterion=criterion, lr_scheduler=lr_scheduler)
metrics = {
    'accuracy':Accuracy(),
    'recall':Recall(average=True),
    'precision':Precision(average=True),
    'f1':Fbeta(beta=1),
    'topKCatAcc':TopKCategoricalAccuracy(k=5),
    'loss':Loss(criterion)
}
evaluator = create_evaluator(model, metrics=metrics)
train_evaluator = create_evaluator(model, metrics=metrics, tag='train')
# Add validation logging
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), evaluate_model)
# Add TB logging
evaluators = {"training": train_evaluator, "validation": evaluator}
tb_logger = common.setup_tb_logging(
    output_path=os.path.join(working_dir,'tb_logs'), 
    trainer=trainer,
    optimizers=optimizer,
    evaluators=evaluators
    )
print('done')

## TRAIN
# Train the model
print('[INFO] Executing model training...')
trainer.run(train_loader, max_epochs=num_epochs)
print('[INFO] Model training is complete.')

It has one additional dependency which is a dataloader creation function:

def create_dataloaders(data_transforms, data_dir, batch_size, num_workers, shuffle=None, test_batch_size=2):

    batch_size = {'train' : batch_size, 'test' : test_batch_size}

    if shuffle == None:
        shuffle = {'train' : True, 'test' : False}

    # Setup data loaders with augmentation transforms
    image_datasets = {x: ImageFolder(os.path.join(data_dir, x), data_transforms[x])
                    for x in ['train', 'test']}
    dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size[x],
                                 shuffle=shuffle[x], num_workers=num_workers)
                for x in ['train', 'test']}

    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
    class_names = image_datasets['train'].classes

    print('***********************************************')
    print('**            DATASET SUMMARY                **')
    print('***********************************************')
    for dataset in dataset_sizes.keys():
        print(dataset,' size:: ', dataset_sizes[dataset],' images')
    print('Number of classes:: ', len(class_names))
    print('***********************************************')
    print('[INFO] Created data loaders.')

    return dataloaders['train'], dataloaders['test']
@ecm200 ecm200 added the question label May 6, 2021
@ecm200
Copy link
Author

ecm200 commented May 6, 2021

After some more investigations, I noted the following had been added to the training step in the CIFAR-10 example.

scaler = GradScaler(enabled=with_amp)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

This was compared to my original example of:

loss.backward()
optimizer.step()
lr_scheduler.step()

It appears the gradient scaling is doing the trick.
This is strange, because in my original simple training script, I am not using gradient scaling at all, and the model trained perfectly well.

def train_model(model, criterion, optimizer, scheduler, device, dataloaders, dataset_sizes, 
                num_epochs=25, return_history=False, log_history=True, working_dir='output'):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    history = {'epoch' : [], 'train_loss' : [], 'test_loss' : [], 'train_acc' : [], 'test_acc' : []}

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'test']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            history['epoch'].append(epoch)
            history[phase+'_loss'].append(epoch_loss)
            history[phase+'_acc'].append(epoch_acc)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'test' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        
        if log_history:
            save_pickle(history,os.path.join(working_dir,'model_history.pkl'))
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    print('Returning object of best model.')
    model.load_state_dict(best_model_wts)
    
    if return_history:
        return model, history
    else:
        return model

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 6, 2021

Thanks for the report @ecm200 !
I can try to reproduce the issue on colab with cifar10.
However, this part of code may seem a bit problematic:

def initialize(model_func, model_args, criterion=None, optimizer=None, scheduler=None, is_torchvision=True, num_classes=200):
    
     # Setup the model object
    model = model_func(**model_args)
    if is_torchvision:
        # Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
        model.fc = nn.Linear(model.fc.in_features, num_classes)

    # Setup loss criterion and optimizer
    if (optimizer == None):
        optimizer = optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)

    if criterion == None:
        criterion = nn.CrossEntropyLoss()

    # Setup learning rate scheduler
    if scheduler == None:
        scheduler = StepLR(optimizer=optimizer,step_size=7, gamma=0.1)

    return model, optimizer, criterion, scheduler

optimizer = None

model, optimizer, criterion, lr_scheduler = initialize(
    model_func=model_func, 
    model_args=model_args, 
    criterion=criterion, 
    optimizer=optimizer, 
    scheduler=scheduler, 
    is_torchvision=True, 
    num_classes=200
    )
model = model.to(device)

In Python we usually compare with None as var is None. When we build the optimizer, it should take parameters of the final model (sent on the device etc). Currently, I think PyTorch still can work like that but no recommended :

model = ...
- optimizer = optim.SGD(model.parameters())
model = model.to("cuda")
+ optimizer = optim.SGD(model.parameters())

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 6, 2021

#1989 (comment)

Actually, grad scaler and AMP is optional and is not usually required.

@ecm200
Copy link
Author

ecm200 commented May 6, 2021

#1989 (comment)

Actually, grad scaler and AMP is optional and is not usually required.

Hi @vfdev-5 ,

Thanks very much for your observation.

I think I just found something else I missed, which was a difference between my simple script and the Ignite examples.

In the model update phase of the forward training step, the batch inference and loss calculation, and backward propagation of the loss, all take part under the call with torch.set_grad_enabled(True). This was not in my original training step for the Ignite training function.

# Iterate over data.
for inputs, labels in dataloaders[phase]:
    inputs = inputs.to(device)
    labels = labels.to(device)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward
    # track history if only in train
    with torch.set_grad_enabled(phase == 'train'):
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        # backward + optimize only if in training phase
        if phase == 'train':
            loss.backward()
            optimizer.step()

I think this means that unless the backward propagation is done with that call, the model parameters cannot be updated, which would tally with the fact that the training up to this point did not appear to be updating the model parameters as the loss was not reducing.

Thanks for the report @ecm200 !
I can try to reproduce the issue on colab with cifar10.
However, this part of code may seem a bit problematic:

def initialize(model_func, model_args, criterion=None, optimizer=None, scheduler=None, is_torchvision=True, num_classes=200):
    
     # Setup the model object
    model = model_func(**model_args)
    if is_torchvision:
        # Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
        model.fc = nn.Linear(model.fc.in_features, num_classes)

    # Setup loss criterion and optimizer
    if (optimizer == None):
        optimizer = optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)

    if criterion == None:
        criterion = nn.CrossEntropyLoss()

    # Setup learning rate scheduler
    if scheduler == None:
        scheduler = StepLR(optimizer=optimizer,step_size=7, gamma=0.1)

    return model, optimizer, criterion, scheduler

optimizer = None

model, optimizer, criterion, lr_scheduler = initialize(
    model_func=model_func, 
    model_args=model_args, 
    criterion=criterion, 
    optimizer=optimizer, 
    scheduler=scheduler, 
    is_torchvision=True, 
    num_classes=200
    )
model = model.to(device)

In Python we usually compare with None as var is None. When we build the optimizer, it should take parameters of the final model (sent on the device etc). Currently, I think PyTorch still can work like that but no recommended :

model = ...
- optimizer = optim.SGD(model.parameters())
model = model.to("cuda")
+ optimizer = optim.SGD(model.parameters())

@vfdev-5 Good spot sir!

I checked in my original simple script and I have indeed pushed the model to the compute device before calling the optimizer, where as here I did not. I have fixed that by simply adding the .to(device) call in the init function.

def initialize(device, model_func, model_args, criterion=None, optimizer=None, scheduler=None, is_torchvision=True, num_classes=200):
    
     # Setup the model object
    model = model_func(**model_args)
    if is_torchvision:
        # Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.to(device)

    # Setup loss criterion and optimizer
    if (optimizer == None):
        optimizer = optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)

    if criterion == None:
        criterion = nn.CrossEntropyLoss()

    # Setup learning rate scheduler
    if scheduler == None:
        scheduler = StepLR(optimizer=optimizer,step_size=7, gamma=0.1)

    return model, optimizer, criterion, scheduler

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 6, 2021

Well, with torch.set_grad_enabled(True). is not required neither, IMO as the model usually contains parameters that requires gradients and by default gradients are enabled and updated...

Let me try to reproduce the issue here with CIFAR10: https://colab.research.google.com/drive/12tSCyNEvkJXYvEOt82dnCu-KZLWfEQVL?usp=sharing

@ecm200
Copy link
Author

ecm200 commented May 6, 2021

@vfdev-5 ok that would be great.
I implemented the model to device change, to make sure it occurs before the optimizer call.

I have just tried running the mode again, without the GradScaler function, and I have the same issues again.
If I turn the GradScaler on again, I start to see convergence.

So without GradScaler I get this behaviour:

[INFO] Executing model training...
Train - Epoch: 0001  Accuracy: 0.01 Precision: 0.00 Recall: 0.01 F1-score: 0.00 TopKCatAcc: 0.03 Loss: 5.32
Valid - Epoch: 0001  Accuracy: 0.01 Precision: 0.00 Recall: 0.01 F1-score: 0.00 TopKCatAcc: 0.03 Loss: 5.33
Train - Epoch: 0002  Accuracy: 0.00 Precision: 0.00 Recall: 0.00 F1-score: 0.00 TopKCatAcc: 0.03 Loss: 5.33
Valid - Epoch: 0002  Accuracy: 0.01 Precision: 0.00 Recall: 0.01 F1-score: 0.00 TopKCatAcc: 0.03 Loss: 5.33

And with GradScaler, I get this behaviour:

Train - Epoch: 0001  Accuracy: 0.02 Precision: 0.01 Recall: 0.02 F1-score: 0.01 TopKCatAcc: 0.07 Loss: 5.23
Valid - Epoch: 0001  Accuracy: 0.02 Precision: 0.01 Recall: 0.02 F1-score: 0.01 TopKCatAcc: 0.07 Loss: 5.22
Train - Epoch: 0002  Accuracy: 0.05 Precision: 0.06 Recall: 0.05 F1-score: 0.04 TopKCatAcc: 0.17 Loss: 5.11
Valid - Epoch: 0002  Accuracy: 0.06 Precision: 0.05 Recall: 0.05 F1-score: 0.04 TopKCatAcc: 0.17 Loss: 5.10
Train - Epoch: 0003  Accuracy: 0.12 Precision: 0.11 Recall: 0.12 F1-score: 0.10 TopKCatAcc: 0.31 Loss: 4.99
Valid - Epoch: 0003  Accuracy: 0.12 Precision: 0.12 Recall: 0.12 F1-score: 0.10 TopKCatAcc: 0.32 Loss: 4.96
Train - Epoch: 0004  Accuracy: 0.17 Precision: 0.18 Recall: 0.17 F1-score: 0.15 TopKCatAcc: 0.42 Loss: 4.84
Valid - Epoch: 0004  Accuracy: 0.17 Precision: 0.17 Recall: 0.17 F1-score: 0.14 TopKCatAcc: 0.43 Loss: 4.80
Train - Epoch: 0005  Accuracy: 0.22 Precision: 0.25 Recall: 0.22 F1-score: 0.19 TopKCatAcc: 0.50 Loss: 4.67
Valid - Epoch: 0005  Accuracy: 0.21 Precision: 0.21 Recall: 0.21 F1-score: 0.18 TopKCatAcc: 0.50 Loss: 4.61
def initialize(device, model_func, model_args, criterion=None, optimizer=None, scheduler=None, is_torchvision=True, num_classes=200):
    
     # Setup the model object
    model = model_func(**model_args)
    if is_torchvision:
        # Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.to(device)

    # Setup loss criterion and optimizer
    if (optimizer == None):
        optimizer = optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)

    if criterion == None:
        criterion = nn.CrossEntropyLoss()

    # Setup learning rate scheduler
    if scheduler == None:
        scheduler = StepLR(optimizer=optimizer,step_size=7, gamma=0.1)

    return model, optimizer, criterion, scheduler


def create_trainer(model, optimizer, criterion, lr_scheduler):

     # Define any training logic for iteration update
    def train_step(engine, batch):

        # Get the images and labels for this batch
        x, y = batch[0].to(device), batch[1].to(device)

        # Set the model into training mode
        model.train()

        # Zero paramter gradients
        optimizer.zero_grad()

        # Update the model
        if with_grad_scale:
            with autocast(enabled=with_amp):
                y_pred = model(x)
                loss = criterion(y_pred, y)
            scaler = GradScaler(enabled=with_amp)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            with torch.set_grad_enabled(True):
                y_pred = model(x)
                loss = criterion(y_pred, y)
                loss.backward()
                optimizer.step()
            lr_scheduler.step()

        return loss.item()

    # Define trainer engine
    trainer = Engine(train_step)

    return trainer

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 6, 2021

OK, I see the issue you have is with LR scheduler which is called with ignite every iteration instead of each epoch (probably initially supposed to). This means that the LR was dumped to 0 and thus no training.

You can do the following with lr scheduler

trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: lr_scheduler.step())

@ecm200
Copy link
Author

ecm200 commented May 6, 2021

OK, I see the issue you have is with LR scheduler which is called with ignite every iteration instead of each epoch (probably initially supposed to). This means that the LR was dumped to 0 and thus no training.

You can do the following with lr scheduler

trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: lr_scheduler.step())

That would do it!
Doh!
Thank you so much.

So, to be clear, take out the lp_scheduler.step() call from the train_step, and replace with the event handler above?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 6, 2021

So, to be clear, take out the lp_scheduler.step() call from the train_step, and replace with the event handler above?

Yes, exactly

@ecm200
Copy link
Author

ecm200 commented May 6, 2021

So, to be clear, take out the lp_scheduler.step() call from the train_step, and replace with the event handler above?

Yes, exactly

Who'd have thought that updating the learning rate by iteration instead of epoch would cause such an issue......
Well of course it would given the learning rate would become vanishingly small pretty quickly!

Thanks again, great spot! I've been looking for that for about a day and you nailed it.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 6, 2021

@ecm200 updating LR by iteration can also be a way to train models with warm-up and down like here: https://pytorch.org/ignite/contrib/handlers.html#piecewise-linear-scheduler

But, yes, it can be tricky to check everything in a single run. Maybe, we can think of providing a debug mode where would dump all important things to check like : loss, grads, LR etc

Thanks again, great spot! I've been looking for that for about a day and you nailed it.

Glad that could help.

@ecm200 ecm200 closed this as completed May 6, 2021
@ecm200
Copy link
Author

ecm200 commented May 6, 2021

@vfdev-5
A debug mode sounds like a great idea.
I think in the end, it was just getting used to the idea of event handlers and how to make sure things happened at the correct time.

It is happily training away now with that modification:

[INFO] Executing model training...
Epoch: 0001  TrAcc: 0.016 ValAcc: 0.020 TrPrec: 0.007 ValPrec: 0.026 TrRec: 0.016 ValRec: 0.020 TrF1: 0.009 ValF1: 0.013 TrTopK: 0.069 ValTopK: 0.078 TrLoss: 5.226 ValLoss: 5.225
Epoch: 0002  TrAcc: 0.054 ValAcc: 0.065 TrPrec: 0.055 ValPrec: 0.070 TrRec: 0.054 ValRec: 0.067 TrF1: 0.042 ValF1: 0.053 TrTopK: 0.167 ValTopK: 0.183 TrLoss: 5.116 ValLoss: 5.106
Epoch: 0003  TrAcc: 0.113 ValAcc: 0.119 TrPrec: 0.117 ValPrec: 0.117 TrRec: 0.113 ValRec: 0.122 TrF1: 0.093 ValF1: 0.097 TrTopK: 0.310 ValTopK: 0.318 TrLoss: 4.992 ValLoss: 4.976
Epoch: 0004  TrAcc: 0.177 ValAcc: 0.176 TrPrec: 0.176 ValPrec: 0.179 TrRec: 0.177 ValRec: 0.180 TrF1: 0.151 ValF1: 0.147 TrTopK: 0.418 ValTopK: 0.429 TrLoss: 4.847 ValLoss: 4.814
Epoch: 0005  TrAcc: 0.219 ValAcc: 0.213 TrPrec: 0.228 ValPrec: 0.229 TrRec: 0.219 ValRec: 0.216 TrF1: 0.191 ValF1: 0.184 TrTopK: 0.485 ValTopK: 0.508 TrLoss: 4.683 ValLoss: 4.628
Epoch: 0006  TrAcc: 0.244 ValAcc: 0.245 TrPrec: 0.300 ValPrec: 0.279 TrRec: 0.244 ValRec: 0.248 TrF1: 0.217 ValF1: 0.213 TrTopK: 0.542 ValTopK: 0.566 TrLoss: 4.503 ValLoss: 4.425
Epoch: 0007  TrAcc: 0.283 ValAcc: 0.277 TrPrec: 0.333 ValPrec: 0.351 TrRec: 0.283 ValRec: 0.279 TrF1: 0.257 ValF1: 0.252 TrTopK: 0.586 ValTopK: 0.606 TrLoss: 4.326 ValLoss: 4.221
Epoch: 0008  TrAcc: 0.286 ValAcc: 0.284 TrPrec: 0.338 ValPrec: 0.342 TrRec: 0.286 ValRec: 0.286 TrF1: 0.260 ValF1: 0.258 TrTopK: 0.593 ValTopK: 0.616 TrLoss: 4.316 ValLoss: 4.216
Epoch: 0009  TrAcc: 0.281 ValAcc: 0.282 TrPrec: 0.358 ValPrec: 0.346 TrRec: 0.281 ValRec: 0.284 TrF1: 0.258 ValF1: 0.254 TrTopK: 0.590 ValTopK: 0.615 TrLoss: 4.286 ValLoss: 4.182
Epoch: 0010  TrAcc: 0.285 ValAcc: 0.290 TrPrec: 0.333 ValPrec: 0.349 TrRec: 0.285 ValRec: 0.293 TrF1: 0.257 ValF1: 0.262 TrTopK: 0.592 ValTopK: 0.621 TrLoss: 4.281 ValLoss: 4.174
Epoch: 0011  TrAcc: 0.289 ValAcc: 0.292 TrPrec: 0.355 ValPrec: 0.347 TrRec: 0.289 ValRec: 0.295 TrF1: 0.264 ValF1: 0.264 TrTopK: 0.588 ValTopK: 0.627 TrLoss: 4.261 ValLoss: 4.144

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants