In [None]:
#import library
import tensorflow as tf
import numpy as np
import glob
import os
import matplotlib.pyplot as plt
import time
import json
#mount the google drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Revise the FlatDataflow function to make it clearer
#!nvidia-smi

def FlatDataflow(loop_num, query, key, value, bias, batch_granularity_level=1, head_granularity_level=8, length_granularity_level=64, batchTrue=False, headTrue=True, lengthTrue=False):
  if (tf.config.list_physical_devices('GPU')):
    memory_before = tf.config.experimental.get_memory_info('GPU:0')
    print("Iteration %d: Before running, reset the memory! Memory peak: %f; Memory current: %f"%(loop_num, memory_before['peak'], memory_before['current']))
    ##If the dimension of the input tensor is 3 rather than 4, expand its dimension for the batch
    if (len(query.shape) == 3):
      query = query[None, :, :, :]
      key = key[None, :, :, :]
      value = value[None, :, :, :]  
    batch_size, source_length, head_num, dim = tf.shape(query).numpy()
    # Set a fixed bias value here
    bias_value = bias
    for batch in tf.range(0, batch_size, batch_granularity_level):
      # The lowest granularity is batch level.
      if (batch_granularity_level != 1 or batchTrue):
        end_batch = batch + batch_granularity_level if batch + batch_granularity_level <= batch_size else batch_size
        query_source = tf.gather(query[:, :, :, :], indices=tf.range(batch, end_batch), axis=0)
        key_source = tf.gather(key[:, :, :, :], indices=tf.range(batch, end_batch), axis=0)
        value_source = tf.gather(value[:, :, :, :], indices=tf.range(batch, end_batch), axis=0)
        result = tf.einsum("BTNH, BFNH->BNFT", key_source, query_source)
        result += bias_value
        result = tf.nn.softmax(result, name="attention_weights")
        result = tf.nn.dropout(result, rate=0.4)
        attention_output = tf.einsum("BNFT,BTNH->BFNH", result, value_source)
      else:
        for head in tf.range(0, head_num, head_granularity_level):
          # The lowest granularity is head level.
          if (head_granularity_level != 1 or headTrue):
            end_head = head + head_granularity_level if head + head_granularity_level <= head_num else head_num
            query_source = tf.gather(query[batch, :, :, :], indices=tf.range(head, end_head), axis=1)
            key_source = tf.gather(key[batch, :, :, :], indices=tf.range(head, end_head), axis=1)
            value_source = tf.gather(value[batch, :, :, :], indices=tf.range(head, end_head), axis=1)
            result = tf.einsum("TNH, FNH->NFT", key_source, query_source)
            result += bias_value
            result = tf.nn.softmax(result, name="attention_weights")
            result = tf.nn.dropout(result, rate=0.4)
            logit = tf.einsum("NFT,TNH->FNH", result, value_source)
            if head == 0:
              attention_output = logit
            else:
              attention_output = tf.concat([attention_output, logit], axis=1)
          else:
            #Lowest granularity is length level
            for length in tf.range(0, source_length, length_granularity_level):
              end_length = length + length_granularity_level if length + length_granularity_level <= source_length else source_length
              query_source = tf.gather(query[batch, :, head, :], indices=tf.range(length, end_length), axis=0)
              key_source = key[batch, :, head, :]
              result = tf.einsum("TH, FH->FT", key_source, query_source)
              result += bias_value
              result = tf.nn.softmax(result, name="attention_weights")
              result = tf.nn.dropout(result, rate=0.4)
              if length == 0:
                lengthOutput = result
              else:
                lengthOutput = tf.concat([lengthOutput, result], axis=0)
            value_source = value[batch, :, head, :]
            lengthRes = tf.einsum("FT,TH->FH", lengthOutput, value_source)
            lengthRes = tf.expand_dims(lengthRes, axis=1)
            if (head == 0):
              attention_output = lengthRes
            else:
              attention_output = tf.concat([attention_output, lengthRes], axis=1)
        attention_output = tf.expand_dims(attention_output, axis=0)
      if (batch == 0):
        output = attention_output
      else:
        output = tf.concat([output, attention_output], axis=0)
    stoptime = time.time()
    # Clear out the memory
    print("LOOP")
    del output
    del attention_output
    memory_after = tf.config.experimental.get_memory_info('GPU:0')
    print("Iteration %d: After running! Memory peak: %f; Memory current: %f"%(loop_num, memory_after['peak'], memory_after['current']))
    #return (memory_before['peak'], memory_before['current'], memory_after['peak'], memory_after['current'], stoptime)
    return (0, 0, 0, 0, stoptime)

In [None]:
# Code block building graph with x-axis as sequence length and y-axis as peak memory usage

# Read in all the files
query_file = glob.glob("/content/drive/MyDrive/models/transformer/logging_query*.txt")
key_file = glob.glob("/content/drive/MyDrive/models/transformer/logging_key*.txt")
value_file = glob.glob("/content/drive/MyDrive/models/transformer/logging_value*.txt")

# Randomly set a bias value for now
BIAS = 0.02

peak = []
curr = []
running_time = []

#tf.config.experimental.reset_memory_stats('GPU:0')
FILENUM = len(query_file)
BATCHSIZE = 64

queryin = []
keyin = []
valuein = []
for file in query_file:
    qfile = tf.io.read_file(file)
    query = tf.io.parse_tensor(qfile, out_type=tf.float32)
    queryin.append(query)
for file in key_file:
    kfile = tf.io.read_file(file)
    key = tf.io.parse_tensor(kfile, out_type=tf.float32)
    keyin.append(key)
for file in value_file:
    vfile = tf.io.read_file(file)
    value = tf.io.parse_tensor(vfile, out_type=tf.float32)
    valuein.append(value)
queryin = tf.stack(queryin)
keyin = tf.stack(keyin)
valuein = tf.stack(valuein)

# Set up the parameters
batch = 1
head = 4
length = 1024
batchTrue = False
headTrue = True
lengthTrue = False

fileidx = np.random.randint(FILENUM)
batchidx = np.random.randint(BATCHSIZE)
query = queryin[fileidx][batchidx, :, :, :]
key = keyin[fileidx][batchidx, :, :, :]
value = valuein[fileidx][batchidx, :, :, :]

# Generate start matrix with shape 1 * 256 * 16 * 64
# Randomly pick a file number
for i in range(256 // 64 - 1):
  fileidx = np.random.randint(FILENUM)
  batchidx = np.random.randint(BATCHSIZE)
  query = tf.concat((query, queryin[fileidx][batchidx, :, :, :]), axis=0)
  key = tf.concat((key, keyin[fileidx][batchidx, :, :, :]), axis=0)
  value = tf.concat((value, valuein[fileidx][batchidx, :, :, :]), axis=0)

for idx in range(256 // 64, 12 * 1024 // 64, 1):
  start_time = time.time()
  peakOld, currOld, peakCurr, currCurr, stoptime = FlatDataflow(idx-256//64, query, key, value, BIAS, batch_granularity_level=batch, head_granularity_level=head, length_granularity_level=length,
                                                                batchTrue=batchTrue, headTrue=headTrue, lengthTrue=lengthTrue)
  peak.append(peakCurr)
  curr.append(currCurr)
  running_time.append(stoptime - start_time)
  fileidx = np.random.randint(FILENUM)
  batchidx = np.random.randint(BATCHSIZE)

  query = tf.concat((query, queryin[fileidx][batchidx, :, :, :]), axis=0)
  key = tf.concat((key, keyin[fileidx][batchidx, :, :, :]), axis=0)
  value = tf.concat((value, valuein[fileidx][batchidx, :, :, :]), axis=0)
iteration = np.arange(len(peak)) * 64 + 256
plt.plot(iteration, peak,'r',label="Peak")
plt.plot(iteration,curr,'b',label='Current')
plt.legend()
plt.title("Peak Memory Usage")
plt.xlabel("Source Length")
plt.ylabel("Bytes in usage")