<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
//
// Created by meng on 3/31/21.
//
#include <iostream>
#include "bert_train.h"
#include "lenet_train.h"
#include "util.h"
int main() {
  std::cout << "----------begin train lenet-------" << std::endl;
  std::string lenet_ms_file = "/home/meng/zj10/fl/mindspore/mindspore/lite/lenet_train.mindir.ms";
  std::string lenet_data_input =
    "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/test/resources/data_bin_new/f0049_32/"
    "f0049_32_bn_11_train_data.bin";
  std::string lenet_label_input =
    "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/test/resources/data_bin_new/f0049_32/"
    "f0049_32_bn_11_train_label.bin";

  auto lenet_input_size = SetLenetInputs(lenet_data_input,lenet_label_input);
  std::cout<< "total train size:"<< lenet_input_size<<std::endl;
  auto session = CreateSession(lenet_ms_file);
  auto status = TrainLenet(session, lenet_ms_file, 32, 2);
  if (status != 0) {
    std::cout << "train failed" << std::endl;
  }
  mindspore::session::TrainFeatureParam **feature;
  int size = 0;
  status = GetFeatures(session, &feature, &size);
  if(status != 0) {
    std::cout<< "get feature failed"<<std::endl;
  }
  std::cout << "get total features:" << size << std::endl;
  for (int i = 0; i < size; i++) {
    std::cout << "name:" << feature[i]->name << std::endl;
  }
  std::string lenet_test_data = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/test/resources/data_bin_new/f0049_32/f0049_32_bn_1_test_data.bin";
  std::string lenet_test_label = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/test/resources/data_bin_new/f0049_32/f0049_32_bn_1_test_label.bin";

  (void)SetLenetInputs(lenet_test_data,lenet_test_label);
  std::cout<< "cal acc:"<< InferLenet(session) << std::endl;
  auto infer_result = GetLenetInferRes(session);
  for(auto infer_label:infer_result) {
    std::cout<< "infer_label:"<< infer_label<<std::endl;
  }
  delete session;


  std::cout << "----------begin train bert-------" << std::endl;
  std::string vocab_file = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/main/native/dataset/vocab.txt";
  std::string train_file = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/main/native/dataset/0.tsv";
  std::string labels_file = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/main/native/dataset/label.tsv";
  auto train_data_size = SetBertInputs(train_file, vocab_file, labels_file);
  std::cout << "total train data size:" << train_data_size << std::endl;
  if(train_data_size == -1) {
    std::cout<< "set bert inputs failed" << std::endl;
    return 0;
  }
  std::string ms_file = "/home/meng/zj10/fl/mindspore/mindspore/lite/albert.ms";
  int epoches = 2;
  int batch_size = 16;
  session = CreateSession(ms_file);
  status = TrainBert(session, ms_file, 16, epoches);
  if (status != 0) {
    std::cout << "train failed" << std::endl;
  }
  size = 0;
  status = GetFeatures(session, &feature, &size);
  std::cout << "get total features:" << size << std::endl;
  for (int i = 0; i < size; i++) {
    std::cout << "name:" << feature[i]->name << std::endl;
  }
  delete session;
}

In [None]:
package com.huawei.flclient;


import com.google.flatbuffers.FlatBufferBuilder;
import mindspore.schema.FeatureMap;
import mindspore.schema.FeatureMapList;
import mindspore.schema.RequestUpdateModel;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;

public class LiteTrain {
    private static LiteTrain train;
    private long sessionPtr;

    private LiteTrain() {
    }

    public static synchronized LiteTrain getInstance() {
        if (train == null) {
            train = new LiteTrain();
        }
        return train;
    }

    public FlatBufferBuilder FeatureMapBuilder(String modelName) {
        FlatBufferBuilder builder = new FlatBufferBuilder();
        int[] fmOffsets = getSeralizeFeaturesMap(builder);
        int fmOffset = FeatureMapList.createFeatureMapVector(builder, fmOffsets);
        RequestUpdateModel.startRequestUpdateModel(builder);
        RequestUpdateModel.addFlName(builder, 0);
        RequestUpdateModel.addFlId(builder, 0);
        RequestUpdateModel.addTimestamp(builder, 0);
        RequestUpdateModel.addIteration(builder, 0);
        RequestUpdateModel.addFeatureMap(builder, fmOffset);
        int root = RequestUpdateModel.endRequestUpdateModel(builder);
        builder.finish(root);
        return builder;
    }

    public Map deserializeFeatureMap(String modelName) {
        Map<String, float[]> map = new TreeMap<>();
        FlatBufferBuilder featureBuilder = FeatureMapBuilder(modelName);
        ByteBuffer buf = featureBuilder.dataBuffer();
        RequestUpdateModel dataBuf = RequestUpdateModel.getRootAsRequestUpdateModel(buf);
        int featureSize = dataBuf.featureMapLength();
        for (int i = 0; i < featureSize; i++) {
            FeatureMap featureMap = dataBuf.featureMap(i);
            String dataName = featureMap.weightFullname();
            int dataLen = featureMap.dataLength();
            float[] weights = new float[dataLen];
            if ((dataName.indexOf("Default") < 0) && (dataName.indexOf("nhwc") < 0) && (dataName.indexOf("moment") < 0) && (dataName.indexOf("learning") < 0)) {
                System.out.println("[train] before sort-->feature name: " + featureMap.weightFullname() + " feature size: " + dataLen);
                for (int j = 0; j < dataLen; j++) {
                    float weight = featureMap.data(j);
                    weights[j] = weight;
                }
                map.put(dataName, weights);
            }
        }
        // todo for test
        Iterator<String> iterator = map.keySet().iterator();
        while (iterator.hasNext()) {
            String key = iterator.next();
            System.out.println("[train] after sort-->feature name: " + key + " feature size: " + map.get(key).length);
        }
        return map;
    }
    // todo add msconfig
    public int init(String modelPath) {
         sessionPtr = NativeTrain.createSession(modelPath,0L);
         return 0;
    }

    public int setInput(String fileSet) {
        return NativeTrain.setInput(fileSet);
    }

    public int train(int batchSize,int epoches, int earlyStopMod) {
        return NativeTrain.train(sessionPtr, batchSize,epoches, earlyStopMod);
    }

    public float infer() {
        return NativeTrain.infer(sessionPtr);
    }

    public int[] getInferLabels() {
        return NativeTrain.getInferLables(sessionPtr);
    }

    public Map<String, float[]> getFeaturesMap() {
        return NativeTrain.getFeaturesMap(sessionPtr);
    }

    public int[] getSeralizeFeaturesMap(FlatBufferBuilder builder) {
        return NativeTrain.getSeralizeFeaturesMap(sessionPtr, builder);
    }

    public int updateFeatures(ArrayList<FeatureMap> featureMaps) {
        return NativeTrain.updateFeatures(sessionPtr, featureMaps);
    }

    public int free() {
        return NativeTrain.free(sessionPtr);
    }

}


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.huawei.flclient;

import com.google.flatbuffers.FlatBufferBuilder;
import mindspore.schema.FeatureMap;
import java.util.ArrayList;
import java.util.Map;

public  class NativeTrain {
    static {
        System.loadLibrary("fl");
    }

    static native int setInput(String fileSet);

    static native long createSession(String modelPath,long msConfigPtr);

    static native float infer(long sessionPtr);

    static native int[]  getInferLables(long sessionPtr);

    static native int train(long sessionPtr,int batch_size,int epoches,int earlyStopMod);

   static native int[] getSeralizeFeaturesMap(long sessionPtr, FlatBufferBuilder builder);

   static native Map<String,float[]> getFeaturesMap(long sessionPtr);

   static native int updateFeatures(long sessionPtr,ArrayList<FeatureMap> featureMaps);

   static native int free(long sessionPtr);
}
