<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
package com.huawei.flclient.model;

import com.mindspore.lite.MSTensor;
import com.mindspore.lite.TrainSession;
import com.mindspore.lite.config.MSConfig;
import mindspore.schema.FeatureMap;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class SessionUtil {
    public static Map<String,float[]> convertTensorTofeatures(List<MSTensor> tensors) {
        Map<String,float[]> features = new HashMap<>(tensors.size());
        for(MSTensor mstensor:tensors) {
            features.put(mstensor.tensorName(),mstensor.getFloatData());
        }
        return features;
    }
    public static List<MSTensor> getFeatures(TrainSession trainSession) {
        List<MSTensor> featuresMap= trainSession.getFeaturesMap();
        for(int i=0;i<5;i++){
            MSTensor feature = featuresMap.get(i);
            float[] data= feature.getFloatData();
            String name = feature.tensorName();
            int elements = feature.elementsNum();
            System.out.println("tensorname:"+name+",len,"+elements+","+data[0]);
        }
        return featuresMap;
    }
    public static int updateFeatures(TrainSession trainSession, String modelName, List<FeatureMap> featureMaps) {
        List<MSTensor> tensors = new ArrayList<>(featureMaps.size());
        for (FeatureMap newFeature:featureMaps) {
           ByteBuffer by = newFeature.dataAsByteBuffer();
           ByteBuffer newData = ByteBuffer.allocateDirect(by.remaining());
           newData.order(ByteOrder.nativeOrder());
           newData.put(by);
           tensors.add(new MSTensor(newFeature.weightFullname(),newData));
        }
        trainSession.updateFeatures(modelName,tensors);
        return 0;
    }

    public static TrainSession initSession(String modelPath) {
        MSConfig msConfig = new MSConfig();
        // arg 0: DeviceType:DT_CPU -> 0
        // arg 1: ThreadNum -> 2
        // arg 2: cpuBindMode:NO_BIND ->  0
        // arg 3: enable_fp16 -> false
        msConfig.init(0, 1, 0, false);
        TrainSession trainSession = new TrainSession();
        boolean status = trainSession.init(modelPath, msConfig);
        if(!status) {
            System.out.println("init session failed,"+modelPath);
            return null;
        }
        trainSession.setLearningRate(0.01f);
        return trainSession;
    }
    public static  MSTensor searchOutputsForSize(TrainSession trainSession,int size) {
        Map<String, MSTensor> outputs = trainSession.getOutputMapByTensor();
        for (MSTensor tensor : outputs.values()) {
            if (tensor.elementsNum() == size) {
                return tensor;
            }
        }
        System.err.println("can not find output the tensor which element num is " + size);
        return null;
    }
}


In [None]:
package com.huawei.flclient.model;

import com.mindspore.lite.MSTensor;
import com.mindspore.lite.TrainSession;

import java.io.*;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.List;

public class AdTrainBert {
    private int batchSize;
    private int batchNum;
    private int dataSize;
    private TrainSession trainSession;
    private List<Feature> features;
    private ByteBuffer inputIdBufffer;
    private ByteBuffer tokenIdBufffer;
    private ByteBuffer maskIdBufffer;
    private ByteBuffer labelIdBufffer;

    static {
        System.loadLibrary("mindspore-lite-jni");
    }

    private int initSessionAndInputs(String modelPath) {
        trainSession = SessionUtil.initSession(modelPath);
        if(trainSession == null) {
            return -1;
        }
        List<MSTensor> inputs = trainSession.getInputs();
        MSTensor labelIdTensor = inputs.get(0);
        int inputLength = labelIdTensor.elementsNum(); // labelId,tokenId,inputId,maskId has same size
        if(batchSize <= 0) {
            System.out.println("batch size need more than 0");
            return -1;
        }
        dataSize = inputLength / batchSize;
        inputIdBufffer = ByteBuffer.allocateDirect(inputLength * Integer.BYTES);
        tokenIdBufffer = ByteBuffer.allocateDirect(inputLength * Integer.BYTES);
        maskIdBufffer = ByteBuffer.allocateDirect(inputLength * Integer.BYTES);
        labelIdBufffer = ByteBuffer.allocateDirect(inputLength * Integer.BYTES);
        inputIdBufffer.order(ByteOrder.nativeOrder());
        tokenIdBufffer.order(ByteOrder.nativeOrder());
        maskIdBufffer.order(ByteOrder.nativeOrder());
        labelIdBufffer.order(ByteOrder.nativeOrder());
        return 0;
    }
    private void fillAdTrainBertInput(int batchIdx) {
        inputIdBufffer.clear();
        tokenIdBufffer.clear();
        maskIdBufffer.clear();
        labelIdBufffer.clear();
        for (int i = 0; i < batchSize; i++) {
            Feature feature = features.get(batchIdx * batchSize + i);
            for (int j = 0; j < dataSize; j++) {
                inputIdBufffer.putInt(feature.inputIds[j]);
            }
            for (int j = 0; j < dataSize; j++) {
                tokenIdBufffer.putInt(feature.tokenIds[j]);
            }
            for (int j = 0; j < dataSize; j++) {
                maskIdBufffer.putInt(feature.inputMasks[j]);
            }
            for (int j = 0; j < dataSize; j++) {
                labelIdBufffer.putInt(feature.inputIds[j]);
            }
        }
        List<MSTensor> inputs = trainSession.getInputs();
        MSTensor labelIdTensor = inputs.get(0);
        MSTensor tokenIdTensor = inputs.get(1);
        MSTensor inputIdTensor = inputs.get(2);
        MSTensor maskIdTensor = inputs.get(3);
        labelIdTensor.setData(labelIdBufffer);
        tokenIdTensor.setData(tokenIdBufffer);
        inputIdTensor.setData(inputIdBufffer);
        maskIdTensor.setData(maskIdBufffer);
    }

    private float getLoss() {
        MSTensor tensor = SessionUtil.searchOutputsForSize(trainSession,1);
        if(tensor == null) {
            System.err.println("cannot find loss tensor");
            return Float.NaN;
        }
        return tensor.getFloatData()[0];
    }
    private int trainLoop(int epoches) {
        trainSession.train();
        long startTime=System.currentTimeMillis();
        for (int i = 0; i < epoches; i++) {
            float sumLossPerEpoch = 0.0f;
            for (int j = 0; j < batchNum; j++) {
                fillAdTrainBertInput(j);
                trainSession.runGraph();
                float loss = getLoss();
                if(Float.isNaN(loss)) {
                    System.out.println("loss is nan");
                    return -1;
                }
                sumLossPerEpoch += loss;
                System.out.println("------batch:"+j+",loss:"+loss+"-----------");
            }
            System.out.println("----------epoch:" + i + ",mean loss:" + sumLossPerEpoch / batchNum + "----------");
            long endTime=System.currentTimeMillis();
            System.out.println("total train time "+(endTime-startTime)+"ms");
        }
        return 0;
    }
    public TrainSession getSession(){
        return trainSession;
    }

    public int initDataSet(String dataFile, String vocabFile,String idsFile, int batchSize) {
        System.out.println("==========Init dataFile,"+dataFile+ ",vocabFile,"+vocabFile+"=============");
        if(batchSize <= 0) {
            System.out.println("batch size need more than 0");
            return -1;
        }
        features = DataSet.init(dataFile, vocabFile,idsFile,true);
        this.batchSize = batchSize;
        batchNum = features.size() / batchSize;
        return features.size();
    }

    public int trainModel(String modelPath,int epoches) {
        System.out.println("==========Loading Model,"+modelPath+" Create Train Session=============");
        int status = initSessionAndInputs(modelPath);
        if(status== -1) {
            System.out.println("init session and inputs failed");
            return -1;
        }
        System.out.println("==========Begin Train Model=============");
        status = trainLoop(epoches);
        if(status== -1) {
            System.out.println("train loop failed");
            return -1;
        }
        if (epoches > 0) {
            trainSession.saveToFile(modelPath);
        }

        return 0;
    }
    public void free() {
        trainSession.free();
    }

    public static void main(String[] args) throws IOException {
        AdTrainBert adTrainBert = new AdTrainBert();
        String dataFile = "/home/meng/zj10/mindspore/mindspore/lite/101.txt";
        String vocabFile = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/main/native/dataset/vocab.txt";
        String idsFile = "/home/meng/zj10/fl/mindspore/mindspore/lite/vocab_map_ids.txt";
        String modelPath = "/home/meng/zj10/fl/mindspore/mindspore/lite/albert_ad_train_new.mindir.ms";
        int epoches = 1;
        int batchSize = 16;
        adTrainBert.initDataSet(dataFile,vocabFile,idsFile,batchSize);
        adTrainBert.trainModel(modelPath,epoches);
    }
}




In [None]:
package com.huawei.flclient.model;

import com.mindspore.lite.MSTensor;
import com.mindspore.lite.TrainSession;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;

public class AdInferBert {
    private int batchSize;
    private int batchNum;
    private int dataSize;
    private TrainSession trainSession;
    private List<Feature> features;
    private ByteBuffer inputIdBufffer;
    private ByteBuffer tokenIdBufffer;
    private ByteBuffer maskIdBufffer;

    static {
        System.loadLibrary("mindspore-lite-jni");
    }
    private int initDataSet(String exampleFile, String vocabFile, String idsFile,int batchSize) {
        if(batchSize <= 0) {
            System.err.println("batch size need more than 0");
            return -1;
        }
        features = DataSet.init(exampleFile, vocabFile,idsFile,false);
        this.batchSize = batchSize;
        batchNum = features.size() / batchSize;
        return features.size();
    }
    private int initSessionAndInputs(String modelPath) {
        trainSession = SessionUtil.initSession(modelPath);
        if(trainSession == null) {
            return -1;
        }
        List<MSTensor> inputs = trainSession.getInputs();
        MSTensor tokenIdTensor = inputs.get(0);
        int inputLength = tokenIdTensor.elementsNum(); // labelId,tokenId,inputId,maskId has same size
        if(batchSize <= 0) {
            System.out.println("batch size need more than 0");
            return -1;
        }
        dataSize = inputLength / batchSize;
        inputIdBufffer = ByteBuffer.allocateDirect(inputLength * Integer.BYTES);
        tokenIdBufffer = ByteBuffer.allocateDirect(inputLength * Integer.BYTES);
        maskIdBufffer = ByteBuffer.allocateDirect(inputLength * Integer.BYTES);
        inputIdBufffer.order(ByteOrder.nativeOrder());
        tokenIdBufffer.order(ByteOrder.nativeOrder());
        maskIdBufffer.order(ByteOrder.nativeOrder());
        return 0;
    }
    private List<Integer> fillAdInferBertInput(int batchIdx) {
        inputIdBufffer.clear();
        tokenIdBufffer.clear();
        maskIdBufffer.clear();
        List<Integer> labels = new ArrayList<>(batchSize);
        for (int i = 0; i < batchSize; i++) {
            Feature feature = features.get(batchIdx * batchSize + i);
            for (int j = 0; j < dataSize; j++) {
                inputIdBufffer.putInt(feature.inputIds[j]);
            }
            for (int j = 0; j < dataSize; j++) {
                tokenIdBufffer.putInt(feature.tokenIds[j]);
            }
            for (int j = 0; j < dataSize; j++) {
                maskIdBufffer.putInt(feature.inputMasks[j]);
            }
            labels.add(feature.labelIds);
        }
        List<MSTensor> inputs = trainSession.getInputs();

        MSTensor tokenIdTensor = inputs.get(0);
        MSTensor inputIdTensor = inputs.get(1);
        MSTensor maskIdTensor = inputs.get(2);
        tokenIdTensor.setData(tokenIdBufffer);
        inputIdTensor.setData(inputIdBufffer);
        maskIdTensor.setData(maskIdBufffer);
        return labels;
    }

    private float calculateAccracy( List<Integer> labels) {
        int numOfClass = 5;
        MSTensor outputTensor = SessionUtil.searchOutputsForSize(trainSession,batchSize* numOfClass);
        if(outputTensor == null) {
            return Float.NaN;
        }
        float[] scores = outputTensor.getFloatData();
        float accuracy = 0.0f;
        for(int b=0;b<batchSize;b++) {
            int maxIdx = 0;
            float maxScore = scores[numOfClass *b];
            for(int c = 0; c< numOfClass; c++) {
                if(scores[numOfClass *b+c] >maxScore) {
                    maxScore = scores[numOfClass *b+c];
                    maxIdx = c;
                }
            }
            if(labels.get(b) == maxIdx) {
                accuracy+=1;
            }
        }
        return accuracy/batchSize;
    }
    private float infer() {
        trainSession.eval();
        float totalAccuracy = 0.0f;
        for (int j = 0; j < batchNum; j++) {
            List<Integer> labels= fillAdInferBertInput(j);
            trainSession.runGraph();
            float curAcc = calculateAccracy(labels);
            if(Float.isNaN(curAcc)) {
                return Float.NaN;
            }
            totalAccuracy += curAcc;
            System.out.println("batch num:"+j+",acc is:"+curAcc);
        }
        System.out.println("total acc:"+totalAccuracy/batchNum);
        return totalAccuracy/batchNum;
    }

    public TrainSession getSession(){
        return trainSession;
    }

    public float inferModel(String modelPath,String dataFile,String vocabFile,String idsFile,int batchSize)  {
        System.out.println("==========Init dataFile,"+dataFile+ ",vocabFile,"+vocabFile+"=============");
        int inferSize = initDataSet(dataFile,vocabFile,idsFile,batchSize);
        if(inferSize == -1) {
            System.out.println("init dataset failed");
            return -1;
        }
        System.out.println("==========Train size,"+inferSize);
        System.out.println("==========Loading Model,"+modelPath+" Create Train Session=============");
        initSessionAndInputs(modelPath);
        System.out.println("==========Begin Infer Model=============");
        return infer();
    }

    public void free() {
        trainSession.free();
    }

    public static void main(String[] args) throws IOException {
        AdInferBert adInferBert = new AdInferBert();
        String dataFile = "/home/meng/zj10/mindspore/mindspore/lite/eval.txt";
        String vocabFile = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/main/native/dataset/vocab.txt";
        String idsFile = "/home/meng/zj10/fl/mindspore/mindspore/lite/vocab_map_ids.txt";
        String modelPath = "/home/meng/zj10/fl/mindspore/mindspore/lite/albert_ad_infer_new.mindir.ms";
        int batchSize =16;
        adInferBert.inferModel(modelPath,dataFile,vocabFile,idsFile,batchSize);
    }
}


In [None]:
package com.huawei.flclient.model;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;

class Feature {
    int[] inputIds;
    int[] inputMasks;
    int[] tokenIds;
    int labelIds;
    int seqLen;

    public Feature(int[] inputIds, int[] inputMasks, int[] tokenIds, int labelIds, int seqLen) {
        this.inputIds = inputIds;
        this.inputMasks = inputMasks;
        this.tokenIds = tokenIds;
        this.labelIds = labelIds;
        this.seqLen = seqLen;
    }
}

public class CustomTokenizer {
    private Map<String, Integer> vocabs = new HashMap<>();
    private Boolean doLowerCase = Boolean.TRUE;
    private int maxInputChars = 100;
    private String[] NotSplitStrs = {"UNK"};
    private String unkToken = "[UNK]";
    private int maxSeqLen = 16;
    private int vocabSize = 11682;
    private Map<String, Integer> labelMap = new HashMap<String, Integer>() {{
        put("beauty", 0);
        put("education", 1);
        put("hotel", 2);
        put("travel", 3);
        put("other", 4);
    }};


    public void init(String vocabFile,String idsFile ,boolean trainMod,boolean doLowerCase) {
        this.doLowerCase = doLowerCase;
        Path vocabPath = Paths.get(vocabFile);
        List<String> vocabLines = null;
        try {
            vocabLines = Files.readAllLines(vocabPath, StandardCharsets.UTF_8);
        } catch (IOException e) {
            e.printStackTrace();
        }
        Path idsPath = Paths.get(idsFile);
        List<String> idsLines = null;
        try {
            idsLines = Files.readAllLines(idsPath, StandardCharsets.UTF_8);
        } catch (IOException e) {
            e.printStackTrace();
        }
        for (int i=0;i<idsLines.size();++i) {
            vocabs.put(vocabLines.get(i), Integer.parseInt(idsLines.get(i)));
        }
        if(!trainMod) {
            maxSeqLen = 256;
        }
    }

    // is chinses or punctuation
    public Boolean isChineseOrPunc(char trimChar) {
        // is chinese char
        if (trimChar >= '\u4e00' && trimChar <= '\u9fa5') {
            return true;
        }
        // is puncuation char
        if ((trimChar >= 33 && trimChar <= 47) || (trimChar >= 58 && trimChar <= 64) || (trimChar >= 91 && trimChar <= 96) || (trimChar >= 123 && trimChar <= 126)) {
            return true;
        }
        return false;
    }

    public String[] splitText(String text) {
        // clean remove white and control char
        String trimText = text.trim();
        StringBuilder cleanText = new StringBuilder();
        for (int i = 0; i < trimText.length(); i++) {
            if (isChineseOrPunc(trimText.charAt(i))) {
                cleanText.append(" " + trimText.charAt(i) + " ");
            } else {
                cleanText.append(trimText.charAt(i));
            }
        }
        return cleanText.toString().trim().split("\\s+");
    }

    //   input = "unaffable" , output = ["un", "##aff", "##able"]
    public List<String> wordPieceTokenize(String[] tokens) {
        List<String> outputTokens = new ArrayList<>();
        for (String token : tokens) {
            List<String> subTokens = new ArrayList<>();
            boolean isBad = false;
            int start = 0;
            while (start < token.length()) {
                int end = token.length();
                String curStr = "";
                while (start < end) {
                    String subStr = token.substring(start, end);
                    if (start > 0) {
                        subStr = "##" + subStr;
                    }
                    if (vocabs.get(subStr) != null) {
                        curStr = subStr;
                        break;
                    }
                    end = end - 1;
                }
                if (curStr.isEmpty()) {
                    isBad = true;
                    break;
                }
                subTokens.add(curStr);
                start = end;
            }
            if (isBad) {
                outputTokens.add(unkToken);
            } else {
                outputTokens.addAll(subTokens);
            }
        }
        return outputTokens;

    }

    public List<Integer> convertTokensToIds(List<String> tokens,boolean cycTrunc) {
        int seqLen = tokens.size();
        if (tokens.size() > maxSeqLen - 2) {
            if(cycTrunc) {
               int randIndex = (int) (Math.random() * seqLen);
               if(randIndex>seqLen-maxSeqLen+2) {
                   List<String> rearPart = tokens.subList(randIndex,seqLen);
                   List<String> frontPart = tokens.subList(0,randIndex+maxSeqLen-2-seqLen);
                   rearPart.addAll(frontPart);
                   tokens = rearPart;
               } else {
                   tokens = tokens.subList(randIndex,randIndex+maxSeqLen-2);
               }
            } else {
                tokens = tokens.subList(0, maxSeqLen - 2);
            }
        }
        tokens.add(0, "[CLS]");
        tokens.add("[SEP]");
        List<Integer> ids = new ArrayList<>(tokens.size());
        for (String token : tokens) {
            ids.add(vocabs.getOrDefault(token, vocabs.get("[UNK]")));
        }
        return ids;
    }

    public void addRandomMaskAndReplace(Feature feature, boolean keepFirstUnchange, boolean keepLastUnchange) {
        int[] masks = new int[maxSeqLen];
        Arrays.fill(masks, 1);
        int[] replaces = new int[maxSeqLen];
        Arrays.fill(replaces, 1);
        int[] inputIds = feature.inputIds;
        for (int i = 0; i < feature.seqLen; i++) {
            double rand1 = Math.random();
            if (rand1 < 0.15) {
                masks[i] = 0;
                double rand2 = Math.random();
                if (rand2 < 0.8) {
                    replaces[i] = 103;
                } else if (rand2 < 0.9) {
                    masks[i] = 1;
                } else {
                    replaces[i] = (int) (Math.random() * vocabSize);
                }
            }
            if(keepFirstUnchange) {
                masks[i] = 1;
                replaces[i] = 0;
            }
            if(keepLastUnchange) {
                masks[feature.seqLen-1] = 1;
                replaces[feature.seqLen-1] =0;
            }
            inputIds[i] = inputIds[i] * masks[i] + replaces[i];
        }
    }

    public Feature getFeatures(List<Integer> tokens, String label) {
        int[] segmentIds = new int[maxSeqLen];
        Arrays.fill(segmentIds, 0);
        int[] masks = new int[maxSeqLen];
        Arrays.fill(masks, 0);
        Arrays.fill(masks, 0, tokens.size(), 1); // tokens size can ensure less than masks
        int[] inputIds = new int[maxSeqLen];
        Arrays.fill(inputIds, 0);
        for (int i = 0; i < tokens.size(); i++) {
            inputIds[i] = tokens.get(i);
        }
        return new Feature(inputIds, masks, segmentIds, labelMap.get(label), tokens.size());
    }

    public List<Integer> tokenize(String text,boolean trainMod) {
        String[] splitTokens = splitText(text);
        List<String> wordPieceTokens = wordPieceTokenize(splitTokens);
        return convertTokensToIds(wordPieceTokens,trainMod); // trainMod need cyclicTrunc

    }

    public static void main(String[] args) throws IOException {
        String test = "\u9EC4abc\u5927";
        CustomTokenizer customTokenizer = new CustomTokenizer();
        String line = "<<<other>>>unaffable";
        String[] tokens = line.split(">>>");
        if (tokens.length != 2) {
            System.out.println("Input line ERROR");
        }


        String vocabFile = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/main/native/dataset/vocab.txt";
        String idsFile = "/home/meng/zj10/fl/mindspore/mindspore/lite/vocab_map_ids.txt";
        customTokenizer.init(vocabFile, idsFile,true,true);
        customTokenizer.tokenize(tokens[1],true);
        tokens = tokens[0].split("<<<");
        System.out.println(tokens[0]);
    }
}



In [None]:
package com.huawei.flclient.model;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;

public class DataSet {

    public static List<Feature> init(String trainFile,String vocabFile,String idsFile,boolean trainMod) {
        // read train file
        CustomTokenizer customTokenizer = new CustomTokenizer();

        customTokenizer.init(vocabFile,idsFile,trainMod,true);
        Path path = Paths.get(trainFile);
        List<String> allLines = null;
        try {
            allLines = Files.readAllLines(path, StandardCharsets.UTF_8);
        } catch (IOException e) {
            e.printStackTrace();
        }
        List<String> examples = new ArrayList<>();
        List<String> labels = new ArrayList<>();
        for(String line:allLines) {
            String[] tokens= line.split(">>>");
            if(tokens.length != 2) {
                System.out.println("Input line ERROR");
                continue;
            }
            examples.add(tokens[1]);
            tokens = tokens[0].split("<<<");
            if(tokens.length != 2) {
                System.out.println("Input line ERROR");
                continue;
            }
            labels.add(tokens[1]);
        }

        List<Feature> features= new ArrayList<>(examples.size());
        for(int i=0;i< examples.size();i++) {
            List<Integer> tokens = customTokenizer.tokenize(examples.get(i),trainMod);
            Feature feature = customTokenizer.getFeatures(tokens,labels.get(i));
            if(trainMod) {
                customTokenizer.addRandomMaskAndReplace(feature,true,true);
            }
            features.add(feature);
        }
        return features;
    }
}

