Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Have problems in CK+ Accurary #18

Closed
juneweng opened this issue Dec 19, 2018 · 3 comments
Closed

Have problems in CK+ Accurary #18

juneweng opened this issue Dec 19, 2018 · 3 comments

Comments

@juneweng
Copy link

Hello,WuJie!
I run k_fold_train.py with vgg19,but got best_Test_acc: 90.000, in your readme VGG19 Test_acc is 94.646% ,What can I do to increase accuracy.Can you help me?
Thank you in advance.

@juneweng
Copy link
Author

In addition,,Test_acc is mean of 10 folds best_Test_acc or not?

@WuJie1010
Copy link
Owner

May be you can try again or train some fold with loss acc specifically
Yes!Test_acc is mean of 10 folds best_Test_acc

@yuhao910716
Copy link

yuhao910716 commented May 17, 2021

`import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
class_names = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']

'path 是你创建的路径,label是你定义好的类别,因为本人代码都是策contempt表情的,因此选择6,是根据class_names的索引定义的,如果想改动去代码63行处理'

path='CK+48/contempt/'
label=6

path='CK+48/anger/'

path='CK+48/surprise/'

path='CK+48/happy/'

path='test_imgs/'

Creat the list to store the data and label information

path='1_test/'

from PIL import Image
import os
import numpy as np
from models import *
import matplotlib.pyplot as plt
net = VGG('VGG19')
"================================================================================================================================================="
checkpoint = torch.load(os.path.join('trained_model_pt', 'PrivateTest_model.t7'), map_location='cpu')
'数据的加载方法'
net.load_state_dict(checkpoint['net'].state_dict())

*****************************

net.eval()
all_num=0
true_num=0
import cv2

if name == 'main':
import transforms as transforms
print('==> Preparing data..')
cut_size = 48
#先把48*48的数据集转换成4个角落的和中心为44的数据,然后进行测试
transform_test = transforms.Compose([
transforms.TenCrop(cut_size),
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
])
import time
print('start',time.time())
for file in os.listdir(path):
raw_img = cv2.imread(path+file,cv2.IMREAD_GRAYSCALE)
raw_img = cv2.resize(raw_img, (48, 48), interpolation=cv2.INTER_CUBIC)
img = raw_img[:, :, np.newaxis]
img = np.concatenate((img, img, img), axis=2)
img = Image.fromarray(img)
inputs = transform_test(img)
ncrops, c, h, w = np.shape(inputs)
inputs = inputs.view(-1, c, h, w)
inputs = Variable(inputs, volatile=True)
start=time.time()
outputs = net(inputs[0:])
outputs_avg = outputs.view(1, ncrops, -1).mean(1) # avg over crops
score = F.softmax(outputs_avg)
_, predicted = torch.max(outputs_avg.data, 1)
predicted_1 = np.reshape(predicted, -1)

    if predicted_1 - label == 0:
        true_num += 1
    all_num += 1
    if all_num > 1000:
        break


    "============================================================"
    '把横线以下注释掉就能看到效果'
    plt.rcParams['figure.figsize'] = (13.5, 5.5)
    axes = plt.subplot(1, 3, 1)
    # print(np.shape(np.array( inputs[0:][0])))
    array_05 = np.array(inputs[0:][0]).transpose(1, 2, 0)

    plt.imshow(array_05)
    plt.xlabel('Input Image', fontsize=16)
    axes.set_xticks([])
    axes.set_yticks([])
    plt.tight_layout()

    plt.subplots_adjust(left=0.05, bottom=0.2, right=0.95, top=0.9, hspace=0.02, wspace=0.3)
    plt.subplot(1, 3, 2)
    ind = 0.1 + 0.6 * np.arange(len(class_names))  # the x locations for the groups
    list_data=score.data.numpy()
    width = 0.4  # the width of the bars: can also be len(x) sequence
    color_list = ['red', 'orangered', 'darkorange', 'limegreen', 'darkgreen', 'royalblue', 'navy']
    plt.bar([1,2,3,4,5,6,7], list_data[0], 1,width, color=color_list)
    plt.title("Classification results ", fontsize=20)
    plt.xlabel(" Expression Category ", fontsize=16)
    plt.ylabel(" Classification Score ", fontsize=16)
    plt.xticks([1,2,3,4,5,6,7], class_names, rotation=45, fontsize=14)
    plt.show()
    '把这里打开就就能保存相关的图片02'
    # plt.savefig(os.path.join('images/results/{}.png'.format(batch_idx)))
    plt.close()
print("判断正确的个数:",true_num)
print("总共判断的个数:",all_num)

`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants