In [1]:
#include <iostream>
#include <vector>
#include <sstream>
#include <fstream>
#include <cmath>
#include <cstdlib>

using namespace std;

In [2]:
struct Data {
    vector<double> features;
    int label;
    Data(vector<double> f, int l) : features(f), label(l)
    {}
};
struct Param {
    vector<double> wtSet;
};


class LR {
public:
    void train();
    void predict();
    int loadModel();
    int storeModel();
    LR(string trainFile, string testFile, string predictOutFile);

private:
    vector<Data> trainDataSet;
    vector<Data> testDataSet;
    vector<int> predictVec;
    Param param;
    string trainFile;
    string testFile;
    string predictOutFile;
    string weightParamFile = "modelweight.txt";

private:
    bool init();
    bool loadTrainData();
    bool loadTestData();
    int storePredict(vector<int> &predict);
    void initParam();
    double wxbCalc(const Data &data);
    double sigmoidCalc(const double wxb);
    double lossCal();
    double gradientSlope(const vector<Data> &dataSet, int index, const vector<double> &sigmoidVec);

private:
    int featuresNum;
    const double wtInitV = 1.0;
    const double stepSize = 0.1;
    const int maxIterTimes = 3000;
    const double predictTrueThresh = 0.5;
    const int train_show_step = 10;
};

In [3]:
LR::LR(string trainF, string testF, string predictOutF)
{
    trainFile = trainF;
    testFile = testF;
    predictOutFile = predictOutF;
    featuresNum = 0;
    init();
}

In [4]:
bool LR::loadTrainData()
{
    ifstream infile(trainFile.c_str());
    string line;

    if (!infile) {
        cout << "打开训练文件失败" << endl;
        exit(0);
    }

    while (infile) {
        getline(infile, line);
        if (line.size() > featuresNum) {
            stringstream sin(line);
            char ch;
            double dataV;
            int i;
            vector<double> feature;
            i = 0;

            while (sin) {
                char c = sin.peek();
                if (int(c) != -1) {
                    sin >> dataV;
                    feature.push_back(dataV);
                    sin >> ch;
                    i++;
                } else {
                    cout << "训练文件数据格式不正确，出错行为" << (trainDataSet.size() + 1) << "行" << endl;
                    return false;
                }
            }
            int ftf;
            ftf = (int)feature.back();
            feature.pop_back();
            trainDataSet.push_back(Data(feature, ftf));
        }
    }
    infile.close();
    return true;
}

In [5]:
void LR::initParam()
{
    int i;
    for (i = 0; i < featuresNum; i++) {
        param.wtSet.push_back(wtInitV);
    }
}

In [6]:
bool LR::init()
{
    trainDataSet.clear();
    bool status = loadTrainData();
    if (status != true) {
        return false;
    }
    featuresNum = trainDataSet[0].features.size();
    param.wtSet.clear();
    initParam();
    return true;
}

In [7]:
double LR::wxbCalc(const Data &data)
{
    double mulSum = 0.0L;
    int i;
    double wtv, feav;
    for (i = 0; i < param.wtSet.size(); i++) {
        wtv = param.wtSet[i];
        feav = data.features[i];
        mulSum += wtv * feav;
    }

    return mulSum;
}

In [8]:
inline double LR::sigmoidCalc(const double wxb)
{
    double expv = exp(-1 * wxb);
    double expvInv = 1 / (1 + expv);
    return expvInv;
}

In [9]:
double LR::lossCal()
{
    double lossV = 0.0L;
    int i;

    for (i = 0; i < trainDataSet.size(); i++) {
        lossV -= trainDataSet[i].label * log(sigmoidCalc(wxbCalc(trainDataSet[i])));
        lossV -= (1 - trainDataSet[i].label) * log(1 - sigmoidCalc(wxbCalc(trainDataSet[i])));
    }
    lossV /= trainDataSet.size();
    return lossV;
}

In [10]:
double LR::gradientSlope(const vector<Data> &dataSet, int index, const vector<double> &sigmoidVec)
{
    double gsV = 0.0L;
    int i;
    double sigv, label;
    for (i = 0; i < dataSet.size(); i++) {
        sigv = sigmoidVec[i];
        label = dataSet[i].label;
        gsV += (label - sigv) * (dataSet[i].features[index]);
    }

    gsV = gsV / dataSet.size();
    return gsV;
}

In [11]:
void LR::train()
{
    double sigmoidVal;
    double wxbVal;
    int i, j;

    for (i = 0; i < maxIterTimes; i++) {
        vector<double> sigmoidVec;

        for (j = 0; j < trainDataSet.size(); j++) {
            wxbVal = wxbCalc(trainDataSet[j]);
            sigmoidVal = sigmoidCalc(wxbVal);
            sigmoidVec.push_back(sigmoidVal);
        }

        for (j = 0; j < param.wtSet.size(); j++) {
            param.wtSet[j] += stepSize * gradientSlope(trainDataSet, j, sigmoidVec);
        }

        if (i % train_show_step == 0) {
            cout << "iter " << i << ". updated weight value is : ";
            for (j = 0; j < param.wtSet.size(); j++) {
                cout << param.wtSet[j] << "  ";
            }
            cout << endl;
        }
    }
}

In [12]:
void LR::predict()
{
    double sigVal;
    int predictVal;

    loadTestData();
    for (int j = 0; j < testDataSet.size(); j++) {
        sigVal = sigmoidCalc(wxbCalc(testDataSet[j]));
        predictVal = sigVal >= predictTrueThresh ? 1 : 0;
        predictVec.push_back(predictVal);
    }

    storePredict(predictVec);
}

In [13]:
int LR::loadModel()
{
    string line;
    int i;
    vector<double> wtTmp;
    double dbt;

    ifstream fin(weightParamFile.c_str());
    if (!fin) {
        cout << "打开模型参数文件失败" << endl;
        exit(0);
    }

    getline(fin, line);
    stringstream sin(line);
    for (i = 0; i < featuresNum; i++) {
        char c = sin.peek();
        if (c == -1) {
            cout << "模型参数数量少于特征数量，退出" << endl;
            return -1;
        }
        sin >> dbt;
        wtTmp.push_back(dbt);
    }
    param.wtSet.swap(wtTmp);
    fin.close();
    return 0;
}

In [14]:
int LR::storeModel()
{
    string line;
    int i;

    ofstream fout(weightParamFile.c_str());
    if (!fout.is_open()) {
        cout << "打开模型参数文件失败" << endl;
    }
    if (param.wtSet.size() < featuresNum) {
        cout << "wtSet size is " << param.wtSet.size() << endl;
    }
    for (i = 0; i < featuresNum; i++) {
        fout << param.wtSet[i] << " ";
    }
    fout.close();
    return 0;
}

In [15]:
bool LR::loadTestData()
{
    ifstream infile(testFile.c_str());
    string lineTitle;

    if (!infile) {
        cout << "打开测试文件失败" << endl;
        exit(0);
    }

    while (infile) {
        vector<double> feature;
        string line;
        getline(infile, line);
        if (line.size() > featuresNum) {
            stringstream sin(line);
            double dataV;
            int i;
            char ch;
            i = 0;
            while (i < featuresNum && sin) {
                char c = sin.peek();
                if (int(c) != -1) {
                    sin >> dataV;
                    feature.push_back(dataV);
                    sin >> ch;
                    i++;
                } else {
                    cout << "测试文件数据格式不正确" << endl;
                    return false;
                }
            }
            testDataSet.push_back(Data(feature, 0));
        }
    }

    infile.close();
    return true;
}

In [16]:
bool loadAnswerData(string awFile, vector<int> &awVec)
{
    ifstream infile(awFile.c_str());
    if (!infile) {
        cout << "打开答案文件失败" << endl;
        exit(0);
    }

    while (infile) {
        string line;
        int aw;
        getline(infile, line);
        if (line.size() > 0) {
            stringstream sin(line);
            sin >> aw;
            awVec.push_back(aw);
        }
    }

    infile.close();
    return true;
}

In [17]:
int LR::storePredict(vector<int> &predict)
{
    string line;
    int i;

    ofstream fout(predictOutFile.c_str());
    if (!fout.is_open()) {
        cout << "打开预测结果文件失败" << endl;
    }
    for (i = 0; i < predict.size(); i++) {
        fout << predict[i] << endl;
    }
    fout.close();
    return 0;
}

int main(int argc, char *argv[])
{
    vector<int> answerVec;
    vector<int> predictVec;
    int correctCount;
    double accurate;
    string trainFile = "../data/train_data.txt";
    string testFile = "../data/test_data.txt";
    string predictFile = "../projects/student/result.txt";

    string answerFile = "../projects/student/answer.txt";

    LR logist(trainFile, testFile, predictFile);

    cout << "ready to train model" << endl;
    logist.train();

    cout << "training ends, ready to store the model" << endl;
    logist.storeModel();

#ifdef TEST
    cout << "ready to load answer data" << endl;
    loadAnswerData(answerFile, answerVec);
#endif

    cout << "let's have a prediction test" << endl;
    logist.predict();

#ifdef TEST
    loadAnswerData(predictFile, predictVec);
    cout << "test data set size is " << predictVec.size() << endl;
    correctCount = 0;
    for (int j = 0; j < predictVec.size(); j++) {
        if (j < answerVec.size()) {
            if (answerVec[j] == predictVec[j]) {
                correctCount++;
            }
        } else {
            cout << "answer size less than the real predicted value" << endl;
        }
    }

    accurate = ((double)correctCount) / answerVec.size();
    cout << "the prediction accuracy is " << accurate << endl;
#endif

    return 0;
}

In [None]:
//int main(int argc, char *argv[])
//{
    vector<int> answerVec;
    vector<int> predictVec;
    int correctCount;
    double accurate;
    string trainFile = "../data/train_data.txt";
    string testFile = "../data/test_data.txt";
    string predictFile = "../projects/student/result.txt";

    string answerFile = "../projects/student/answer.txt";

    LR logist(trainFile, testFile, predictFile);

    cout << "ready to train model" << endl;
    logist.train();

    cout << "training ends, ready to store the model" << endl;
    logist.storeModel();

//#ifdef TEST
    cout << "ready to load answer data" << endl;
    loadAnswerData(answerFile, answerVec);
//#endif

    cout << "let's have a prediction test" << endl;
    logist.predict();

ready to train model
iter 0. updated weight value is : 0.967873  0.968336  0.968356  0.967541  0.968286  0.968274  0.968886  0.968412  0.969361  0.968498  0.968197  0.968495  0.968504  0.968776  0.968453  0.968952  0.968098  0.968281  0.968818  0.968143  0.968228  0.968872  0.968542  0.968298  0.969315  0.968767  0.968364  0.968758  0.968303  0.968146  0.967578  0.968261  0.968827  0.968515  0.968633  0.968942  0.968571  0.968723  0.968415  0.96833  0.968069  0.968358  0.96785  0.968302  0.968066  0.968722  0.968461  0.968979  0.969069  0.968253  0.968669  0.968712  0.967869  0.968634  0.968692  0.96851  0.968098  0.968337  0.968466  0.96826  0.968629  0.968561  0.968098  0.967853  0.968061  0.968548  0.96873  0.96796  0.968395  0.969225  0.968769  0.968127  0.968709  0.968344  0.96887  0.968487  0.967775  0.968414  0.968565  0.967704  0.968784  0.968553  0.968253  0.968637  0.968212  0.968335  0.967992  0.969407  0.969122  0.968168  0.968237  0.968632  0.969195  0.967549  0.968856  0.

iter 10. updated weight value is : 0.646601  0.651691  0.651911  0.642953  0.651143  0.651009  0.657746  0.652527  0.662966  0.653479  0.650163  0.653449  0.653544  0.656531  0.652983  0.658468  0.64908  0.651088  0.657  0.649576  0.650503  0.657597  0.653967  0.651283  0.662461  0.656432  0.652004  0.656338  0.65133  0.649604  0.643362  0.650874  0.657095  0.653669  0.654964  0.658361  0.654277  0.65595  0.652564  0.651625  0.648754  0.651935  0.646348  0.651322  0.648727  0.655938  0.653074  0.658764  0.659759  0.650785  0.655354  0.655835  0.646557  0.654973  0.65561  0.653608  0.649075  0.651712  0.653127  0.650864  0.654916  0.654171  0.649077  0.646387  0.648671  0.654027  0.656034  0.647562  0.652346  0.661476  0.656458  0.649401  0.655802  0.651789  0.657575  0.653357  0.64552  0.652557  0.654213  0.644741  0.656619  0.654078  0.650782  0.65501  0.650331  0.651684  0.647914  0.663478  0.66034  0.649846  0.650603  0.654948  0.661144  0.643039  0.657415  0.651506  0.655257  0.649

iter 20. updated weight value is : 0.325328  0.335046  0.335466  0.318364  0.334  0.333744  0.346606  0.336642  0.356571  0.33846  0.33213  0.338402  0.338585  0.344287  0.337513  0.347984  0.330063  0.333896  0.345181  0.331009  0.332778  0.346322  0.339391  0.334268  0.355607  0.344098  0.335645  0.343918  0.334358  0.331061  0.319145  0.333487  0.345364  0.338822  0.341295  0.34778  0.339984  0.343177  0.336714  0.33492  0.32944  0.335512  0.324845  0.334342  0.329388  0.343155  0.337687  0.34855  0.350449  0.333316  0.34204  0.342959  0.325244  0.341312  0.342529  0.338706  0.330052  0.335087  0.337788  0.333468  0.341204  0.339782  0.330056  0.32492  0.32928  0.339507  0.343338  0.327165  0.336296  0.353726  0.344147  0.330674  0.342895  0.335233  0.34628  0.338226  0.323265  0.3367  0.339862  0.321778  0.344454  0.339604  0.333311  0.341383  0.332451  0.335034  0.327835  0.357549  0.351558  0.331524  0.33297  0.341265  0.353093  0.318529  0.345974  0.334693  0.341854  0.330275  0

iter 30. updated weight value is : 0.00405603  0.0184017  0.0190213  -0.00622435  0.0168571  0.0164793  0.0354664  0.0207569  0.0501767  0.0234407  0.0140969  0.0233551  0.0236248  0.0320432  0.0220434  0.0374996  0.011045  0.0167033  0.0333623  0.0124423  0.0150533  0.0350467  0.0248152  0.0172531  0.0487526  0.0317631  0.0192848  0.0314984  0.0173853  0.012519  -0.00507115  0.0161003  0.0336327  0.0239751  0.0276265  0.0371989  0.0256909  0.0304041  0.0208635  0.0182149  0.0101251  0.0190895  0.00334303  0.017362  0.0100487  0.0303715  0.0223003  0.0383354  0.0411386  0.015848  0.0287262  0.0300817  0.00393203  0.0276509  0.029447  0.0238046  0.0110295  0.0184617  0.0224483  0.0160716  0.0274913  0.0253918  0.0110357  0.00345346  0.00989023  0.0249865  0.0306416  0.00676698  0.0202469  0.0459766  0.0318367  0.0119479  0.0299879  0.0186768  0.0349843  0.0230962  0.00101066  0.0208425  0.02551  -0.00118414  0.0322897  0.0251298  0.0158399  0.0277555  0.0145705  0.0183831  0.00775665  0

iter 40. updated weight value is : -0.022247  -0.00870085  -0.00754894  -0.0346827  -0.0125006  -0.0107809  0.0143233  -0.00516524  0.0311179  -0.000141535  -0.00996363  -0.0043291  -0.00160106  0.00992375  -0.00426745  0.0164813  -0.0160612  -0.00740718  0.0120604  -0.0174122  -0.0101406  0.010776  -0.000714359  -0.00958364  0.0314476  0.0102432  -0.00700216  0.0120064  -0.00783683  -0.0141059  -0.0351713  -0.010434  0.0114229  0.0025271  0.00329757  0.0167657  0.00321904  0.0041336  -0.00622621  -0.007336  -0.0178073  -0.00748641  -0.0276759  -0.0112338  -0.0193514  0.00444324  -0.00422857  0.0163254  0.0193408  -0.0103825  0.00618393  0.0058171  -0.0209536  0.00293969  0.00385361  -0.00299332  -0.0175665  -0.00469712  -0.000510363  -0.0133143  0.00591804  -0.000972825  -0.0164315  -0.0266789  -0.01965  0.00378877  0.00928682  -0.0217377  -0.004999  0.0226082  0.00922001  -0.0191877  0.00878543  -0.00612918  0.0156857  -0.00294665  -0.0275325  -0.00595046  0.00206587  -0.0294863  0.0

iter 50. updated weight value is : -0.0105926  0.00143795  0.00317693  -0.0249961  -0.0048162  -0.000720823  0.0304136  0.00621838  0.0489156  0.0137372  0.00377174  0.00494597  0.0104364  0.0250764  0.00659648  0.0326418  -0.00561885  0.00616928  0.028043  -0.0100818  0.00225923  0.0233648  0.0109106  0.000908209  0.0512736  0.0260569  0.00405106  0.0301098  0.00445876  -0.00317169  -0.0273912  0.000466839  0.026395  0.0187626  0.0161673  0.0336088  0.0182164  0.0146997  0.00383582  0.00454161  -0.00823099  0.00320919  -0.0212796  -0.00270462  -0.011427  0.0153845  0.00638885  0.0313267  0.0344593  0.000830233  0.0210151  0.0186694  -0.0077088  0.0153702  0.0152168  0.00726323  -0.00875637  0.00986046  0.0141236  -0.00564781  0.0218828  0.00969714  -0.0063566  -0.0192865  -0.011887  0.020295  0.0253469  -0.0126814  0.00713201  0.035768  0.023826  -0.0132782  0.0250602  0.00655927  0.0338504  0.00816681  -0.0182511  0.00443292  0.016016  -0.0198445  0.0187266  0.0198472  -0.00304813  0

iter 60. updated weight value is : -0.0410857  -0.0297746  -0.0275041  -0.0576198  -0.0382346  -0.0320806  0.00507463  -0.0238175  0.0256716  -0.0140178  -0.0244884  -0.0268174  -0.0189142  -0.00121199  -0.0238366  0.00744958  -0.0368767  -0.0221238  0.00252351  -0.0439807  -0.0271302  -0.00503549  -0.0187506  -0.0300512  0.0297387  0.000333612  -0.0263291  0.00636183  -0.0249268  -0.0339386  -0.0616292  -0.030183  8.29367e-06  -0.00692245  -0.0122918  0.00899843  -0.00846932  -0.0156503  -0.0273433  -0.0251786  -0.0402783  -0.0274949  -0.0563761  -0.0353654  -0.0449111  -0.014644  -0.0242497  0.00516007  0.00851337  -0.0295555  -0.00569629  -0.00970261  -0.0368243  -0.0134718  -0.0144776  -0.0236287  -0.0414361  -0.017507  -0.0130138  -0.0391103  -0.00389297  -0.020773  -0.037932  -0.0535089  -0.0455192  -0.00510905  -0.000215427  -0.0453169  -0.0222814  0.00831735  -0.00295156  -0.0484393  -0.000344714  -0.0224302  0.0102948  -0.0220131  -0.0509315  -0.0264706  -0.0115889  -0.052302 

In [None]:
//#ifdef TEST
    loadAnswerData(predictFile, predictVec);
    cout << "test data set size is " << predictVec.size() << endl;
    correctCount = 0;

In [None]:
    for (int j = 0; j < predictVec.size(); j++) {
        if (j < answerVec.size()) {
            if (answerVec[j] == predictVec[j]) {
                correctCount++;
            }
        } else {
            cout << "answer size less than the real predicted value" << endl;
        }
    }

In [None]:
    accurate = ((double)correctCount) / answerVec.size();
    cout << "the prediction accuracy is " << accurate << endl;
//#endif

//    return 0;
//}