<a href="https://colab.research.google.com/github/tx1103mark/tweet-sentiment/blob/master/TPUs_in_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TPUs in Colab&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>
In this example, we'll work through training a model to classify images of
flowers on Google's lightning-fast Cloud TPUs. Our model will take as input a photo of a flower and return whether it is a daisy, dandelion, rose, sunflower, or tulip.

We use the Keras framework, new to TPUs in TF 2.1.0. Adapted from [this notebook](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_xception_fine_tuned_best.ipynb) by [Martin Gorner](https://twitter.com/martin_gorner).

#### License

Copyright 2019-2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


---


This is not an official Google product but sample code provided for an educational purpose.


## Enabling and testing the TPU

First, you'll need to enable TPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select TPU from the Hardware Accelerator drop-down

Next, we'll check that we can connect to the TPU:

#Data process

In [None]:
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "bert_train.h"
#include "util.h"
#include <cstring>
#include <fstream>
#include <iostream>
#include "include/context.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "dataset/CustomizedTokenizer.h"
#include <climits>

static int *lable_ids = 0;
static int *input_ids = 0;
static int *input_mask = 0;
static int *token_type_ids = 0;
static int batch_num = 0;
static int total_size = 0;

#define MAX_SEQ_LENGTH 32
#define BATCH_SIZE 16
#define LABEL_CLASS 107

unsigned int seed_ = time(NULL);

std::vector<int> FillBertInput(mindspore::session::TrainSession *train_session, int batch_idx, bool train_mode = true) {
  auto inputs = train_session->GetInputs();
  int batch_size = inputs[0]->shape()[0];

  auto model_input_mask = reinterpret_cast<int *>(inputs.at(0)->MutableData());
  auto model_input_ids = reinterpret_cast<int *>(inputs.at(1)->MutableData());
  auto model_token_id = reinterpret_cast<int *>(inputs.at(2)->MutableData());
  auto model_label_ids = reinterpret_cast<int *>(inputs.at(3)->MutableData());
  std::vector<int> labels_vec(batch_size);
  std::fill(model_label_ids, model_label_ids + inputs.at(3)->ElementsNum(), 0.f);
  for (int i = 0; i < batch_size; i++) {
    std::memcpy(model_input_mask + i * MAX_SEQ_LENGTH,
                input_mask + +batch_idx * inputs[0]->ElementsNum() + i * MAX_SEQ_LENGTH, MAX_SEQ_LENGTH * sizeof(int));
    std::memcpy(model_input_ids + i * MAX_SEQ_LENGTH,
                input_ids + +batch_idx * inputs[1]->ElementsNum() + i * MAX_SEQ_LENGTH, MAX_SEQ_LENGTH * sizeof(int));
    std::memcpy(model_token_id + i * MAX_SEQ_LENGTH,
                token_type_ids + +batch_idx * inputs[2]->ElementsNum() + i * MAX_SEQ_LENGTH,
                MAX_SEQ_LENGTH * sizeof(int));
    if (train_mode) {
      model_label_ids[i] = lable_ids[batch_idx * batch_size + i];
      labels_vec[i] = lable_ids[batch_idx * batch_size + i];
    }
  }
  //  std::ofstream ofs("model_input_mask.bin", std::ios::binary | std::ios::out);
  //  ofs.write((const char *)model_input_mask, sizeof(int) * inputs[0]->ElementsNum());
  //  ofs.close();
  //
  //  std::ofstream ofs1("model_input_ids.bin", std::ios::binary | std::ios::out);
  //  ofs1.write((const char *)model_input_ids, sizeof(int) * inputs[1]->ElementsNum());
  //  ofs1.close();
  //
  //  std::ofstream ofs3("model_label_ids.bin", std::ios::binary | std::ios::out);
  //  ofs3.write((const char *)model_label_ids, sizeof(int) * inputs[3]->ElementsNum());
  //  ofs3.close();
  //
  //  std::ofstream ofs2("model_token_id.bin", std::ios::binary | std::ios::out);
  //  ofs2.write((const char *)model_token_id, sizeof(int) * inputs[2]->ElementsNum());
  //  ofs2.close();
  return labels_vec;
}

// net inference function
float InferBert(TrainSession *session) {
  auto labels = FillBertInput(session, 0);
  auto infer_acc = CalculateAccuracy(session, labels, LABEL_CLASS);
  std::cout << "inference acc is:" << infer_acc << std::endl;
  return infer_acc;
}

// net inference function
std::vector<int> GetBertInferRes(TrainSession *session) {
  (void)FillBertInput(session, 0, false);
  return GetInferResult(session, LABEL_CLASS);
}

int infer(TrainSession *session, const std::string &input_str, const std::string &vocab_file) {
  CustomizedTokenizer customized_tokenizer;
  bool do_lower_case = true;
  customized_tokenizer.init(vocab_file, do_lower_case);
  int s_input_ids[MAX_SEQ_LENGTH];
  int s_attention_mask[MAX_SEQ_LENGTH];
  int s_token_type_ids[MAX_SEQ_LENGTH];
  int len = 0;
  customized_tokenizer.tokenize(input_str, s_input_ids, s_attention_mask, s_token_type_ids, len);
//  std::vector<std::vector<int>> resize_dims;
//  resize_dims.push_back({1, MAX_SEQ_LENGTH});
//  resize_dims.push_back({1, MAX_SEQ_LENGTH});
//  resize_dims.push_back({1, MAX_SEQ_LENGTH});
//  resize_dims.push_back({1});
//  auto ret = session->Resize(session->GetInputs(), resize_dims);
//  if (ret != mindspore::lite::RET_OK) {
//    MS_LOG(ERROR) << "Input tensor resize failed.";
//    std::cout << "Input tensor resize failed.";
//    return ret;
//  }
  // not support one input,need pad to one batch
  auto ms_inputs = session->GetInputs();
  auto model_input_mask = reinterpret_cast<int *>(ms_inputs.at(0)->MutableData());
  auto model_input_ids = reinterpret_cast<int *>(ms_inputs.at(1)->MutableData());
  auto model_token_id = reinterpret_cast<int *>(ms_inputs.at(2)->MutableData());
  auto model_label_ids = reinterpret_cast<int *>(ms_inputs.at(3)->MutableData());
  for(int i=0;i<BATCH_SIZE;i++) {
  std::memcpy(model_input_mask+i*MAX_SEQ_LENGTH, s_attention_mask, MAX_SEQ_LENGTH * sizeof(int));
  std::memcpy(model_input_ids+i*MAX_SEQ_LENGTH, s_input_ids, MAX_SEQ_LENGTH * sizeof(int));
  std::memcpy(model_token_id+i*MAX_SEQ_LENGTH, s_token_type_ids, MAX_SEQ_LENGTH * sizeof(int));
  }
  std::fill(model_label_ids, model_label_ids + ms_inputs.at(3)->ElementsNum(), 0.f);

  session->Eval();
  session->RunGraph();
  auto inputs = session->GetInputs();
  auto outputsv = SearchOutputsForSize(session,  BATCH_SIZE* LABEL_CLASS);
  std::cout << "ouput tensor name:" << outputsv->tensor_name() << std::endl;
  auto scores = reinterpret_cast<float *>(outputsv->MutableData());

  int max_idx = 0;
  float max_score = scores[0];
  for (int c = 0; c < LABEL_CLASS; c++) {
    if (scores[c] > max_score) {
      max_score = scores[c];
      max_idx = c;
    }
  }
  return max_idx;
}

// net training function
int TrainBert(TrainSession *session, const std::string &ms_file, int batch_size, int epoches) {
  if (epoches <= 0) {
    std::cout << "error iterations or epoch!, epoch:"
              << ", iterations" << epoches;
    return mindspore::lite::RET_ERROR;
  }
  batch_num = total_size / batch_size;
  std::cout << "total epoches :" << epoches << ",batch_num:" << batch_num << std::endl;
  for (int j = 0; j < epoches; ++j) {
    float sum_loss_per_epoch = 0.0f;
    float sum_acc_per_epoch = 0.0f;
    for (int k = 0; k < batch_num; ++k) {
      session->Train();
      auto lables = FillBertInput(session, k);
      session->RunGraph(nullptr, nullptr);
      auto loss = GetLoss(session);
      sum_loss_per_epoch += loss;
      //      std::cout << "batch:" << k << ",loss:" << loss << std::endl;
      sum_acc_per_epoch += CalculateAccuracy(session, lables, LABEL_CLASS);
    }
    std::cout << "epoch "
              << "[" << j << "]"
              << ",mean Loss " << sum_loss_per_epoch / batch_num << ",train acc " << sum_acc_per_epoch / batch_num
              << std::endl;
  }
  session->SaveToFile(ms_file);
  return mindspore::lite::RET_OK;
}

void ReadTxt(const std::string &file, std::vector<std::string> *train_data) {
  std::fstream fin;
  fin.open(file, std::ios::in);
  if (fin.is_open()) {
    std::string train_sentense;
    while (!fin.eof()) {
      getline(fin, train_sentense, '\n');
      size_t endpos = train_sentense.find_last_not_of("\r");
      if (endpos != std::string::npos) {
        train_sentense.substr(0, endpos + 1).swap(train_sentense);
      }
      if (!train_sentense.empty()) {
        train_data->push_back(train_sentense);
      }
    }
    fin.close();
  }
}

// Set input tensors.
int SetBertInputs(const std::string &train_file, const std::string &vocab_file, const std::string &labels_file) {
  if (train_file.empty()) {
    std::cout << "files empty";
    return -1;
  }
  std::vector<std::string> train_data;
  ReadTxt(train_file, &train_data);
  std::vector<std::string> labels;
  ReadTxt(labels_file, &labels);
  std::map<std::string, int> labels_map;
  for (int i = 0; i < labels.size(); i++) {
    labels_map[labels[i]] = i;
  }
  int train_size = train_data.size();
  int total_batch_num = train_size / BATCH_SIZE;
  int remain_size = train_size % BATCH_SIZE;
  int pad_size = BATCH_SIZE - remain_size;
  if (total_batch_num == 0) {
    std::cout << "train data size less than one batch,need random padding" << std::endl;
  }
  total_size = train_size + pad_size;
  input_ids = new (std::nothrow) int[total_size * MAX_SEQ_LENGTH];
  input_mask = new (std::nothrow) int[total_size * MAX_SEQ_LENGTH];
  token_type_ids = new (std::nothrow) int[total_size * MAX_SEQ_LENGTH];
  lable_ids = new (std::nothrow) int[total_size];

  CustomizedTokenizer customized_tokenizer;
  bool do_lower_case = true;
  customized_tokenizer.init(vocab_file, do_lower_case);

  int s_input_ids[MAX_SEQ_LENGTH];
  int s_attention_mask[MAX_SEQ_LENGTH];
  int s_token_type_ids[MAX_SEQ_LENGTH];
  int seq_length;

  for (int i = 0; i < total_size; i++) {
    std::vector<std::string> dataset_tuple(2);
    dataset_tuple.clear();
    int idx = i;
    // less than one batch would  pad random
    if (i >= train_size) {
      idx = rand_r(&seed_) % train_size;
    }
    size_t pos = train_data[idx].find("\t", 0);
    if (pos != std::string::npos) {
      dataset_tuple.push_back(train_data[idx].substr(0, pos));
      dataset_tuple.push_back(train_data[idx].substr(pos + 1, train_data[idx].size()));
    }
    if (dataset_tuple.size() != 2) {
      std::cout << "train data error,must 2 word.idx::" << idx << std::endl;
      return -1;
    }
    customized_tokenizer.tokenize(dataset_tuple[1], s_input_ids, s_attention_mask, s_token_type_ids, seq_length);
    memcpy(input_ids + i * MAX_SEQ_LENGTH, s_input_ids, MAX_SEQ_LENGTH * sizeof(int));
    memcpy(input_mask + i * MAX_SEQ_LENGTH, s_attention_mask, MAX_SEQ_LENGTH * sizeof(int));
    memcpy(token_type_ids + i * MAX_SEQ_LENGTH, s_token_type_ids, MAX_SEQ_LENGTH * sizeof(int));
    std::string key = dataset_tuple[0];
    lable_ids[i] = labels_map[key];
  }
  std::cout << "total train size :" << std::endl << std::endl;
  return total_size;
}

void FreeBertInput() {
  delete input_ids;
  delete input_mask;
  delete token_type_ids;
}

In [None]:
extern "C" JNIEXPORT jint JNICALL Java_com_huawei_flclient_NativeTrain_getInferLabel(JNIEnv *env, jclass, jlong session_ptr,jstring input_str ,jstring vocab_file) {


  return infer(reinterpret_cast<mindspore::session::TrainSession *>(session_ptr),JstringToChar(env, input_str),JstringToChar(env, vocab_file));
}