Test if image exists
First, download https://github.com/falloutdurham/beginners-pytorch-deep-learning/tree/master/chapter2 and run download.py in that directory(some of them won't be downloaded successfully, which shouldn't matter).
Then, upload train & test & val folders to Google Drive. Finally, in Colab Notebooks, mount Google Drive in the Files Menu on the left.


Note, in this notebook, in Runtime, select "Run all" is almost always needed, as much of the code relies on the result from previous python running.

Firstly, mount Google Drive:

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# PIL is a widely used python image library
import PIL
from PIL import Image, ImageFile

import torchvision
from torchvision import transforms
from torch.utils import data

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# this is to prevent image file to be too large
ImageFile.LOAD_TRUNCATED_IMAGES = True 

workspace_path = "./drive/MyDrive/image_classification_test/"
train_data_path = workspace_path + "/train/"
val_data_path = workspace_path + "/val/"
test_data_path = workspace_path + "/test/"

# check if this image exists
img = Image.open(val_data_path + "/fish/100_1422.JPG")
print(img.size)

(512, 342)


Run the following to set train_data, test_data and validation_data. The torchvision helps do the image preprocessing here.
Search for knowledge about train set, test set and validation set if needed.

In [3]:
# this function is very important, otherwise quite a few images won't be opened successfully, causing script's runtime error.
def check_image(path):
  try:
    Image.open(path)
    return True
  except:
    return False

img_transforms = transforms.Compose([
    transforms.Resize((64, 64)), # resize image
    transforms.ToTensor(), # store image data in tensor
    transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    # the above normalization follows distribution of ImageNet dataset
    ])

train_data = torchvision.datasets.ImageFolder(root = train_data_path, transform = img_transforms, is_valid_file=check_image)

val_data = torchvision.datasets.ImageFolder(root = val_data_path, transform = img_transforms, is_valid_file=check_image)

test_data = torchvision.datasets.ImageFolder(root = test_data_path, transform = img_transforms, is_valid_file=check_image)

# load data in a batch
batch_size = 64
train_data_loader = data.DataLoader(train_data, batch_size = batch_size)
val_data_loader = data.DataLoader(val_data, batch_size = batch_size)
test_data_loader = data.DataLoader(test_data, batch_size = batch_size)
# check how many images get loaded
print(len(train_data_loader.dataset))
print(len(val_data_loader.dataset))
print(len(test_data_loader.dataset))


803
110
160


Define the network structure:

In [4]:
class simple_net(nn.Module):
    def __init__(self):
        super(simple_net, self).__init__()
        # search for the definition of nn.Linear for more info
        # 12288 is for 64 * 64 * 3 (image size)
        self.fc1 = nn.Linear(12288, 84)
        self.fc2 = nn.Linear(84, 50)
        self.fc3 = nn.Linear(50, 2)
    # search for the definition of forward for more info
    def forward(self, x):
        x = x.view(-1, 12288)
        # search for why relu is used in neural networks
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

my_model = simple_net()
# search for Adam optimizer for more info if needed
optimizer = optim.Adam(my_model.parameters(), lr = 0.001)

# copy the model to device
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")
my_model.to(device)

simple_net(
  (fc1): Linear(in_features=12288, out_features=84, bias=True)
  (fc2): Linear(in_features=84, out_features=50, bias=True)
  (fc3): Linear(in_features=50, out_features=2, bias=True)
)

Note the following code relies on result from previous python running. Click on the "Runtime" and select "Run all".

Train the model and print the loss result per epoch.

In [5]:
def loss_update(model, batch, loss_fn, device, check_result):
    inputs, targets = batch
    inputs = inputs.to(device)
    targets = targets.to(device)
    output = model(inputs)
    loss = loss_fn(output, targets)
    num_current_correct = 0
    if (check_result):
        result = torch.eq(torch.max(F.softmax(output, dim = 1), dim = 1)[1], targets)
        num_current_correct = torch.sum(result).item()
    return loss, num_current_correct

def train(model, optimizer, loss_fn, train_data_loader, val_data_loader, epochs, device):
  for epoch in range(epochs):
    training_loss = 0.0
    valid_loss = 0.0
    # this is to set model in training mode
    model.train()
    check_result = False
    # training process
    for batch in train_data_loader:
      optimizer.zero_grad()
      loss = loss_update(model, batch, loss_fn, device, check_result)[0]
      loss.backward()
      optimizer.step()
      training_loss += loss.data.item() * batch[0].shape[0]
    training_loss /= len(train_data_loader.dataset)

    # this is to set model in evaluation mode
    model.eval()
    check_result = True
    num_correct = 0
    num_examples = 0
    for batch in val_data_loader:
      loss, num_current_correct= loss_update(model, batch, loss_fn, device, check_result)
      valid_loss += loss.data.item() * batch[0].shape[0]
      num_correct += num_current_correct
      num_examples += batch[0].shape[0]
    valid_loss /= len(val_data_loader.dataset)

    print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss, valid_loss, num_correct / num_examples))

# modify the number of epochs to check how loss and accuracy changes with more training
train(my_model, optimizer, torch.nn.CrossEntropyLoss(), train_data_loader, val_data_loader, 20, device)

Epoch: 0, Training Loss: 3.74, Validation Loss: 2.82, accuracy = 0.40
Epoch: 1, Training Loss: 2.02, Validation Loss: 1.38, accuracy = 0.43
Epoch: 2, Training Loss: 1.28, Validation Loss: 0.82, accuracy = 0.71
Epoch: 3, Training Loss: 0.64, Validation Loss: 0.65, accuracy = 0.70
Epoch: 4, Training Loss: 0.43, Validation Loss: 0.69, accuracy = 0.67
Epoch: 5, Training Loss: 0.40, Validation Loss: 0.63, accuracy = 0.75
Epoch: 6, Training Loss: 0.29, Validation Loss: 0.66, accuracy = 0.69
Epoch: 7, Training Loss: 0.28, Validation Loss: 0.65, accuracy = 0.73
Epoch: 8, Training Loss: 0.24, Validation Loss: 0.63, accuracy = 0.72
Epoch: 9, Training Loss: 0.20, Validation Loss: 0.66, accuracy = 0.70
Epoch: 10, Training Loss: 0.18, Validation Loss: 0.65, accuracy = 0.72
Epoch: 11, Training Loss: 0.17, Validation Loss: 0.66, accuracy = 0.72
Epoch: 12, Training Loss: 0.14, Validation Loss: 0.66, accuracy = 0.73
Epoch: 13, Training Loss: 0.13, Validation Loss: 0.68, accuracy = 0.71
Epoch: 14, Train

After training, run prediction.

In [6]:
labels = ['cat', 'fish']
img = Image.open(val_data_path + "/fish/100_1422.JPG")
img = img_transforms(img).to(device)
img = torch.unsqueeze(img, 0)
my_model.eval()
prediction = F.softmax(my_model(img), dim = 1)
print("prediction = ")
print(prediction)
prediction = prediction.argmax()
print(labels[prediction])

prediction = 
tensor([[9.2736e-06, 9.9999e-01]], grad_fn=<SoftmaxBackward>)
fish


Run the following to save model:

In [7]:
model_path = workspace_path + "/simple_net.pth"
torch.save(my_model, model_path)
# reload
my_model = torch.load(model_path)

# only save the parameters
model_dict_path = workspace_path + "/simple_net_dict.pth"
torch.save(my_model.state_dict(), model_dict_path)
my_model = simple_net()
my_model_state_dict = torch.load(model_dict_path)
my_model.load_state_dict(my_model_state_dict)

<All keys matched successfully>