<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
cmake_minimum_required(VERSION 3.14)

project(FederalLearning)

option(SUPPORT_GPU "if support gpu" off)
set(BUILD_LITE "on")
set(SUPPORT_TRAIN "on")
set(PLATFORM_ARM "on")

set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../../)
set(LITE_DIR ${TOP_DIR}/mindspore/lite)
set(MS_VERSION_MAJOR ${MS_VERSION_MAJOR})
set(MS_VERSION_MINOR ${MS_VERSION_MINOR})
set(MS_VERSION_REVISION ${MS_VERSION_REVISION})
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} \
  -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} \
  -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")

include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/linux)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/dataset)


include_directories(${LITE_DIR}) ## lite include
include_directories(${TOP_DIR}) ## api include
include_directories(${TOP_DIR}/mindspore/core/) ## core include
include_directories(${LITE_DIR}/build) ## flatbuffers

set(OP_SRC
        lite_train_jni.cpp
        util.cpp
        bert_train.cpp
        lenet_train.cpp
        dataset/CustomizedTokenizer.cc
            )
find_library(log-lib glog)

add_library(fl SHARED ${OP_SRC})

link_directories(${CMAKE_CURRENT_SOURCE_DIR}/lib/)

install(TARGETS fl LIBRARY DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/lib)
add_executable(test test_train.cc ${OP_SRC} )
target_link_libraries(test mindspore-lite  glog)

In [None]:
//
// Created by meng on 3/31/21.
//
#include <iostream>
#include "bert_train.h"
#include "lenet_train.h"
#include "util.h"
int main() {
  std::cout << "----------begin train lenet-------" << std::endl;
  std::string lenet_ms_file = "/home/meng/zj10/fl/mindspore/mindspore/lite/lenet_train.mindir.ms";
  std::string lenet_data_input =
    "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/test/resources/data_bin_new/f0049_32/"
    "f0049_32_bn_11_train_data.bin";
  std::string lenet_label_input =
    "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/test/resources/data_bin_new/f0049_32/"
    "f0049_32_bn_11_train_label.bin";

  auto lenet_input_size = SetLenetInputs(lenet_data_input,lenet_label_input);
  auto session = CreateSession(lenet_ms_file);
  auto status = TrainLenet(session, lenet_ms_file, 32, 500);
  if (status != 0) {
    std::cout << "train failed" << std::endl;
  }
  mindspore::session::TrainFeatureParam **feature;
  int size = 0;
  status = GetFeatures(session, &feature, &size);
  std::cout << "get total features:" << size << std::endl;
  for (int i = 0; i < size; i++) {
    std::cout << "name:" << feature[i]->name << std::endl;
  }
  delete session;


//  std::cout << "----------begin train bert-------" << std::endl;
//  std::string vocab_file = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/main/native/dataset/vocab.txt";
//  std::string train_file = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/main/native/dataset/0.tsv";
//  std::string labels_file = "/home/meng/zj10/fl/mindspore/mindspore/lite/flclient/src/main/native/dataset/label.tsv";
//  auto train_data_size = SetBertInputs(train_file, vocab_file, labels_file);
//  std::cout << "total train data size:" << train_data_size << std::endl;
//
//  std::string ms_file = "/home/meng/zj10/fl/mindspore/mindspore/lite/albert.ms";
//  int epoches = 2;
//  int batch_size = 16;
//  session = CreateSession(ms_file);
//  status = TrainBert(session, ms_file, 16, epoches);
//  if (status != 0) {
//    std::cout << "train failed" << std::endl;
//  }
//  size = 0;
//  status = GetFeatures(session, &feature, &size);
//  std::cout << "get total features:" << size << std::endl;
//  for (int i = 0; i < size; i++) {
//    std::cout << "name:" << feature[i]->name << std::endl;
//  }
//  delete session;
}

In [None]:
/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "util.h"
#include <cstring>
#include <iostream>
#include "include/context.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"

mindspore::tensor::MSTensor *SearchOutputsForSize(mindspore::session::TrainSession *train_session, size_t size) {
  auto outputs = train_session->GetOutputs();
  for (auto it = outputs.begin(); it != outputs.end(); ++it) {
    if (it->second->ElementsNum() == size) return it->second;
  }
  MS_LOG(ERROR) << "Model does not have an output tensor with size:"<<size;
  return nullptr;
}

float GetLoss(mindspore::session::TrainSession *train_session) {
  auto outputsv = SearchOutputsForSize(train_session, 1);  // Search for Loss which is a single value tensor
  if (outputsv == nullptr) {
    return 10000;
  }
  auto loss = reinterpret_cast<float *>(outputsv->MutableData());
  return loss[0];
}
mindspore::session::TrainSession *CreateSession(const std::string &ms_file) {
  // create model file
  mindspore::lite::Context context;
  context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = mindspore::lite::NO_BIND;
  context.thread_num_ = 1;
  bool train_mode = false;
  return mindspore::session::TrainSession::CreateSession(ms_file, &context, train_mode);
}

float CalculateAccuracy(mindspore::session::TrainSession *session,const std::vector<int> &labels,int num_of_class) {
  session->Eval();
  session->RunGraph();
  auto inputs = session->GetInputs();
  auto batch_size = inputs[1]->shape()[0];
  auto outputsv = SearchOutputsForSize(session, batch_size * num_of_class);
  auto scores = reinterpret_cast<float *>(outputsv->MutableData());
  float accuracy = 0.0;
  for (int b = 0; b < batch_size; b++) {
    int max_idx = 0;
    float max_score = scores[num_of_class * b];
    for (int c = 0; c < num_of_class; c++) {
      if (scores[num_of_class * b + c] > max_score) {
        max_score = scores[num_of_class * b + c];
        max_idx = c;
      }
    }
    if (labels[b] == max_idx) accuracy += 1.0;
  }
  return accuracy/batch_size;
}

int UpdateFeatures(TrainSession *train_session,const std::string &update_ms_file, TrainFeatureParam *new_features, int size) {
  train_session->Eval();
  auto status = train_session->UpdateFeatureMaps(update_ms_file, new_features, size);
  if (status != mindspore::lite::RET_OK) {
    MS_LOG(ERROR) << "update model feature map failed" << update_ms_file;
  }
  delete train_session;
  return status;
}

int GetFeatures(TrainSession *train_session, mindspore::session::TrainFeatureParam ***feature,
                              int *size) {

  std::vector<mindspore::session::TrainFeatureParam *> new_features;
  auto status = train_session->GetFeatureMaps(&new_features);
  if (status != mindspore::lite::RET_OK) {
    MS_LOG(ERROR) << "get model feature map failed";
    return mindspore::lite::RET_ERROR;
  }
  *feature = new (std::nothrow) TrainFeatureParam *[new_features.size()];
  if (*feature == nullptr) {
    MS_LOG(ERROR) << "create features failed";
    return mindspore::lite::RET_ERROR;
  }
  for (int i = 0; i < new_features.size(); i++) {
    (*feature)[i] = new_features[i];
  }
  *size = new_features.size();
  return mindspore::lite::RET_OK;
}

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MSLITE_FL_UTIL_H
#define MSLITE_FL_UTIL_H

#include <string>
#include "include/train/train_session.h"

using mindspore::session::TrainFeatureParam;
using mindspore::session::TrainSession;
using mindspore::tensor::MSTensor;
int GetFeatures(TrainSession *train_session, TrainFeatureParam ***features,
                              int *size);
int UpdateFeatures(TrainSession *train_session,const std::string &update_ms_file, TrainFeatureParam *new_features, int size);
TrainSession *CreateSession(const std::string &ms_file);
MSTensor *SearchOutputsForSize(TrainSession *train_session, size_t size);
float GetLoss(TrainSession *train_session);
float CalculateAccuracy(TrainSession *session,const std::vector<int> &labels,int num_of_class);
#endif  // MSLITE_FL_UTIL_H


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MSLITE_FL_BERT_TRAIN_H
#define MSLITE_FL_BERT_TRAIN_H

#include "include/train/train_session.h"
#include <string>
int SetBertInputs(const std::string &train_file,const std::string &vocab_file,const std::string &labels_file);
void FreeBertInput();
int TrainBert(mindspore::session::TrainSession *session,const std::string &ms_file,int batch_size ,int epoches);
float InferBert(mindspore::session::TrainSession *session);
#endif  // MSLITE_FL_BERT_TRAIN_H


In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "bert_train.h"
#include "util.h"
#include <cstring>
#include <fstream>
#include <iostream>
#include "include/context.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "dataset/CustomizedTokenizer.h"
#include <climits>

static int *lable_ids = 0;
static int *input_ids = 0;
static int *input_mask = 0;
static int *token_type_ids = 0;
static int  batch_num = 0;
static int total_size = 0;

#define MAX_SEQ_LENGTH 32
#define BATCH_SIZE 16
#define LABEL_CLASS 107

std::vector<int> FillBertInputData(mindspore::session::TrainSession *train_session, int batch_idx) {
  auto inputs = train_session->GetInputs();
  int batch_size = inputs[0]->shape()[0];

  auto model_input_mask = reinterpret_cast<int *>(inputs.at(0)->MutableData());
  auto model_input_ids = reinterpret_cast<int *>(inputs.at(1)->MutableData());
  auto model_token_id = reinterpret_cast<int *>(inputs.at(2)->MutableData());
  auto model_label_ids = reinterpret_cast<int *>(inputs.at(3)->MutableData());
  std::vector<int> labels_vec(batch_size);
  for (int i = 0; i < batch_size; i++) {
    std::memcpy(model_input_mask + i * MAX_SEQ_LENGTH,
                input_mask + +batch_idx * inputs[0]->ElementsNum() + i * MAX_SEQ_LENGTH, MAX_SEQ_LENGTH * sizeof(int));
    std::memcpy(model_input_ids + i * MAX_SEQ_LENGTH,
                input_ids + +batch_idx * inputs[1]->ElementsNum() + i * MAX_SEQ_LENGTH, MAX_SEQ_LENGTH * sizeof(int));
    std::memcpy(model_token_id + i * MAX_SEQ_LENGTH,
                token_type_ids + +batch_idx * inputs[2]->ElementsNum() + i * MAX_SEQ_LENGTH, MAX_SEQ_LENGTH * sizeof(int));
    model_label_ids[i] = lable_ids[batch_idx*batch_size+i];
    labels_vec[i] = lable_ids[batch_idx*batch_size+i];

  }
//  std::ofstream  ofs("model_input_mask.bin", std::ios::binary | std::ios::out);
//  ofs.write((const char*)model_input_mask, sizeof(int) * inputs[0]->ElementsNum());
//  ofs.close();
//
//  std::ofstream  ofs1("model_input_ids.bin", std::ios::binary | std::ios::out);
//  ofs1.write((const char*)model_input_ids, sizeof(int) * inputs[1]->ElementsNum());
//  ofs1.close();
//
//  std::ofstream  ofs3("model_label_ids.bin", std::ios::binary | std::ios::out);
//  ofs3.write((const char*)model_label_ids, sizeof(int) * inputs[3]->ElementsNum());
//  ofs3.close();
//
//  std::ofstream  ofs2("model_token_id.bin", std::ios::binary | std::ios::out);
//  ofs2.write((const char*)model_token_id, sizeof(int) * inputs[2]->ElementsNum());
//  ofs2.close();
  return labels_vec;
}



// net inference function
float InferBert(TrainSession *session) {
  auto labels = FillBertInputData(session, 0);
  auto infer_acc = CalculateAccuracy(session, labels,LABEL_CLASS);
  std::cout << "inference acc is:" << infer_acc << std::endl;
  return infer_acc;
}


// net training function
int TrainBert(TrainSession *session,const std::string &ms_file,int batch_size, int epoches) {
  if (epoches <= 0) {
    std::cout << "error iterations or epoch!, epoch:"
                  << ", iterations" << epoches;
    return mindspore::lite::RET_ERROR;
  }
  batch_num = total_size/batch_size;
  std::cout << "total epoches :" << epoches << ",batch_num:" << batch_num<<std::endl;
  for (int j = 0; j < epoches; ++j) {
    float sum_loss_per_epoch = 0.0f;
    float sum_acc_per_epoch = 0.0f;
    for (int k = 0; k < batch_num; ++k) {
      session->Train();
      auto lables = FillBertInputData(session, k);
      session->RunGraph(nullptr, nullptr);
      auto loss = GetLoss(session);
      sum_loss_per_epoch += loss;
      std::cout<< "batch:"<< k<<",loss:"<< loss <<std::endl;
      sum_acc_per_epoch += CalculateAccuracy(session, lables,LABEL_CLASS);

    }
    std::cout << "epoch "
              << "[" << j << "]"
              << ",mean Loss " << sum_loss_per_epoch / batch_num << ",train acc " << sum_acc_per_epoch / batch_num
              << std::endl;
  }
  session->SaveToFile(ms_file);
  return mindspore::lite::RET_OK;
}

void ReadTxt(const std::string &file, std::vector<std::string> *train_data) {
  std::fstream fin;
  fin.open(file, std::ios::in);
  if (fin.is_open()) {
    std::string train_sentense;
    while (!fin.eof()) {
      getline(fin, train_sentense, '\n');
      size_t endpos = train_sentense.find_last_not_of("\r");
      if(endpos != std::string::npos) {
        train_sentense.substr(0,endpos+1).swap(train_sentense);
      }
      if (!train_sentense.empty()) {
        train_data->push_back(train_sentense);
      }
    }
    fin.close();
  }
}

// Set input tensors.
int SetBertInputs(const std::string &train_file,const std::string &vocab_file,const std::string &labels_file) {
  if (train_file.empty()) {
    std::cout << "files empty";
    return -1;
  }
  std::vector<std::string> train_data;
  ReadTxt(train_file, &train_data);
  std::vector<std::string> labels;
  ReadTxt(labels_file, &labels);
  std::map<std::string,int> labels_map;
  for(int i=0;i<labels.size();i++) {
    labels_map[labels[i]] = i;
  }
  total_size = train_data.size();
  int total_batch_num = total_size / BATCH_SIZE;
  if(total_batch_num == 0) {
    std::cout << "train data size less than one batch,not support now";
    return -1;
  }
  int train_size = total_batch_num * BATCH_SIZE;
  input_ids = new (std::nothrow) int[train_size * MAX_SEQ_LENGTH];
  input_mask = new (std::nothrow) int[train_size * MAX_SEQ_LENGTH];
  token_type_ids = new (std::nothrow) int[train_size * MAX_SEQ_LENGTH];
  lable_ids = new (std::nothrow) int[train_size];

  CustomizedTokenizer customized_tokenizer;
  bool do_lower_case = true;
  customized_tokenizer.init(vocab_file, do_lower_case);

  int s_input_ids[MAX_SEQ_LENGTH];
  int s_attention_mask[MAX_SEQ_LENGTH];
  int s_token_type_ids[MAX_SEQ_LENGTH];
  int seq_length;
  // less than one batch would  drop
  for (int i = 0; i < train_size; i++) {
    std::vector<std::string> dataset_tuple(2);
    dataset_tuple.clear();
    char* token = strtok(reinterpret_cast<char *>(train_data[i].data()),"\t");
    while (token != NULL) {
      dataset_tuple.push_back(token);
      token = strtok(NULL, "\t");
    }
    if(dataset_tuple.size() != 2) {
      std::cout<< "train data error,must 2 word"<<std::endl;
    }
    customized_tokenizer.tokenize(dataset_tuple[1], s_input_ids, s_attention_mask, s_token_type_ids, seq_length);
    memcpy(input_ids + i * MAX_SEQ_LENGTH, s_input_ids, MAX_SEQ_LENGTH);
    memcpy(input_mask + i * MAX_SEQ_LENGTH, s_attention_mask, MAX_SEQ_LENGTH);
    memcpy(token_type_ids + i * MAX_SEQ_LENGTH, s_token_type_ids, MAX_SEQ_LENGTH);
    std::string key = dataset_tuple[0];
    lable_ids[i] = labels_map[key];
  }
  return train_size;
}

void FreeBertInput() {
  delete input_ids;
  delete input_mask;
  delete token_type_ids;
}