# Applying Transfer Learning on Pre-trained SqueezeNet Model for Plant Disease Detection Dataset

References:

* https://towardsdatascience.com/a-beginners-tutorial-on-building-an-ai-image-classifier-using-pytorch-6f85cb69cba7

* https://pytorch.org/hub/pytorch_vision_squeezenet/
* https://pytorch.org/docs/stable/torchvision/models.html


In [2]:
import torch

In [6]:
from torchvision import models, transforms, datasets

In [31]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [48]:
import time

In [10]:
transformations = transforms.Compose([
    #transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [8]:
import os

new_dataset_path = "D:\\Machine Learning Datasets\\Tomato Plant Disease Classification Dataset - Simplified"
trainset_path = os.path.join(new_dataset_path,'train')
valset_path = os.path.join(new_dataset_path,'val')
testset_path = os.path.join(new_dataset_path,'test')

In [11]:
train_set = datasets.ImageFolder(trainset_path, transform = transformations)
val_set = datasets.ImageFolder(valset_path, transform = transformations)

In [12]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size =32, shuffle=True)

In [27]:
import torch.nn as nn
import torch.optim as optim

In [36]:
# Get pretrained model using torchvision.models as models library
model = models.squeezenet1_1(pretrained=True)

In [37]:
# Turn off training for their parameters
for param in model.parameters():
    param.requires_grad = False

In [38]:
model.eval()

SqueezeNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    (3): Fire(
      (squeeze): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inplace=True)
    )
    (4): Fire(
      (squeeze): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inplace=True)
    )
    (5): MaxPool2d

In [26]:
model.classifier

Sequential(
  (0): Dropout(p=0.5, inplace=False)
  (1): Conv2d(512, 1000, kernel_size=(1, 1), stride=(1, 1))
  (2): ReLU(inplace=True)
  (3): AdaptiveAvgPool2d(output_size=(1, 1))
)

In [29]:
num_classes = 8

In [39]:
model.classifier = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Conv2d(512, num_classes, kernel_size=1),
    nn.ReLU(inplace=True),
    nn.AdaptiveAvgPool2d(13)
)
# model.forward = lambda x: model.classifier(model.features(x)).view(x.size(0), num_classes)

In [40]:
model.num_classes = num_classes

In [54]:
model.to(device)

SqueezeNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    (3): Fire(
      (squeeze): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inplace=True)
    )
    (4): Fire(
      (squeeze): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inplace=True)
    )
    (5): MaxPool2d

Reference:

* https://discuss.pytorch.org/t/fine-tuning-squeezenet/3855/7
* https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

In [46]:
# Set the error function using torch.nn as nn library
criterion = nn.CrossEntropyLoss()
# Set the optimizer function using torch.optim as optim library
# optimizer = optim.Adam(model.classifier.parameters())
optimizer = optim.SGD(model.classifier.parameters(), lr=0.001, momentum=0.9)

In [47]:
def exp_lr_scheduler(optimizer, epoch, init_lr=0.001, lr_decay_epoch=7):
    """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
    lr = init_lr * (0.1**(epoch // lr_decay_epoch))

    if epoch % lr_decay_epoch == 0:
        print('LR is set to {}'.format(lr))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return optimizer

In [58]:
dset_sizes = {} 
dset_sizes['train'] = len(train_set)
dset_sizes['val'] = len(val_set)

In [59]:
dset_sizes

{'train': 10416, 'val': 578}

In [63]:
import copy

In [64]:
def train_model(model, criterion, optimizer, lr_scheduler, num_epochs=25):
    since = time.time()

    best_model = model
    best_acc = 0.0

    # run for given number of epochs
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                optimizer = lr_scheduler(optimizer, epoch)
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0.0

            if phase == 'train':
                phase_ = train_loader
            else:
                phase_ = val_loader

            # Run through all data in mini batches
            for data in phase_:
                # get the inputs
                inputs, labels = data

                # wrap them in Variable
                inputs, labels = inputs.to(device), labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward pass
                outputs = model(inputs)
                _, preds = torch.max(outputs.data, 1)

                # calculating the loss
                loss = criterion(outputs, labels)

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                # statistics for printing
                #running_loss += loss.data[0]
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dset_sizes[phase]
            epoch_acc = running_corrects.double() / dset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # save model if it performed better than
            # any other previous model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model = copy.deepcopy(model)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    return best_model

In [65]:
model_conv = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=10)

Epoch 0/9
----------
LR is set to 0.001
train Loss: 7.0746 Acc: 0.0688
val Loss: 7.2094 Acc: 0.1817

Epoch 1/9
----------
train Loss: 7.1793 Acc: 0.1066
val Loss: 7.2093 Acc: 0.1834

Epoch 2/9
----------
train Loss: 7.1412 Acc: 0.1022
val Loss: 7.2093 Acc: 0.1834

Epoch 3/9
----------
train Loss: 7.1155 Acc: 0.0741
val Loss: 7.2093 Acc: 0.1799

Epoch 4/9
----------
train Loss: 7.1018 Acc: 0.0853
val Loss: 7.2094 Acc: 0.1782

Epoch 5/9
----------
train Loss: 7.1159 Acc: 0.0878
val Loss: 7.2093 Acc: 0.1817

Epoch 6/9
----------
train Loss: 7.2042 Acc: 0.1256
val Loss: 7.2094 Acc: 0.1817

Epoch 7/9
----------
LR is set to 0.0001
train Loss: 7.1875 Acc: 0.1024
val Loss: 7.2095 Acc: 0.1713

Epoch 8/9
----------
train Loss: 7.0767 Acc: 0.0559
val Loss: 7.2029 Acc: 0.1436

Epoch 9/9
----------
train Loss: 6.9441 Acc: 0.0526
val Loss: 7.2046 Acc: 0.0882

Training complete in 12m 44s
Best val Acc: 0.183391


In [44]:
# epochs = 10
# for epoch in range(epochs):
#     train_loss = 0
#     val_loss = 0
#     accuracy = 0
    
#     # Training the model
#     model.train()
#     counter = 0
#     for inputs, labels in train_loader:
#         # Move to device
#         inputs, labels = inputs.to(device), labels.to(device)
#         # Clear optimizers
#         optimizer.zero_grad()
#         # Forward pass
#         output = model.forward(inputs)
#         # Loss
#         loss = criterion(output, labels)
#         # Calculate gradients (backpropogation)
#         loss.backward()
#         # Adjust parameters based on gradients
#         optimizer.step()
#         # Add the loss to the training set's rnning loss
#         train_loss += loss.item()*inputs.size(0)
        
#         # Print the progress of our training
#         counter += 1
#         print(counter, "/", len(train_loader))
        
#     # Evaluating the model
#     model.eval()
#     counter = 0
#     # Tell torch not to calculate gradients
#     with torch.no_grad():
#         for inputs, labels in val_loader:
#             # Move to device
#             inputs, labels = inputs.to(device), labels.to(device)
#             # Forward pass
#             output = model.forward(inputs)
#             # Calculate Loss
#             valloss = criterion(output, labels)
#             # Add loss to the validation set's running loss
#             val_loss += valloss.item()*inputs.size(0)
            
#             # Since our model outputs a LogSoftmax, find the real 
#             # percentages by reversing the log function
#             output = torch.exp(output)
#             # Get the top class of the output
#             top_p, top_class = output.topk(1, dim=1)
#             # See how many of the classes were correct?
#             equals = top_class == labels.view(*top_class.shape)
#             # Calculate the mean (get the accuracy for this batch)
#             # and add it to the running accuracy for this epoch
#             accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
            
#             # Print the progress of our evaluation
#             counter += 1
#             print(counter, "/", len(val_loader))
    
#     # Get the average loss for the entire epoch
#     train_loss = train_loss/len(train_loader.dataset)
#     valid_loss = val_loss/len(val_loader.dataset)
#     # Print out the information
#     print('Accuracy: ', accuracy/len(val_loader))
#     print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(epoch, train_loss, valid_loss))

1 / 326
2 / 326
3 / 326
4 / 326
5 / 326
6 / 326
7 / 326
8 / 326
9 / 326
10 / 326
11 / 326
12 / 326
13 / 326
14 / 326
15 / 326
16 / 326
17 / 326
18 / 326
19 / 326
20 / 326
21 / 326
22 / 326
23 / 326
24 / 326
25 / 326
26 / 326
27 / 326
28 / 326
29 / 326
30 / 326
31 / 326
32 / 326
33 / 326
34 / 326
35 / 326
36 / 326
37 / 326
38 / 326
39 / 326
40 / 326
41 / 326
42 / 326
43 / 326
44 / 326
45 / 326
46 / 326
47 / 326
48 / 326
49 / 326
50 / 326
51 / 326
52 / 326
53 / 326
54 / 326
55 / 326
56 / 326
57 / 326
58 / 326
59 / 326
60 / 326
61 / 326
62 / 326
63 / 326
64 / 326
65 / 326
66 / 326
67 / 326
68 / 326
69 / 326
70 / 326
71 / 326
72 / 326
73 / 326
74 / 326
75 / 326
76 / 326
77 / 326
78 / 326
79 / 326
80 / 326
81 / 326
82 / 326
83 / 326
84 / 326
85 / 326
86 / 326
87 / 326
88 / 326
89 / 326
90 / 326
91 / 326
92 / 326
93 / 326
94 / 326
95 / 326
96 / 326
97 / 326
98 / 326
99 / 326
100 / 326
101 / 326
102 / 326
103 / 326
104 / 326
105 / 326
106 / 326
107 / 326
108 / 326
109 / 326
110 / 326
111 / 32

KeyboardInterrupt: 