In [1]:
import pandas as pd
import torchvision
from torchvision import transforms
from PIL import Image
from torch.utils.data import DataLoader
import torch
from torch.autograd import Variable
from tqdm.notebook import tqdm
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, accuracy_score
import numpy as np
import matplotlib.pyplot as plt
pd.set_option('display.max_colwidth', None)

In [2]:
df = pd.read_csv("dataset/label.csv", sep=',')
df = df.sample(frac=1).reset_index(drop=True)
label_encoder = LabelEncoder()
df['emotion_encoded'] = label_encoder.fit_transform(df['emotion'])
display(df.head(10))

Unnamed: 0,image,emotion,emotion_encoded
0,122.jpg,Disgust,2
1,105.jpg,Contempt,1
2,131.jpg,Fear,3
3,107.jpg,Fear,3
4,126.jpg,Sad,6
5,81.jpg,Contempt,1
6,98.jpg,Disgust,2
7,58.jpg,Disgust,2
8,146.jpg,Disgust,2
9,85.jpg,Neutral,5


In [3]:
class extractImageFeatureResNetDataSet():
    def __init__(self, data):
        self.data = data
        self.scaler = transforms.Resize([224, 224])
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
        self.to_tensor = transforms.ToTensor()
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        image_name = self.data.iloc[idx]['image']
        img_loc = 'dataset/images/'+str(image_name)
        img = Image.open(img_loc)
        t_img = self.normalize(self.to_tensor(self.scaler(img)))
        return t_img, self.data.iloc[idx]['emotion_encoded'].item()

In [4]:
train_ImageDataset_ResNet = extractImageFeatureResNetDataSet(df)
train_ImageDataloader_ResNet = DataLoader(train_ImageDataset_ResNet, batch_size = 1, shuffle=False)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet18 = torchvision.models.resnet18(pretrained=True).to(device)
resnet18.eval()
list(resnet18._modules)



['conv1',
 'bn1',
 'relu',
 'maxpool',
 'layer1',
 'layer2',
 'layer3',
 'layer4',
 'avgpool',
 'fc']

In [6]:
resNet18Layer4 = resnet18._modules.get('layer4').to(device)

In [7]:
def get_vector(t_img):
    t_img = Variable(t_img)
    my_embedding = torch.zeros(1, 512, 7, 7)
    def copy_data(m, i, o):
        my_embedding.copy_(o.data)
    h = resNet18Layer4.register_forward_hook(copy_data)
    resnet18(t_img)
    h.remove()
    return my_embedding

In [8]:
xs = []
for t_img, label_encoder in tqdm(train_ImageDataloader_ResNet):
    t_img = t_img.to(device)
    embdg = get_vector(t_img)
    dim = embdg.shape[1]
    seq_length = embdg.view(dim, -1).size(1)
    xs.append(embdg.squeeze().view(-1).detach().cpu().numpy())

  0%|          | 0/152 [00:00<?, ?it/s]

In [9]:
df['embeddings'] = xs

In [10]:
df

Unnamed: 0,image,emotion,emotion_encoded,embeddings
0,122.jpg,Disgust,2,"[0.0, 0.2553425, 0.19864154, 1.4733634, 0.9519596, 0.64513445, 0.31432152, 0.0, 1.395471, 1.0335015, 2.1121247, 1.3558153, 1.0306091, 0.6030865, 0.0, 1.1495869, 1.2043129, 1.8133128, 0.8334929, 0.8014201, 0.5314569, 0.0, 1.7906526, 0.8401391, 1.9037502, 0.0, 0.64398617, 0.2831356, 0.2781834, 1.8679556, 0.68791753, 2.6219616, 0.0, 0.0, 0.0, 0.0, 1.5116826, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.16934592, 1.1216514, 1.1230497, 0.4130748, 0.0, 0.0, 0.0, 0.0, 0.8568609, 0.8319073, 0.0, 0.0, 0.0, 0.0, 0.0, 0.78946096, 1.2648194, 0.057054024, 0.07261681, 0.014452144, 0.0, 1.6580561, 1.9901056, 3.2788842, 1.3810068, 1.7355158, 1.3162445, 0.0, 0.56609887, 0.3828603, 1.5526278, 1.3986729, 3.0466795, 2.4032652, 0.0, 0.0, 0.0, 0.0, 0.8846738, 2.908171, 2.4544382, 0.0, 0.0, 0.0, 0.0, 0.22588436, 2.0542476, 2.1666732, 0.0, 0.0, ...]"
1,105.jpg,Contempt,1,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1372497, 0.66143113, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.821706, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6852348, 0.6642686, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.41258502, 0.0, 0.0, 0.0, 0.0, 0.13903026, 0.055395205, 0.0, 0.0, 0.0, 0.0, 0.34272125, 0.7604307, 0.18252398, 0.0, 0.0, 0.24835087, 1.0976923, 1.1847737, 0.9900833, 0.1071143, 0.0, 0.0, 0.0, 0.91196114, 1.5358677, 3.0140834, 1.6674527, 0.0, 0.0, 0.23000321, 1.0817962, 1.761068, 2.7192724, 2.0237162, 0.24020718, 0.0, ...]"
2,131.jpg,Fear,3,"[0.0, 0.0, 0.57628125, 1.3390694, 1.0388707, 0.23910704, 0.0, 0.0, 0.0, 0.8647506, 1.7062124, 1.0452168, 0.04049769, 0.0, 0.0, 0.0, 0.1163572, 0.6256913, 0.14081556, 0.0, 0.0, 0.6272202, 0.90286297, 1.4630734, 1.0048512, 0.22691488, 0.0, 0.0, 2.6054792, 3.5703852, 2.625505, 1.1426862, 0.0, 0.0, 0.0, 1.2313861, 1.9026924, 1.6610072, 0.82031226, 0.0, 0.0, 0.0, 0.0, 0.0, 0.099371016, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.06426788, 0.20180102, 0.0, 0.0, 0.0, 0.0, 0.19471233, 0.9190455, 0.857705, 0.4372949, 0.0, 0.0, 0.0, 0.4506044, 0.8394079, 0.9464472, 1.0785191, 0.19091985, 0.0, 0.0, 0.39973402, 0.75909495, 1.2885472, 2.1306508, 1.2125717, 0.0, 0.0, 0.0, 0.49136138, 1.3755836, 2.8731916, 1.6558971, 0.0, 0.0, 0.0, 0.0, 0.4794854, 2.10242, 0.9361596, 0.0, 0.0, 0.42649, 0.0, 0.24889906, 1.1857839, 0.64455503, 0.0, 0.0, ...]"
3,107.jpg,Fear,3,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.13191104, 0.0, 0.0, 0.0, 0.0, 0.0, 0.39914662, 1.1512022, 0.8845013, 0.0, 0.0, 0.0, 0.0, 0.21804394, 0.7788117, 0.6572619, 0.45837012, 1.4333589, 0.41429436, 0.06109178, 0.29737237, 0.6015885, 0.14818351, 0.47677273, 1.0350032, 0.31789604, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.16857332, 0.0, 0.0, 0.0, 0.0, 0.0, 0.104557775, 0.7923274, 0.15014245, 0.0, 0.0, 0.0, 0.0, 0.2542536, 2.7655382, 2.247298, 0.0, 0.0, 0.0, 0.0, 1.090465, 5.7420154, 5.536825, 0.39055932, 0.44017512, 0.0, 0.0, 1.5265279, 5.049629, 4.7963743, 0.0, 0.0, ...]"
4,126.jpg,Sad,6,"[0.254521, 0.7833906, 1.4329958, 2.4234712, 2.7830353, 2.0779808, 1.0589365, 0.49415642, 2.0764759, 3.5792384, 4.9451, 4.6842184, 3.4089239, 1.8507133, 0.9441617, 2.2387955, 2.8039715, 3.621048, 3.0185494, 2.9218638, 1.8529799, 1.5351516, 2.9521337, 2.9954863, 4.2604485, 2.2905214, 2.131907, 1.4202769, 2.2275245, 2.5743165, 1.085864, 1.1563818, 0.45282376, 0.13452536, 0.0, 1.3674661, 0.8854226, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04021536, 0.4376132, 0.475554, 0.77974194, 1.3260603, 1.3540992, 0.75375, 0.6243999, 0.21602385, 0.3264669, 0.58598334, 1.2946577, 1.6551384, 0.75511575, 0.0, 0.0, 0.53667456, 0.9533034, 1.6146662, 2.235755, 1.1854101, 0.5106351, 0.6472553, 0.9686109, 1.4936934, 2.0117297, 3.3599954, 1.9740862, 0.2939383, 0.058594458, 0.0, 0.0, 0.0, 2.029226, 1.1539093, 0.5787791, 0.0, 0.0, 0.0, 0.0, 1.8619833, 1.2418151, 0.59064305, 0.133744, 0.0, 0.0, 0.0, 0.7966119, 1.2048721, 0.0, 0.2962505, ...]"
...,...,...,...,...
147,140.jpg,Happy,4,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.52219605, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.08106333, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.33772808, 0.38514653, 0.0, 0.0, 0.0, 0.0, 0.0, 0.36529714, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.74965847, 0.7066292, 0.0, 0.0, 0.0, 0.0, 0.0, 0.74649537, 1.0915289, 0.0, 0.0, 0.0, 0.0, 0.0, 0.05954508, 0.47157985, 0.0, 0.0, 0.0, 0.2946463, 0.83879554, 0.0, 0.0, 0.0, 0.0, 0.0, 0.09142997, 0.19708014, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1266229, 0.0, 0.1290232, ...]"
148,22.jpg,Sad,6,"[0.0, 0.12043937, 0.09774172, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.35054097, 0.0, 0.0, 0.0, 1.9294115, 2.096951, 0.91868204, 0.99560654, 0.011301444, 0.0, 0.0, 2.1402333, 2.4741557, 1.8237984, 2.4529355, 1.1543633, 1.1300974, 0.0, 3.1261575, 2.4418476, 0.0, 0.0, 0.0, 0.0, 0.06236291, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2971871, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.63149184, 1.0340137, 0.059428643, 0.0, 0.0, 0.0, 0.5046215, 0.64221454, 2.1903098, 1.3018103, 0.0, 0.5620924, 2.2208946, 4.281111, 2.9443147, 3.111346, 1.4421717, 0.0, 0.6676522, 3.5121164, 6.3839617, 4.5562325, 2.596778, 1.4630884, 0.0, 0.0, 2.7821255, 3.920694, 3.274433, 1.3435962, 0.26117226, 0.0, 0.0, 0.883193, 1.3706087, 1.067751, 0.0, 0.0, ...]"
149,36.jpg,Happy,4,"[0.0, 0.0, 0.27574044, 0.8872288, 0.5183063, 0.1680931, 0.0, 0.0, 0.0, 0.32873583, 0.77120376, 0.06164374, 0.017812328, 0.0, 0.0, 0.06686219, 0.0018352144, 0.16036412, 0.0, 0.57466644, 0.12664662, 1.149339, 0.5791175, 0.046364803, 0.0, 0.0, 0.2311505, 0.0, 1.7511935, 0.681923, 0.34840092, 0.0, 1.1699722, 2.7043588, 2.2648842, 1.1945368, 0.073707804, 0.0, 0.0, 0.13433076, 1.7979662, 2.6155114, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4234977, 1.4383757, 0.0, 0.0, 0.0, 0.0, 0.0, 0.027975226, 0.05252131, 0.0, 0.0, 0.0, 0.11179694, 0.17730598, 0.44666794, 0.15049669, 0.0, 0.0, 0.0, 0.0, 0.018611535, 0.14689958, 0.0, 0.0, 0.0, 0.0, 0.0, 0.08942916, 0.6744919, 0.0, 0.5484563, 0.0, 0.0, 0.0, 0.0, 0.66108954, 0.20812981, 4.1413918, 2.9597166, 0.40110478, 0.0, 0.0, 0.9673716, 1.6068832, 4.554695, 4.2051773, 1.2517941, 0.0, 0.0, 0.0, 1.3950381, 0.035232432, 0.14210023, ...]"
150,34.jpg,Disgust,2,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.06861038, 0.0, 0.0, 0.0, 0.16194466, 0.069424435, 0.0, 1.9770504, 1.390541, 1.4879557, 3.0581143, 4.3385024, 5.278859, 3.1015756, 2.5319312, 0.90231156, 0.8358623, 2.2009654, 3.8435068, 5.880035, 4.4053445, 1.2679435, 0.062279947, 0.41081202, 1.6320806, 3.7388797, 5.0445933, 3.7897153, 0.0, 0.0, 0.0, 0.039169613, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0615579, 0.5420773, 0.4438106, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25820604, 0.0, 0.0, 0.0, 0.104720734, 0.0, 0.0, 0.0, 0.0, 0.86825275, 0.31625432, 1.1938372, 0.382185, 0.0, 0.0, 0.0, 1.4284322, 0.93152314, 1.5557094, 1.498081, 0.27840212, 0.0, 0.0, 0.65432495, 1.4282649, 0.60199577, 0.76820666, ...]"


In [11]:
X_train, X_test, y_train, y_test = train_test_split(df['embeddings'], df['emotion_encoded'], test_size=0.2)

In [12]:
rf_classifier = RandomForestClassifier(n_estimators=200, random_state=42)
rf_classifier.fit(np.stack(X_train), np.stack(y_train))

0,1,2
,n_estimators,200
,criterion,'gini'
,max_depth,
,min_samples_split,2
,min_samples_leaf,1
,min_weight_fraction_leaf,0.0
,max_features,'sqrt'
,max_leaf_nodes,
,min_impurity_decrease,0.0
,bootstrap,True


In [15]:
y_pred_test = rf_classifier.predict(np.stack(X_test))
accuracy_score(y_test, y_pred_test)

0.0

In [16]:
ss = np.stack(X_test)
print(type(ss[0]))

<class 'numpy.ndarray'>


In [17]:
%matplotlib inline
test_id = "124.jpg"
img = np.asarray(Image.open('dataset/images/{}'.format(test_id)))
plt.imshow(img)
img_embedding = df[df['image']==test_id]['embeddings'].to_numpy()
predict = rf_classifier.predict(np.stack(img_embedding))
actual = df[df['image']==test_id]['emotion_encoded'].to_numpy()
id2label = {4:"happy",6:"sad",7:"suprised",5:"neutral",1:"contempt",2:"disgust",3:"fear"}


print("Predicted: {}".format(id2label[predict[0]]))
print("Actual: {}".format(id2label[actual[0]]))

Predicted: happy
Actual: happy
