In [None]:
# Import Libraries
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
import numpy as np
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from logging import Logger

In [2]:
!protoc --version

libprotoc 3.5.1


In [49]:
transformations = transforms.Compose([
    transforms.Resize(255),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [50]:
from tqdm import tqdm_notebook

In [51]:
import random, os

In [52]:
train_set = datasets.ImageFolder("../data/dataset/train/", transform = transformations )
valid_set = datasets.ImageFolder("../data/dataset/valid/", transform = transformations)

In [53]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=32, shuffle=True)

In [54]:
model = models.densenet161(pretrained=True)
# Turn off training for their parameters
for param in model.parameters():
    param.requires_grad = False

  nn.init.kaiming_normal(m.weight.data)


In [55]:
num_labels = 4

In [56]:
class LinearClassifier(nn.Module):
    def __init__(self, in_feature_number, class_number):
        super(LinearClassifier, self).__init__()

        self.f1 = nn.Linear(in_feature_number, 1024)
        self.f2 = nn.Linear(1024, 512)
        self.f3 = nn.Linear(512, num_labels)
        self.soft_max = nn.LogSoftmax(dim=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.f1(x)
        x = self.relu(x)
        x = self.f2(x)
        x = self.relu(x)
        x = self.f3(x)
        output = self.soft_max(x)
        return output

In [57]:
# Replace default classifier with new classifier
classifier_ = LinearClassifier(in_feature_number=model.classifier.in_features, class_number=num_labels)
model.classifier = classifier_

In [58]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move model to the device specified above
_ = model.to(device)

In [59]:
criterion = nn.NLLLoss()
# Set the optimizer function using torch.optim as optim library
optimizer = optim.Adam(model.classifier.parameters())

In [60]:
PATH_CHECKPOINT = "./models/check_"

In [69]:
from tensorboardX import SummaryWriter
writer = SummaryWriter('runs')

TypeError: __new__() got an unexpected keyword argument 'serialized_options'

In [None]:
epochs = 10
num_of_examples = 5000

for epoch in range(epochs):
    train_loss = 0
    val_loss = 0
    accuracy = 0
    
    # Training the model
    model.train()
    counter = 0
    print(f" -> TRAINING EPOCH {epoch}")
    for inputs, labels in tqdm_notebook(train_loader):
        # Move to device
        inputs, labels = inputs.to(device), labels.to(device)
        # Clear optimizers
        optimizer.zero_grad()
        # Forward pass
        output = model.forward(inputs)
        # Loss
        loss = criterion(output, labels)
        # Calculate gradients (backpropogation)
        loss.backward()
        # Adjust parameters based on gradients
        optimizer.step()
        # Add the loss to the training set's rnning loss
        train_loss += loss.item()*inputs.size(0)
        
        # Print the progress of our training

        counter += 1
        
        if counter >= num_of_examples:
            break
    
        
    # Evaluating the model
    model.eval()
    counter = 0
    # Tell torch not to calculate gradients
    
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, PATH_CHECKPOINT + str(epoch) + ".pt")
    
    print(f"-> EVALUATING EPOCH {epoch}")

    with torch.no_grad():
        for inputs, labels in tqdm_notebook(valid_loader):
            # Move to device
            inputs, labels = inputs.to(device), labels.to(device)
            # Forward pass
            output = model.forward(inputs)
            # Calculate Loss
            valloss = criterion(output, labels)
            # Add loss to the validation set's running loss
            val_loss += valloss.item()*inputs.size(0)
            
            # Since our model outputs a LogSoftmax, find the real 
            # percentages by reversing the log function
            output = torch.exp(output)
            # Get the top class of the output
            top_p, top_class = output.topk(1, dim=1)
            # See how many of the classes were correct?
            equals = top_class == labels.view(*top_class.shape)
            # Calculate the mean (get the accuracy for this batch)
            # and add it to the running accuracy for this epoch
            accuracy += torch.sum(equals.type(torch.FloatTensor)).item()
            
            # Print the progress of our evaluation
            counter += 1
            if counter >= num_of_examples:
                break
            
    # Get the average loss for the entire epoch
    train_loss = train_loss/(train_loader.batch_size*(counter))
    valid_loss = val_loss/(valid_loader.batch_size*(counter))
    # Print out the information
    print('Accuracy: ', accuracy/(valid_loader.batch_size*counter))
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(epoch, train_loss, valid_loss))

 -> TRAINING EPOCH 0


HBox(children=(IntProgress(value=0, max=17519), HTML(value='')))

-> EVALUATING EPOCH 0


HBox(children=(IntProgress(value=0, max=7483), HTML(value='')))

Accuracy:  0.865875
Epoch: 0 	Training Loss: 0.483130 	Validation Loss: 0.371526
 -> TRAINING EPOCH 1


HBox(children=(IntProgress(value=0, max=17519), HTML(value='')))

-> EVALUATING EPOCH 1


HBox(children=(IntProgress(value=0, max=7483), HTML(value='')))

In [None]:
for image in valid_loader:
    print(image)
    break

In [65]:
PATH = "./model.pt"

torch.save(model, PATH)


  "type " + obj.__name__ + ". It won't be checked "


In [67]:
print('Accuracy: ', accuracy/(valid_loader.batch_size*counter))


Accuracy:  0.89071875
