In [1]:
from __future__ import print_function
import argparse
import numpy  as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms

from data_loaders import Plain_Dataset, eval_data_dataloader
from deep_emotion import Deep_Emotion
from generate_data import Generate_data


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
import cv2
import os
import math
import numpy as np
import pandas as pd
import csv
import matplotlib.pyplot as plt
data_path = 'ck_new/'

In [3]:
def Train(epochs,train_loader,val_loader,criterion,optmizer,device):
    '''
    Training Loop
    '''
    print("===================================Start Training===================================")
    for e in range(epochs):
        train_loss = 0
        validation_loss = 0
        train_correct = 0
        val_correct = 0
        # Train the model  #
        net.train()
        for data, labels in train_loader:
            data, labels = data.to(device), labels.to(device)
            optmizer.zero_grad()
            outputs = net(data)
            loss = criterion(outputs,labels)
            loss.backward()
            optmizer.step()
            train_loss += loss.item()
            _, preds = torch.max(outputs,1)
            train_correct += torch.sum(preds == labels.data)

        #validate the model#
        net.eval()
        for data,labels in val_loader:
            data, labels = data.to(device), labels.to(device)
            val_outputs = net(data)
            val_loss = criterion(val_outputs, labels)
            validation_loss += val_loss.item()
            _, val_preds = torch.max(val_outputs,1)
            val_correct += torch.sum(val_preds == labels.data)

        train_loss = train_loss/len(train_dataset)
        train_acc = train_correct.double() / len(train_dataset)
        validation_loss =  validation_loss / len(validation_dataset)
        val_acc = val_correct.double() / len(validation_dataset)
        print('Epoch: {} \tTraining Loss: {:.8f} \tValidation Loss {:.8f} \tTraining Acuuarcy {:.3f}% \tValidation Acuuarcy {:.3f}%'
                                                           .format(e+1, train_loss,validation_loss,train_acc * 100, val_acc*100))

    torch.save(net.state_dict(),'deep_emotion-ckplus-{}-{}-{}.pt'.format(epochs,batchsize,lr))
    print("===================================Training Finished===================================")


## create train,vali,test csv file
70%, 10%, 20%

In [4]:
np.random.seed(100)
with open('ck_plus.csv','r') as csvfile:
    reader = csv.reader(csvfile)
    rows = [row for row in reader]



In [8]:
len(rows)*0.2

65.4

In [10]:
num_each = 9

random_array = np.arange(len(rows))
cnt_an = 0
cnt_di = 0
cnt_fe = 0
cnt_ha = 0
cnt_sa = 0
cnt_su = 0
cnt_ne = 0
loop = True

cnt = 0

st = []
st_lbl = []
idx = []
while loop:
    temp = np.random.choice(random_array)
    item = rows[temp][0]
    lbl = int(rows[temp][1])
    if item not in st:
        
        if lbl == 0 and cnt_an<num_each :
            cnt_an = cnt_an+1
            st_lbl.append(lbl)
            st.append(item)
            idx.append(temp)
        elif lbl==1 and cnt_di<num_each:
            cnt_di = cnt_di+1
            st_lbl.append(lbl)
            st.append(item)
            idx.append(temp)
        elif lbl==2 and cnt_fe<num_each:
            cnt_fe = cnt_fe+1
            st_lbl.append(lbl)
            st.append(item)
            idx.append(temp)
        elif lbl==3 and cnt_ha<num_each:
            cnt_ha = cnt_ha +1
            st_lbl.append(lbl)
            st.append(item)
            idx.append(temp)
        elif lbl == 4 and cnt_sa<num_each:
            cnt_sa = cnt_sa +1
            st_lbl.append(lbl)
            st.append(item)
            idx.append(temp)
        elif lbl==5 and cnt_su<num_each:
            cnt_su = cnt_su+1
            st_lbl.append(lbl)
            st.append(item)
            idx.append(temp)
        elif lbl==6 and cnt_ne<num_each:
            cnt_ne = cnt_ne+1
            st_lbl.append(lbl)
            st.append(item)
            idx.append(temp)

#         print(len(st))
        if cnt_an==num_each and cnt_di==num_each and cnt_fe==num_each and cnt_ha==num_each and cnt_sa==num_each and cnt_su==num_each and cnt_ne==num_each:
            break
        

### test.csv

In [15]:
with open('test_ckplus.csv', 'w', encoding='UTF8',newline='') as test_ckplus:
    test_writer = csv.writer(test_ckplus)
    for i in range(len(rows)):
        if i in idx:
            test_writer.writerow(rows[i])

### vali.csv

In [None]:
num_val = 32
idx_val = []
cnt_val = 0
t=0
while True:
    t = t+1
    temp = np.random.choice(random_array)
    if temp not in idx:
        if cnt_val>=num_val:
            break
        cnt_val = cnt_val+1
        idx_val.append(temp)
#     print(t,cnt_val)

In [27]:
with open('vali_ckplus.csv', 'w', encoding='UTF8',newline='') as vali_ckplus:
    vali_writer = csv.writer(vali_ckplus)
    for i in range(len(rows)):
        if i in idx_val:
            vali_writer.writerow(rows[i])

### train.csv

In [29]:
with open('train_ckplus.csv', 'w', encoding='UTF8',newline='') as train_ckplus:
    train_writer = csv.writer(train_ckplus)
    for i in range(len(rows)):
        if i not in idx and i not in idx_val:
            train_writer.writerow(rows[i])

## create dataset and dataloader

In [4]:
class Jaffe_Dataset(Dataset):
    def __init__(self,csv_file,img_dir,datatype,transform):
        with open(csv_file,'r') as csvfile:
            rd = csv.reader(csvfile)
            self.data = [row for row in rd]

        self.img_dir = img_dir
        self.transform = transform
        self.datatype = datatype

    def __len__(self):
        return len(self.data)

    def __getitem__(self,idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
#         img = Image.open(self.img_dir+self.data[idx][0])

        img = cv2.imread(self.img_dir+self.data[idx][0],0)
#         img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
#         print(img.shape)
        lables = np.array(int(self.data[idx][1]))
        lables = torch.from_numpy(lables).long()

        if self.transform :
            img = self.transform(img)
        return img,lables


In [5]:
traincsv_file = 'train_ckplus.csv'
validationcsv_file = 'vali_ckplus.csv'
train_img_dir = 'ck_new/'
validation_img_dir = 'ck_new/'

In [6]:
batchsize = 4
transformation= transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
# transformation= transforms.Compose([transforms.ToTensor()])
train_dataset= Jaffe_Dataset(csv_file=traincsv_file, img_dir = train_img_dir, datatype = 'train', transform = transformation)
validation_dataset= Jaffe_Dataset(csv_file=validationcsv_file, img_dir = validation_img_dir, datatype = 'val', transform = transformation)
train_loader= DataLoader(train_dataset,batch_size=batchsize,shuffle = True,num_workers=0)
val_loader=   DataLoader(validation_dataset,batch_size=batchsize,shuffle = True,num_workers=0)

In [7]:
epochs = 50
lr = 0.001

net = Deep_Emotion()
net.to(device)
criterion= nn.CrossEntropyLoss()
optmizer= optim.Adam(net.parameters(),lr= lr)

In [8]:
Train(epochs, train_loader, val_loader, criterion, optmizer, device)

Epoch: 1 	Training Loss: 0.42105620 	Validation Loss 0.42464104 	Training Acuuarcy 39.655% 	Validation Acuuarcy 50.000%
Epoch: 2 	Training Loss: 0.28998634 	Validation Loss 0.30538505 	Training Acuuarcy 63.362% 	Validation Acuuarcy 56.250%
Epoch: 3 	Training Loss: 0.21413986 	Validation Loss 0.27801306 	Training Acuuarcy 72.414% 	Validation Acuuarcy 65.625%
Epoch: 4 	Training Loss: 0.14461304 	Validation Loss 0.28765207 	Training Acuuarcy 82.328% 	Validation Acuuarcy 62.500%
Epoch: 5 	Training Loss: 0.13017361 	Validation Loss 0.20455906 	Training Acuuarcy 83.190% 	Validation Acuuarcy 65.625%
Epoch: 6 	Training Loss: 0.10906308 	Validation Loss 0.23160118 	Training Acuuarcy 87.500% 	Validation Acuuarcy 65.625%
Epoch: 7 	Training Loss: 0.08493465 	Validation Loss 0.20749995 	Training Acuuarcy 89.224% 	Validation Acuuarcy 71.875%
Epoch: 8 	Training Loss: 0.08826700 	Validation Loss 0.21346084 	Training Acuuarcy 86.638% 	Validation Acuuarcy 68.750%
Epoch: 9 	Training Loss: 0.06266985 	Val

In [14]:
testcsv_file = 'test_ckplus.csv'
test_img_dir = 'ck_new/'
test_dataset= Jaffe_Dataset(csv_file=testcsv_file, img_dir = test_img_dir, datatype = 'test', transform = transformation)
test_loader= DataLoader(test_dataset,batch_size=63,shuffle = True,num_workers=0)

total = []
net.eval()
with torch.no_grad():
    for data,labels in test_loader:
        output = net(data)
        pred = F.softmax(output,dim=1)
        result = torch.argmax(pred,1)
        wrong = torch.where(result != labels,torch.tensor([1.]),torch.tensor([0.])) 
        acc = 1-(torch.sum(wrong)/70)
        total.append(acc.item())

    print('Accuracy of the network on the test images: %d %%' % (100 * np.mean(total)))

Accuracy of the network on the test images: 67 %


### confusion matrix

In [10]:
def confusion_matrix(preds, labels, conf_matrix):
#     preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

Emotion_kinds = 7
conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)


            
net.eval()
with torch.no_grad():
    for data,labels in test_loader:
        output = net(data)
        pred = F.softmax(output,dim=1)
        result = torch.argmax(pred,1)
        conf_matrix = confusion_matrix(result, labels, conf_matrix)

Emotion=7#这个数值是具体的分类数，大家可以自行修改
labels = ['angry', 'contempt', 'disgust', 'fear', 'happy','sadness','surprise']#每种类别的标签
# 0=anger, 1=contempt, 2=disgust, 3=fear, 4=happy, 5=sadness, 6=surprise
# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2	#数值颜色阈值，如果数值超过这个，就颜色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意这里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()
