In [139]:
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import torch.optim as optim

In [21]:
model = torch.hub.load('pytorch/vision:v0.6.0', 'squeezenet1_0', pretrained=True)

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.6.0


In [22]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [95]:
import os
directory = "drive/My Drive/csc420_project/data/"

image_dict = {}
for folder in os.listdir(directory):
  print(folder)
  folder_name = os.path.join(directory, folder)
  images = []
  for filename in os.listdir(folder_name):
    input_image = Image.open(os.path.join(folder_name, filename)).convert('RGB')
    if input_image is not None:
      images.append(input_image)
  image_dict[folder] = images

contempt
anger
fear
surprise
disgust
sadness
happy


In [100]:
for param in model.parameters():
  param.requires_grad = False

model.classifier._modules["1"] = nn.Conv2d(512, 10, kernel_size=(1, 1))

In [105]:
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])
input_tensor = preprocess(images[0])
input_batch = input_tensor.unsqueeze(0)

In [115]:
data = torchvision.datasets.ImageFolder(root='drive/My Drive/csc420_project/data/', transform=preprocess)

In [134]:
train_data, test_data = torch.utils.data.random_split(data, [800, len(data) - 800])

In [116]:
data.classes

['anger', 'contempt', 'disgust', 'fear', 'happy', 'sadness', 'surprise']

In [136]:
trainloader = torch.utils.data.DataLoader(train_data, batch_size=8,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(test_data, batch_size=8,
                                         shuffle=False, num_workers=2)

In [148]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.classifier.parameters(), lr=0.01, momentum=0.9)

In [149]:
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 5 == 0:    # print every 50 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 5))
            running_loss = 0.0

print('Finished Training')

[1,     1] loss: 0.231
[1,     6] loss: 1.279
[1,    11] loss: 1.502
[1,    16] loss: 0.958
[1,    21] loss: 1.166
[1,    26] loss: 1.596
[1,    31] loss: 1.632
[1,    36] loss: 1.355
[1,    41] loss: 1.054
[1,    46] loss: 1.608
[1,    51] loss: 1.645
[1,    56] loss: 1.065
[1,    61] loss: 0.897
[1,    66] loss: 1.433
[1,    71] loss: 0.826
[1,    76] loss: 0.809
[1,    81] loss: 0.876
[1,    86] loss: 1.161
[1,    91] loss: 1.156
[1,    96] loss: 0.853
[2,     1] loss: 0.240
[2,     6] loss: 1.032
[2,    11] loss: 0.822
[2,    16] loss: 0.826
[2,    21] loss: 1.385
[2,    26] loss: 1.259
[2,    31] loss: 1.103
[2,    36] loss: 0.965
[2,    41] loss: 1.320
[2,    46] loss: 0.990
[2,    51] loss: 0.967
[2,    56] loss: 0.948
[2,    61] loss: 0.969
[2,    66] loss: 0.730
[2,    71] loss: 1.073
[2,    76] loss: 1.027
[2,    81] loss: 0.637
[2,    86] loss: 0.850
[2,    91] loss: 0.943
[2,    96] loss: 0.747
Finished Training


In [158]:
total, correct = 0, 0
for i, data in enumerate(testloader, 0):
  inputs, labels = data
  outputs = model(inputs)
  total += len(outputs)
  correct += (outputs.argmax(axis=1)==labels).sum()
  print(outputs.argmax(axis=1)==labels)
# outputs_test = model(inputs)
#         loss = criterion(outputs, labels)

tensor([ True,  True,  True,  True,  True, False,  True,  True])
tensor([ True,  True,  True, False,  True,  True,  True,  True])
tensor([ True,  True,  True,  True, False,  True, False,  True])
tensor([ True,  True,  True,  True,  True,  True,  True, False])
tensor([ True,  True,  True,  True,  True, False, False,  True])
tensor([False,  True,  True, False,  True,  True,  True,  True])
tensor([False,  True, False,  True,  True,  True,  True, False])
tensor([ True,  True,  True, False,  True, False,  True,  True])
tensor([False,  True,  True, False,  True, False, False, False])
tensor([ True,  True,  True, False,  True,  True, False,  True])
tensor([ True,  True, False,  True,  True,  True,  True, False])
tensor([ True,  True, False,  True,  True,  True,  True, False])
tensor([False,  True,  True,  True, False,  True,  True, False])
tensor([False,  True,  True,  True,  True,  True,  True,  True])
tensor([ True,  True,  True,  True, False,  True, False,  True])
tensor([ True,  True,  Tr

In [159]:
correct/total

tensor(0.7225)

In [78]:
torch.nn.functional.softmax(model(input_batch)[0], dim=0)

tensor([5.0820e-06, 1.3278e-04, 1.9385e-06, 6.3734e-06, 9.7049e-06, 1.3537e-04,
        2.4458e-07, 1.5868e-06, 4.3819e-07, 1.9357e-07, 4.1149e-05, 1.7066e-07,
        1.9116e-07, 6.1778e-07, 2.0990e-07, 4.7334e-07, 7.6882e-07, 3.2619e-07,
        1.9090e-07, 6.3272e-07, 9.9615e-07, 6.4790e-06, 1.0136e-06, 7.5075e-07,
        4.7457e-07, 2.8162e-07, 1.0196e-05, 3.3439e-06, 2.3744e-05, 5.1369e-04,
        6.9723e-07, 1.0275e-06, 4.5763e-06, 1.3030e-06, 1.4085e-06, 5.1848e-07,
        8.5952e-06, 1.2628e-07, 6.4958e-06, 5.3273e-07, 9.2076e-07, 1.7081e-06,
        4.1623e-07, 1.4271e-05, 4.3673e-07, 1.3392e-06, 9.8664e-07, 4.5255e-06,
        4.3735e-07, 3.2336e-07, 3.9974e-07, 1.1510e-05, 8.6861e-06, 4.1654e-06,
        5.5755e-07, 1.5891e-07, 8.1985e-07, 1.9281e-07, 3.2173e-07, 7.3376e-06,
        1.7030e-06, 3.0784e-06, 3.1243e-05, 1.5649e-05, 6.7643e-06, 3.7818e-06,
        4.2090e-05, 1.2728e-06, 2.0829e-05, 3.2188e-06, 1.9543e-06, 1.7787e-05,
        4.8485e-06, 6.0464e-05, 5.5780e-