In [1]:
import glob2
import time
import os
import tensorflow as tf
import numpy as np

In [8]:
# Revise the FlatDataflow function to make it clearer

def FlatDataflow(loop_num, query, key, value, bias, batch_granularity_level=1, head_granularity_level=8, length_granularity_level=64, headTrue = False):
  if (tf.config.list_physical_devices('CPU')):
    # If the dimension of the input tensor is 3 rather than 4, expand its dimension for the batch
    if (len(query.shape) == 3):
      query = query[None, :, :, :]
      key = key[None, :, :, :]
      value = value[None, :, :, :]  
    batch_size, source_length, head_num, dim = tf.shape(query).numpy()
    # Set a fixed bias value here
    bias_value = bias
    for batch in tf.range(0, batch_size, batch_granularity_level):
      # The lowest granularity is batch level.
      if (batch_granularity_level != 1):
        end_batch = batch + batch_granularity_level if batch + batch_granularity_level1 <= batch_size else batch_size
        query_source = tf.gather(query[:, :, :, :], indices=tf.range(batch, end_batch), axis=0)
        key_source = tf.gather(key[:, :, :, :], indices=tf.range(batch, end_batch), axis=0)
        value_source = tf.gather(value[:, :, :, :], indices=tf.range(batch, end_batch), axis=0)
        result = tf.einsum("BTNH, BFNH->BNFT", key_source, query_source)
        result += bias_value
        result = tf.nn.softmax(result, name="attention_weights")
        result = tf.nn.dropout(result, rate=0.4)
        attention_output = tf.einsum("BNFT,BTNH->BFNH", result, value_source)
      else:
        for head in tf.range(0, head_num, head_granularity_level):
          # The lowest granularity is head level.
          if (head_granularity_level != 1 or headTrue):
            end_head = head + head_granularity_level if head + head_granularity_level <= head_num else head_num
            query_source = tf.gather(query[batch, :, :, :], indices=tf.range(head, end_head), axis=1)
            key_source = tf.gather(key[batch, :, :, :], indices=tf.range(head, end_head), axis=1)
            value_source = tf.gather(value[batch, :, :, :], indices=tf.range(head, end_head), axis=1)
            result = tf.einsum("TNH, FNH->NFT", key_source, query_source)
            result += bias_value
            result = tf.nn.softmax(result, name="attention_weights")
            result = tf.nn.dropout(result, rate=0.4)
            logit = tf.einsum("NFT,TNH->FNH", result, value_source)
            if head == 0:
              attention_output = logit
            else:
              attention_output = tf.concat([attention_output, logit], axis=1)
          else:
            #Lowest granularity is length level
            for length in tf.range(0, source_length, length_granularity_level):
              end_length = length + length_granularity_level if length + length_granularity_level <= source_length else source_length
              query_source = tf.gather(query[batch, :, head, :], indices=tf.range(length, end_length), axis=0)
              key_source = key[batch, :, head, :]
              result = tf.einsum("TH, FH->FT", key_source, query_source)
              result += bias_value
              result = tf.nn.softmax(result, name="attention_weights")
              result = tf.nn.dropout(result, rate=0.4)
              if length == 0:
                lengthOutput = result
              else:
                lengthOutput = tf.concat([lengthOutput, result], axis=0)
            value_source = value[batch, :, head, :]
            lengthRes = tf.einsum("FT,TH->FH", lengthOutput, value_source)
            lengthRes = tf.expand_dims(lengthRes, axis=1)
            if (head == 0):
              attention_output = lengthRes
            else:
              attention_output = tf.concat([attention_output, lengthRes], axis=1)
        attention_output = tf.expand_dims(attention_output, axis=0)
      if (batch == 0):
        output = attention_output
      else:
        output = tf.concat([output, attention_output], axis=0)
    stoptime = time.time()
    return (0, 0, 0, 0, stoptime)

In [9]:
# Code block building graph with x-axis as sequence length and y-axis as peak memory usage

# Read in all the files
query_file = glob2.glob("/Users/zeyuchen/Desktop/FLAT/FILES/logging_query*.txt")
key_file = glob2.glob("/Users/zeyuchen/Desktop/FLAT/FILES/logging_key*.txt")
value_file = glob2.glob("/Users/zeyuchen/Desktop/FLAT/FILES/logging_value*.txt")

# Randomly set a bias value for now
BIAS = 0.02

running_time = []

FILENUM = len(query_file)
BATCHSIZE = 64
queryin = []
keyin = []
valuein = []
for file in query_file:
    qfile = tf.io.read_file(file)
    query = tf.io.parse_tensor(qfile, out_type=tf.float32)
    queryin.append(query)
for file in key_file:
    kfile = tf.io.read_file(file)
    key = tf.io.parse_tensor(kfile, out_type=tf.float32)
    keyin.append(key)
for file in value_file:
    vfile = tf.io.read_file(file)
    value = tf.io.parse_tensor(vfile, out_type=tf.float32)
    valuein.append(value)
queryin = tf.stack(queryin)
keyin = tf.stack(keyin)
valuein = tf.stack(valuein)
print("Files Successfully Loaded!")

# Set up the parameters
batch = 1
head = 1
length = -1

fileidx = np.random.randint(FILENUM)
batchidx = np.random.randint(BATCHSIZE)
query = queryin[fileidx][batchidx, :, :, :]
key = keyin[fileidx][batchidx, :, :, :]
value = valuein[fileidx][batchidx, :, :, :]

# Generate start matrix with shape 1 * 256 * 16 * 64
# Randomly pick a file number
for i in range(256 // 64 - 1):
  fileidx = np.random.randint(FILENUM)
  batchidx = np.random.randint(BATCHSIZE)
  query = tf.concat((query, queryin[fileidx][batchidx, :, :, :]), axis=0)
  key = tf.concat((key, keyin[fileidx][batchidx, :, :, :]), axis=0)
  value = tf.concat((value, valuein[fileidx][batchidx, :, :, :]), axis=0)

for idx in range(256 // 64, 8 * 1024 // 64, 1):
  start_time = time.time()
  peakOld, currOld, peakCurr, currCurr, stoptime = FlatDataflow(idx-256//64, query, key, value, BIAS, batch_granularity_level=batch, 
                                                      head_granularity_level=head, length_granularity_level=length, headTrue = True)
  running_time.append(stoptime - start_time)
  fileidx = np.random.randint(FILENUM)
  batchidx = np.random.randint(BATCHSIZE)

  query = tf.concat((query, queryin[fileidx][batchidx, :, :, :]), axis=0)
  key = tf.concat((key, keyin[fileidx][batchidx, :, :, :]), axis=0)
  value = tf.concat((value, valuein[fileidx][batchidx, :, :, :]), axis=0)
  print("LOOP %d" % (idx) )

print("Finished!")


Files Successfully Loaded!
LOOP 4
LOOP 5
LOOP 6
LOOP 7
LOOP 8
LOOP 9
LOOP 10
LOOP 11
LOOP 12
LOOP 13
LOOP 14
LOOP 15
LOOP 16
LOOP 17
LOOP 18
LOOP 19
LOOP 20
LOOP 21
LOOP 22
LOOP 23
LOOP 24
LOOP 25
LOOP 26
LOOP 27
LOOP 28
LOOP 29
LOOP 30
LOOP 31
LOOP 32
LOOP 33
LOOP 34
LOOP 35
LOOP 36
LOOP 37
LOOP 38
LOOP 39
LOOP 40
LOOP 41
LOOP 42
LOOP 43
LOOP 44
LOOP 45
LOOP 46
LOOP 47
LOOP 48
LOOP 49
LOOP 50
LOOP 51
LOOP 52
LOOP 53
LOOP 54
LOOP 55
LOOP 56
LOOP 57
LOOP 58
LOOP 59
LOOP 60
LOOP 61
LOOP 62
LOOP 63
LOOP 64
LOOP 65
LOOP 66
LOOP 67
LOOP 68
LOOP 69
LOOP 70
LOOP 71
LOOP 72
LOOP 73
LOOP 74
LOOP 75
LOOP 76
LOOP 77
LOOP 78
LOOP 79
LOOP 80
LOOP 81
LOOP 82
LOOP 83
LOOP 84
LOOP 85
LOOP 86
LOOP 87
LOOP 88
LOOP 89
LOOP 90
LOOP 91
LOOP 92
LOOP 93
LOOP 94
LOOP 95
LOOP 96
LOOP 97
LOOP 98
LOOP 99
LOOP 100
LOOP 101
LOOP 102
LOOP 103
LOOP 104
LOOP 105
LOOP 106
LOOP 107
LOOP 108
LOOP 109
LOOP 110
LOOP 111
LOOP 112
LOOP 113
LOOP 114
LOOP 115
LOOP 116
LOOP 117
LOOP 118
LOOP 119
LOOP 120
LOOP 121
LOOP 122
LOOP

In [10]:
running_time

[0.12507200241088867,
 0.04843282699584961,
 0.058074951171875,
 0.06334114074707031,
 0.06549930572509766,
 0.07078385353088379,
 0.07886791229248047,
 0.08131289482116699,
 0.09952092170715332,
 0.10601091384887695,
 0.12760496139526367,
 0.14111065864562988,
 0.15697979927062988,
 0.16658425331115723,
 0.21337080001831055,
 0.24250292778015137,
 0.25846290588378906,
 0.28644609451293945,
 0.31878232955932617,
 0.3431839942932129,
 0.3785891532897949,
 0.40302276611328125,
 0.478057861328125,
 0.5046799182891846,
 0.46741318702697754,
 0.5038840770721436,
 0.5495600700378418,
 0.5925040245056152,
 0.7093391418457031,
 0.7467348575592041,
 0.7971432209014893,
 0.981464147567749,
 0.8937628269195557,
 0.8569889068603516,
 0.9147021770477295,
 0.9575960636138916,
 1.0245511531829834,
 1.0661780834197998,
 1.0868008136749268,
 1.2870101928710938,
 1.3939650058746338,
 1.505539894104004,
 1.5384180545806885,
 1.6107709407806396,
 1.730478286743164,
 1.7697160243988037,
 1.7710950374603271