-
Notifications
You must be signed in to change notification settings - Fork 1
/
classifierSVM_GA.m
137 lines (121 loc) · 4.66 KB
/
classifierSVM_GA.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
% 带有GA参数寻优的SVM分类
% 课题实验2的数据分类
clc;
clear;
global trainData trainLabel
floderPath='C:\Users\zhj\Desktop\zhj20161219\无归一化预处理后,overlap为100,len为300';
fullPath = fullfile(floderPath,'*.mat');
dirout=dir(fullPath);
num=length(dirout);
repeatNum=2;
%
% funName={'feature_MAV','feature_WL','feature_ZC','feature_SSC'};
% funName={'feature_RMS','feature_AR5'};
% funName={'feature_SE','feature_WL','feature_CC5','feature_AR5'};
% funName={'feature_WT_WL'};
% funName={'feature_DFT_MAV2'};
% funName={'feature_TDPSD'};
% funName={'feature_DFT_MAV2','feature_DFT_DASDV','feature_WT_LOG','feature_WAMP'};
FunName.a={'feature_MAV','feature_WL','feature_ZC','feature_SSC'};
FunName.b={'feature_RMS','feature_AR5'};
FunName.c={'feature_SE','feature_WL','feature_CC5','feature_AR5'};
FunName.d={'feature_WT_WL'};
FunName.e={'feature_DFT_MAV2'};
FunName.f={'feature_TDPSD'};
FunName.g={'feature_DFT_MAV2','feature_DFT_DASDV','feature_WT_LOG','feature_WAMP'};
TotalAcc=[];%用于记录所有特征的最后结果
for featKindNum=97:97+6
funName=eval(['FunName.',char(featKindNum)]);
%% 把数据分为左手和右手
rowLeft=0;
rowRight=0;
leftDataName=cell(num/(2*repeatNum),1);
rightDataName=cell(num/(2*repeatNum),1);
for i=1:repeatNum:num
if mod(i,2*repeatNum)==1
rowLeft=rowLeft+1;
for j=0:repeatNum-1
leftDataName(rowLeft,1)=strcat(leftDataName(rowLeft,1),dirout(i+j).name(1:2));
end
else
rowRight=rowRight+1;
for j=0:repeatNum-1
rightDataName(rowRight,1)=strcat(rightDataName(rowRight,1),dirout(i+j).name(1:2));
end
end
end
%% 左手数据为训练集,右手数据为测试集
allDataName=[leftDataName;rightDataName];
Acc4=[];%用于保存所有的正确率
for i=1:length(allDataName)
dataTrainName=cell2mat(allDataName(i));%确定训练数据
trainData=[];
for k=1:length(funName)%加载训练数据
FUN=cell2mat(funName(k));
if exist(['C:\Users\zhj\Desktop\zhj20161219\特征保存\',FUN,'-',dataTrainName,'.mat'],'file')%是否已经提取过该特征
load(['C:\Users\zhj\Desktop\zhj20161219\特征保存\',FUN,'-',dataTrainName,'.mat']);
trainData=cat(2,trainData,featSaved(:,1:end-1));
train_label=featSaved(:,end);%label都是一样的
else
[train_data,train_label]=loadData(floderPath,dataTrainName);
trainLen=size(train_data,3);
trainDataTemp=[];
load(['C:\Users\zhj\Desktop\zhj20161219\特征保存\',FUN,'-',dataTrainName,'-thresh.mat']);
for n=1:trainLen
trainDataTemp=cat(1,trainDataTemp,feval(FUN,train_data(:,:,n)',thresh));
end
trainData=cat(2,trainData,trainDataTemp);
end
end
trainLabel=train_label;
trainData=real(trainData);
trainData=mapminmax(trainData',0,5)';%归一化
%% 根据训练集进行参数寻优
best_CG=GA_SVM;
bestC=best_CG(1);
bestG=best_CG(2);
evalin('base' ,['C_G.',[char(featKindNum),'_',dataTrainName],'=[bestC,bestG]']);
%% 确定测试数据并加载
if i<=length(leftDataName)%确定测试数据
startIndex=length(leftDataName)+1;
endIndex=length(allDataName);
else
startIndex=1;
endIndex=length(leftDataName);
end
for j=startIndex:endIndex
dataTestName=cell2mat(allDataName(j));
testData=[];
for k=1:length(funName)%加载测试数据
FUN=cell2mat(funName(k));
if exist(['C:\Users\zhj\Desktop\zhj20161219\特征保存\',FUN,'-',dataTestName,'.mat'],'file')%是否已经提取过该特征
load(['C:\Users\zhj\Desktop\zhj20161219\特征保存\',FUN,'-',dataTestName,'.mat']);
testData=cat(2,testData,featSaved(:,1:end-1));
test_label=featSaved(:,end);%label都是一样的
else
[test_data,test_label]=loadData(floderPath,dataTestName);
testLen=size(test_data,3);
testDataTemp=[];
load(['C:\Users\zhj\Desktop\zhj20161219\特征保存\',FUN,'-',dataTestName,'-thresh.mat']);
for n=1:testLen
testDataTemp=cat(1,testDataTemp,feval(FUN,test_data(:,:,n)',thresh));
end
testData=cat(2,testData,testDataTemp);
end
end
testLabel=test_label;
testData=real(testData);
testData=mapminmax(testData',0,5)';%归一化
%% 分类
%------SVM------%
cmd=['-c ',num2str(bestC),' -g ',num2str(bestG)];
M=libsvmtrain(trainLabel,trainData,cmd);
[~,acc,~]=libsvmpredict(testLabel,testData,M);
accuracy4=acc(1);
%% 保存所有的正确率
Acc4=cat(1,Acc4,accuracy4);
end
end
AverAcc4=mean(Acc4);
TotalAcc=cat(1,TotalAcc,AverAcc4);
end