<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.huawei.flclient.model;

import com.huawei.flclient.Common;
import com.mindspore.lite.MSTensor;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;

public class AdBert extends TrainModel {
    private static final Logger logger = Logger.getLogger(AdBert.class.toString());

    private static final int NUM_OF_CLASS = 5;

    List<Feature> features;

    private int dataSize;

    private ByteBuffer inputIdBufffer;

    private ByteBuffer tokenIdBufffer;

    private ByteBuffer maskIdBufffer;

    private ByteBuffer labelIdBufffer;

    @Override
    public int initSessionAndInputs(String modelPath, boolean trainMod) {
        int ret = -1;
        trainSession = SessionUtil.initSession(modelPath);
        if (trainSession == null) {
            logger.severe(Common.addTag("session init failed"));
            return ret;
        }
        List<MSTensor> inputs = trainSession.getInputs();
        MSTensor labelIdTensor = inputs.get(0);
        int inputSize = labelIdTensor.elementsNum(); // labelId,tokenId,inputId,maskId has same size
        batchSize = labelIdTensor.getShape()[0];
        if (batchSize <= 0) {
            logger.severe(Common.addTag("batch size should bigger than 0"));
            return ret;
        }
        dataSize = inputSize / batchSize;
        inputIdBufffer = ByteBuffer.allocateDirect(inputSize * Integer.BYTES);
        tokenIdBufffer = ByteBuffer.allocateDirect(inputSize * Integer.BYTES);
        maskIdBufffer = ByteBuffer.allocateDirect(inputSize * Integer.BYTES);
        inputIdBufffer.order(ByteOrder.nativeOrder());
        tokenIdBufffer.order(ByteOrder.nativeOrder());
        maskIdBufffer.order(ByteOrder.nativeOrder());
        if (trainMod) {
            labelIdBufffer = ByteBuffer.allocateDirect(inputSize * Integer.BYTES);
            labelIdBufffer.order(ByteOrder.nativeOrder());
        }
        numOfClass = NUM_OF_CLASS;
        return 0;
    }

    @Override
    public List<Integer> fillModelInput(int batchIdx, boolean trainMod) {
        inputIdBufffer.clear();
        tokenIdBufffer.clear();
        maskIdBufffer.clear();
        if (trainMod) {
            labelIdBufffer.clear();
        }
        List<Integer> labels = new ArrayList<>();
        for (int i = 0; i < batchSize; i++) {
            Feature feature = features.get(batchIdx * batchSize + i);
            for (int j = 0; j < dataSize; j++) {
                inputIdBufffer.putInt(feature.inputIds[j]);
            }
            for (int j = 0; j < dataSize; j++) {
                tokenIdBufffer.putInt(feature.tokenIds[j]);
            }
            for (int j = 0; j < dataSize; j++) {
                maskIdBufffer.putInt(feature.inputMasks[j]);
            }
            if (!trainMod) {
                labels.add(feature.labelIds);
            }
            if (trainMod) {
                for (int j = 0; j < dataSize; j++) {
                    labelIdBufffer.putInt(feature.inputIds[j]);
                }
            }
        }

        List<MSTensor> inputs = trainSession.getInputs();
        MSTensor labelIdTensor;
        MSTensor tokenIdTensor;
        MSTensor inputIdTensor;
        MSTensor maskIdTensor;
        if (trainMod) {
            labelIdTensor = inputs.get(0);
            tokenIdTensor = inputs.get(1);
            inputIdTensor = inputs.get(2);
            maskIdTensor = inputs.get(3);
            labelIdTensor.setData(labelIdBufffer);
        } else {
            tokenIdTensor = inputs.get(0);
            inputIdTensor = inputs.get(1);
            maskIdTensor = inputs.get(2);
        }
        tokenIdTensor.setData(tokenIdBufffer);
        inputIdTensor.setData(inputIdBufffer);
        maskIdTensor.setData(maskIdBufffer);
        return labels;
    }

    @Override
    public int padSamples() {
        if (batchSize <= 0) {
            logger.severe(Common.addTag("batch size should bigger than 0"));
            return -1;
        }
        logger.info(Common.addTag("before pad samples size:" + features.size()));
        int curSize = features.size();
        int modSize = curSize - curSize / batchSize * batchSize;
        padSize = modSize != 0 ? batchSize - modSize : 0;
        for (int i = 0; i < padSize; i++) {
            int idx = (int) (Math.random() * curSize);
            features.add(features.get(idx));
        }
        trainSampleSize = features.size();
        batchNum = features.size() / batchSize;
        logger.info(Common.addTag("after pad samples size:" + features.size()));
        return 0;
    }
}


In [None]:
/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.huawei.flclient.model;

import com.huawei.flclient.Common;

import java.util.Arrays;
import java.util.logging.Logger;

public class AdInferBert extends AdBert {
    private static final Logger logger = Logger.getLogger(AdInferBert.class.toString());

    private static AdInferBert adInferBert;

    public static synchronized AdInferBert getInstance() {
        if (adInferBert == null) {
            adInferBert = new AdInferBert();
        }
        return adInferBert;
    }

    public int initDataSet(String exampleFile, String vocabFile, String idsFile, boolean evalMod) {
        if (evalMod) {
            features = DataSet.init(exampleFile, vocabFile, idsFile, false);
        } else {
            features = DataSet.readInferData(exampleFile, vocabFile, idsFile, false);
        }
        if (features == null) {
            logger.severe(Common.addTag("features cannot be null"));
            return -1;
        }
        return features.size();
    }

    private int[] infer() {
        boolean success = trainSession.eval();
        if (!success) {
            logger.severe(Common.addTag("trainSession switch eval mode failed"));
            return new int[0];
        }
        int[] predictLabels = new int[features.size()];
        for (int j = 0; j < batchNum; j++) {
            fillModelInput(j, false);
            success = trainSession.runGraph();
            if (!success) {
                logger.severe(Common.addTag("run graph failed"));
                return new int[0];
            }
            int[] batchLabels = getBatchLabel();
            System.arraycopy(batchLabels, 0, predictLabels, j * batchSize, batchSize);
        }
        return predictLabels;
    }

    public int[] inferModel(String modelPath, String dataFile, String vocabFile, String idsFile) {
        logger.info(Common.addTag("Infer model," + modelPath + ",Data file," + dataFile + ",vocab file," + vocabFile + ",idsFile," + idsFile));
        int inferSize = initDataSet(dataFile, vocabFile, idsFile, false);
        if (inferSize == 0) {
            logger.severe(Common.addTag("infer size should more than 0"));
            return new int[0];
        }
        int status = initSessionAndInputs(modelPath, false);
        if (status == -1) {
            logger.severe(Common.addTag("init session and inputs failed"));
            return new int[0];
        }
        status = padSamples();
        if (status == -1) {
            logger.severe(Common.addTag("infer model failed"));
            return new int[0];
        }
        if (batchSize <= 0) {
            logger.severe(Common.addTag("batch size must bigger than 0"));
            return new int[0];
        }
        batchNum = features.size() / batchSize;
        int[] predictLabels = infer();
        if (predictLabels.length == 0) {
            return new int[0];
        }
        return Arrays.copyOfRange(predictLabels, 0, inferSize);
    }
}


In [None]:
/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.huawei.flclient.model;

import com.huawei.flclient.Common;

import java.util.logging.Logger;

public class AdTrainBert extends AdBert {
    private static final Logger logger = Logger.getLogger(AdTrainBert.class.toString());

    private static AdTrainBert adTrainBert;

    public static synchronized AdTrainBert getInstance() {
        if (adTrainBert == null) {
            adTrainBert = new AdTrainBert();
        }
        return adTrainBert;
    }

    public int initDataSet(String dataFile, String vocabFile, String idsFile) {
        features = DataSet.init(dataFile, vocabFile, idsFile, true);
        if (features == null) {
            logger.severe(Common.addTag("features cannot be null"));
            return -1;
        }
        return features.size();
    }
}




In [None]:
/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.huawei.flclient.model;

import com.huawei.flclient.Common;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.logging.Logger;

public class CustomTokenizer {
    private static final Logger logger = Logger.getLogger(AdInferBert.class.toString());
    private Map<String, Integer> vocabs = new HashMap<>();
    private Boolean doLowerCase = Boolean.TRUE;
    private int maxInputChars = 100;
    private String[] NotSplitStrs = {"UNK"};
    private String unkToken = "[UNK]";
    private int maxSeqLen = 16;
    private int vocabSize = 11682;
    private Map<String, Integer> labelMap = new HashMap<String, Integer>() {{
        put("beauty", 0);
        put("education", 1);
        put("hotel", 2);
        put("travel", 3);
        put("other", 4);
    }};

    public void init(String vocabFile, String idsFile, boolean trainMod, boolean doLowerCase) {
        this.doLowerCase = doLowerCase;
        Path vocabPath = Paths.get(vocabFile);
        List<String> vocabLines = null;
        try {
            vocabLines = Files.readAllLines(vocabPath, StandardCharsets.UTF_8);
        } catch (IOException e) {
            logger.severe(Common.addTag("read vocab file failed," + e.getMessage()));
        }
        if (vocabLines == null) {
            logger.severe(Common.addTag("vocabLines cannot be null"));
            return;
        }
        Path idsPath = Paths.get(idsFile);
        List<String> idsLines = null;
        try {
            idsLines = Files.readAllLines(idsPath, StandardCharsets.UTF_8);
        } catch (IOException e) {
            logger.severe(Common.addTag("read ids file failed," + e.getMessage()));
        }
        if (idsLines == null) {
            logger.severe(Common.addTag("idsLines cannot be null"));
            return;
        }
        for (int i = 0; i < idsLines.size(); ++i) {
            vocabs.put(vocabLines.get(i), Integer.parseInt(idsLines.get(i)));
        }
        if (!trainMod) {
            maxSeqLen = 256;
        }
    }

    // is chinses or punctuation
    public Boolean isChineseOrPunc(char trimChar) {
        // is chinese char
        if (trimChar >= '\u4e00' && trimChar <= '\u9fa5') {
            return true;
        }
        // is puncuation char
        return (trimChar >= 33 && trimChar <= 47) || (trimChar >= 58 && trimChar <= 64) || (trimChar >= 91 && trimChar
                <= 96) || (trimChar >= 123 && trimChar <= 126);
    }

    public String[] splitText(String text) {
        if (text.isEmpty()) {
            return new String[0];
        }
        // clean remove white and control char
        String trimText = text.trim();
        StringBuilder cleanText = new StringBuilder();
        for (int i = 0; i < trimText.length(); i++) {
            if (isChineseOrPunc(trimText.charAt(i))) {
                cleanText.append(" ").append(trimText.charAt(i)).append(" ");
            } else {
                cleanText.append(trimText.charAt(i));
            }
        }
        return cleanText.toString().trim().split("\\s+");
    }

    //   input = "unaffable" , output = ["un", "##aff", "##able"]
    public List<String> wordPieceTokenize(String[] tokens) {
        List<String> outputTokens = new ArrayList<>();
        for (String token : tokens) {
            List<String> subTokens = new ArrayList<>();
            boolean isBad = false;
            int start = 0;
            while (start < token.length()) {
                int end = token.length();
                String curStr = "";
                while (start < end) {
                    String subStr = token.substring(start, end);
                    if (start > 0) {
                        subStr = "##" + subStr;
                    }
                    if (vocabs.get(subStr) != null) {
                        curStr = subStr;
                        break;
                    }
                    end = end - 1;
                }
                if (curStr.isEmpty()) {
                    isBad = true;
                    break;
                }
                subTokens.add(curStr);
                start = end;
            }
            if (isBad) {
                outputTokens.add(unkToken);
            } else {
                outputTokens.addAll(subTokens);
            }
        }
        return outputTokens;

    }

    public List<Integer> convertTokensToIds(List<String> tokens, boolean cycTrunc) {
        int seqLen = tokens.size();
        if (tokens.size() > maxSeqLen - 2) {
            if (cycTrunc) {
                int randIndex = (int) (Math.random() * seqLen);
                if (randIndex > seqLen - maxSeqLen + 2) {
                    List<String> rearPart = tokens.subList(randIndex, seqLen);
                    List<String> frontPart = tokens.subList(0, randIndex + maxSeqLen - 2 - seqLen);
                    rearPart.addAll(frontPart);
                    tokens = rearPart;
                } else {
                    tokens = tokens.subList(randIndex, randIndex + maxSeqLen - 2);
                }
            } else {
                tokens = tokens.subList(0, maxSeqLen - 2);
            }
        }
        tokens.add(0, "[CLS]");
        tokens.add("[SEP]");
        List<Integer> ids = new ArrayList<>(tokens.size());
        for (String token : tokens) {
            ids.add(vocabs.getOrDefault(token, vocabs.get("[UNK]")));
        }
        return ids;
    }

    public void addRandomMaskAndReplace(Feature feature, boolean keepFirstUnchange, boolean keepLastUnchange) {
        int[] masks = new int[maxSeqLen];
        Arrays.fill(masks, 1);
        int[] replaces = new int[maxSeqLen];
        Arrays.fill(replaces, 1);
        int[] inputIds = feature.inputIds;
        for (int i = 0; i < feature.seqLen; i++) {
            double rand1 = Math.random();
            if (rand1 < 0.15) {
                masks[i] = 0;
                double rand2 = Math.random();
                if (rand2 < 0.8) {
                    replaces[i] = 103;
                } else if (rand2 < 0.9) {
                    masks[i] = 1;
                } else {
                    replaces[i] = (int) (Math.random() * vocabSize);
                }
            }
            if (keepFirstUnchange) {
                masks[i] = 1;
                replaces[i] = 0;
            }
            if (keepLastUnchange) {
                masks[feature.seqLen - 1] = 1;
                replaces[feature.seqLen - 1] = 0;
            }
            inputIds[i] = inputIds[i] * masks[i] + replaces[i];
        }
    }

    public Feature getFeatures(List<Integer> tokens, String label) {
        int[] segmentIds = new int[maxSeqLen];
        Arrays.fill(segmentIds, 0);
        int[] masks = new int[maxSeqLen];
        Arrays.fill(masks, 0);
        Arrays.fill(masks, 0, tokens.size(), 1); // tokens size can ensure less than masks
        int[] inputIds = new int[maxSeqLen];
        Arrays.fill(inputIds, 0);
        for (int i = 0; i < tokens.size(); i++) {
            inputIds[i] = tokens.get(i);
        }
        return new Feature(inputIds, masks, segmentIds, labelMap.get(label), tokens.size());
    }

    public List<Integer> tokenize(String text, boolean trainMod) {
        String[] splitTokens = splitText(text);
        List<String> wordPieceTokens = wordPieceTokenize(splitTokens);
        return convertTokensToIds(wordPieceTokens, trainMod); // trainMod need cyclicTrunc
    }
}



In [None]:
/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.huawei.flclient.model;

import com.huawei.flclient.Common;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;

public class DataSet {
    private static final Logger logger = Logger.getLogger(DataSet.class.toString());

    public static List<Feature> init(String trainFile, String vocabFile, String idsFile, boolean trainMod) {
        if (trainFile.isEmpty() || vocabFile.isEmpty() || idsFile.isEmpty()) {
            logger.severe(Common.addTag("dataset init failed,trainFile,idsFile,vocabFile cannot be empty"));
            return null;
        }
        // read train file
        CustomTokenizer customTokenizer = new CustomTokenizer();
        customTokenizer.init(vocabFile, idsFile, trainMod, true);
        List<String> allLines = readTxtFile(trainFile);
        if (allLines == null) {
            logger.severe(Common.addTag("all lines cannot be null"));
            return null;
        }
        List<String> examples = new ArrayList<>();
        List<String> labels = new ArrayList<>();
        for (String line : allLines) {
            String[] tokens = line.split(">>>");
            if (tokens.length != 2) {
                logger.warning(Common.addTag("line may have format problem,need include >>>"));
                continue;
            }
            examples.add(tokens[1]);
            tokens = tokens[0].split("<<<");
            if (tokens.length != 2) {
                logger.warning(Common.addTag("line may have format problem,need include >>>"));
                continue;
            }
            labels.add(tokens[1]);
        }

        List<Feature> features = new ArrayList<>(examples.size());
        for (int i = 0; i < examples.size(); i++) {
            List<Integer> tokens = customTokenizer.tokenize(examples.get(i), trainMod);
            Feature feature = customTokenizer.getFeatures(tokens, labels.get(i));
            if (trainMod) {
                customTokenizer.addRandomMaskAndReplace(feature, true, true);
            }
            features.add(feature);
        }
        return features;
    }

    public static List<Feature> readInferData(String inferFile, String vocabFile, String idsFile, boolean trainMod) {
        if (inferFile.isEmpty() || vocabFile.isEmpty() || idsFile.isEmpty()) {
            logger.severe(Common.addTag("dataset init failed,trainFile,idsFile,vocabFile cannot be empty"));
            return null;
        }
        // read train file
        CustomTokenizer customTokenizer = new CustomTokenizer();
        customTokenizer.init(vocabFile, idsFile, false, true);
        List<String> allLines = readTxtFile(inferFile);
        if (allLines == null) {
            logger.severe(Common.addTag("all lines cannot be null"));
            return null;
        }
        List<Feature> features = new ArrayList<>(allLines.size());
        for (String line : allLines) {
            if (line.isEmpty()) {
                continue;
            }
            List<Integer> tokens = customTokenizer.tokenize(line, trainMod);
            Feature feature = customTokenizer.getFeatures(tokens, "other");
            features.add(feature);
        }
        return features;
    }

    public static byte[] readBinFile(String dataFile) {
        // read train file
        Path path = Paths.get(dataFile);
        byte[] data = null;
        try {
            data = Files.readAllBytes(path);
        } catch (IOException e) {
            logger.severe(Common.addTag("read ids file failed," + e.getMessage()));
        }
        return data;
    }

    private static List<String> readTxtFile(String file) {
        Path path = Paths.get(file);
        List<String> allLines = null;
        try {
            allLines = Files.readAllLines(path, StandardCharsets.UTF_8);
        } catch (IOException e) {
            logger.severe(Common.addTag("read file failed," + e.getMessage()));
        }
        return allLines;
    }
}

