In [None]:
import glob
import time
import os
import numpy as np
import tensorflow as tf
from google.colab import drive
from google.colab import auth
auth.authenticate_user()

In [None]:
# Initialize TPU
try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
  raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)

In [None]:
# Mount google cloud storage bucket
!echo "deb http://packages.cloud.google.com/apt gcsfuse-bionic main" > /etc/apt/sources.list.d/gcsfuse.list
!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -
!apt -qq update
!apt -qq install gcsfuse
!gcsfuse --implicit-dirs flatdataflow /content

2022/03/03 04:52:51.588918 Start gcsfuse/0.40.0 (Go version go1.17.6) for app "" using mount point: /content
2022/03/03 04:52:51.605851 Opening GCS connection...
2022/03/03 04:52:51.815035 Mounting file system "flatdataflow"...
2022/03/03 04:52:51.853941 File system has been successfully mounted.


In [None]:
# Incorporate the platform feature
def FlatDataflow(loop_num, query, key, value, bias, batch_granularity_level=1, head_granularity_level=8, length_granularity_level=64, running_platform='CPU'):
  # If the dimension of the input tensor is 3 rather than 4, expand its dimension for the batch
  if (len(query.shape) == 3):
    query = query[None, :, :, :]
    key = key[None, :, :, :]
    value = value[None, :, :, :]  
  batch_size, source_length, head_num, dim = tf.shape(query).numpy()
  # Set a fixed bias value here
  bias_value = bias
  for batch in tf.range(0, batch_size, batch_granularity_level):
    # The lowest granularity is batch level.
    if (batch_granularity_level != 1):
      end_batch = batch + batch_granularity_level if batch + batch_granularity_level <= batch_size else batch_size
      query_source = tf.gather(query[:, :, :, :], indices=tf.range(batch, end_batch), axis=0)
      key_source = tf.gather(key[:, :, :, :], indices=tf.range(batch, end_batch), axis=0)
      value_source = tf.gather(value[:, :, :, :], indices=tf.range(batch, end_batch), axis=0)
      result = tf.einsum("BTNH, BFNH->BNFT", key_source, query_source)
      result += bias_value
      result = tf.nn.softmax(result, name="attention_weights")
      result = tf.nn.dropout(result, rate=0.4)
      attention_output = tf.einsum("BNFT,BTNH->BFNH", result, value_source)
    else:
      for head in tf.range(0, head_num, head_granularity_level):
        # The lowest granularity is head level.
        if (head_granularity_level != 1):
          end_head = head + head_granularity_level if head + head_granularity_level <= head_num else head_num
          query_source = tf.gather(query[batch, :, :, :], indices=tf.range(head, end_head), axis=1)
          key_source = tf.gather(key[batch, :, :, :], indices=tf.range(head, end_head), axis=1)
          value_source = tf.gather(value[batch, :, :, :], indices=tf.range(head, end_head), axis=1)
          result = tf.einsum("TNH, FNH->NFT", key_source, query_source)
          result += bias_value
          result = tf.nn.softmax(result, name="attention_weights")
          result = tf.nn.dropout(result, rate=0.4)
          logit = tf.einsum("NFT,TNH->FNH", result, value_source)
          if head == 0:
            attention_output = logit
          else:
            attention_output = tf.concat([attention_output, logit], axis=1)
        else:
          #Lowest granularity is length level
          for length in tf.range(0, source_length, length_granularity_level):
            end_length = length + length_granularity_level if length + length_granularity_level <= source_length else source_length
            query_source = tf.gather(query[batch, :, head, :], indices=tf.range(length, end_length), axis=0)
            key_source = key[batch, :, head, :]
            result = tf.einsum("TH, FH->FT", key_source, query_source)
            result += bias_value
            result = tf.nn.softmax(result, name="attention_weights")
            result = tf.nn.dropout(result, rate=0.4)
            if length == 0:
              lengthOutput = result
            else:
              lengthOutput = tf.concat([lengthOutput, result], axis=0)
          value_source = value[batch, :, head, :]
          lengthRes = tf.einsum("FT,TH->FH", lengthOutput, value_source)
          lengthRes = tf.expand_dims(lengthRes, axis=1)
          if (head == 0):
            attention_output = lengthRes
          else:
            attention_output = tf.concat([attention_output, lengthRes], axis=1)
      attention_output = tf.expand_dims(attention_output, axis=0)
    if (batch == 0):
      output = attention_output
    else:
      output = tf.concat([output, attention_output], axis=0)
  stoptime = time.time()
  return stoptime

In [None]:
# Readin data
queryin = []
keyin = []
valuein = []
for idx in range(3, 12):
  query_path = "/content/FILES/logging_query"+str(idx)+".txt"   
  with open(query_path, "rb") as query:
    query_file = query.read()
  query = tf.io.parse_tensor(query_file, out_type=tf.float32)
  queryin.append(query)
  key_path = "/content/FILES/logging_key"+str(idx)+".txt"   
  with open(key_path, "rb") as key:
    key_file = key.read()
  key = tf.io.parse_tensor(key_file, out_type=tf.float32)
  keyin.append(key)
  value_path = "/content/FILES/logging_value"+str(idx)+".txt"   
  with open(value_path, "rb") as value:
    value_file = value.read()
  value = tf.io.parse_tensor(value_file, out_type=tf.float32)
  valuein.append(value)

In [None]:
# Randomly set a bias value for now
BIAS = 0.02

running_time = []
FILENUM = 9
BATCHSIZE = 64
queryin = tf.stack(queryin)
keyin = tf.stack(keyin)
valuein = tf.stack(valuein)
print("Files Successfully Loaded!")

# Set up the parameters
batch = 1
head = 4
length = 1

fileidx = np.random.randint(FILENUM)
batchidx = np.random.randint(BATCHSIZE)
query = queryin[fileidx][batchidx, :, :, :]
key = keyin[fileidx][batchidx, :, :, :]
value = valuein[fileidx][batchidx, :, :, :]

# Generate start matrix with shape 1 * 256 * 16 * 64
# Randomly pick a file number
for i in range(256 // 64 - 1):
  fileidx = np.random.randint(FILENUM)
  batchidx = np.random.randint(BATCHSIZE)
  query = tf.concat((query, queryin[fileidx][batchidx, :, :, :]), axis=-3)
  key = tf.concat((key, keyin[fileidx][batchidx, :, :, :]), axis=-3)
  value = tf.concat((value, valuein[fileidx][batchidx, :, :, :]), axis=-3)

for idx in range(256 // 64, 16 * 1024 // 64, 1):
  start_time = time.time()
  stoptime = FlatDataflow(idx-256//64, query, key, value, BIAS, batch_granularity_level=batch, head_granularity_level=head, length_granularity_level=length)
  running_time.append(stoptime - start_time)
  fileidx = np.random.randint(FILENUM)
  batchidx = np.random.randint(BATCHSIZE)

  query = tf.concat((query, queryin[fileidx][batchidx, :, :, :]), axis=0)
  key = tf.concat((key, keyin[fileidx][batchidx, :, :, :]), axis=0)
  value = tf.concat((value, valuein[fileidx][batchidx, :, :, :]), axis=0)
  print("LOOP %d" % (idx) )

print("Finished!")


In [None]:
# Code block --TO BE FIXED
dataset_list = []
for idx in range(3, 12):
  query_path = "/content/FILES/logging_query"+str(idx)+".txt"   
  with open(query_path, "rb") as query:
    query_file = query.read()
  query = tf.io.parse_tensor(query_file, out_type=tf.float32)
  key_path = "/content/FILES/logging_key"+str(idx)+".txt"   
  with open(key_path, "rb") as key:
    key_file = key.read()
  key = tf.io.parse_tensor(key_file, out_type=tf.float32)
  value_path = "/content/FILES/logging_value"+str(idx)+".txt"   
  with open(value_path, "rb") as value:
    value_file = value.read()
  value = tf.io.parse_tensor(value_file, out_type=tf.float32)
  # Add each (query, key, value) tuple to the list
  dataset_list.append((query, key, value))

#Form the tf dataset
dataset = tf.data.Dataset.from_tensor_slices(dataset_list)

#Distribute training code
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dist_dataset = strategy.experimental_distribute_dataset(dataset)

@tf.function
def replica_fn(input):
  #Decouple the tuple
  query, key, value = input
  stoptime = FlatDataflow(query, key, value)
  return stoptime

result = []
# Iterate over the `tf.distribute.DistributedDataset` for distribute running
for x in dist_dataset:
  # process dataset elements
  result.append(strategy.run(replica_fn, args=(x,)))