<a href="https://colab.research.google.com/github/wy-go/google-research/blob/master/spreadsheet_coder/experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Data
Load data from [[UI]](https://console.cloud.google.com/storage/browser/spreadsheet_coder)


## Copy files from GCS with gsutil

In [1]:
from google.colab import auth
auth.authenticate_user()

# https://cloud.google.com/resource-manager/docs/creating-managing-projects
project_id = 'sheetcoder'
!gcloud config set project {project_id}

Updated property [core/project].


In [2]:
%cd /content/
!ls
!rm -rf google-research 
# Clone the entire repo.
!git clone https://github.com/wy-go/google-research.git

/content
adc.json  google-research  sample_data
Cloning into 'google-research'...
remote: Enumerating objects: 36716, done.[K
remote: Counting objects: 100% (1404/1404), done.[K
remote: Compressing objects: 100% (888/888), done.[K
remote: Total 36716 (delta 606), reused 1116 (delta 491), pack-reused 35312[K
Receiving objects: 100% (36716/36716), 296.98 MiB | 29.04 MiB/s, done.
Resolving deltas: 100% (19746/19746), done.
Checking out files: 100% (12421/12421), done.


In [3]:
# Download the file from a given Google Cloud Storage bucket.
%cd google-research/spreadsheet_coder
!gsutil -m cp -r \
  "gs://spreadsheet_coder/enron/" .
!ls

/content/google-research/spreadsheet_coder
Copying gs://spreadsheet_coder/enron/formulas_5_10.tf_record...
Copying gs://spreadsheet_coder/enron/formulas_0_5.tf_record...
Copying gs://spreadsheet_coder/enron/formulas_20_25.tf_record...
Copying gs://spreadsheet_coder/enron/formulas_15_20.tf_record...
Copying gs://spreadsheet_coder/enron/formulas_10_15.tf_record...
Copying gs://spreadsheet_coder/enron/formulas_25_30.tf_record...
\ [6/6 files][  3.9 GiB/  3.9 GiB] 100% Done  52.6 MiB/s ETA 00:00:00           
Operation completed over 6 objects/3.9 GiB.                                      
bert_modeling.py	model.py
constants.py		model_utils.py
enron			README.md
experiments.ipynb	spreadsheetcoder_model_architecture.png
mobilebert_modeling.py


## Read the TFRecord files

In [4]:
import tensorflow as tf
print(tf.__version__)

2.6.0


In [5]:
import os

data_dir = 'enron/'
tfrecord_files = [data_dir + filename for filename in os.listdir(data_dir)]
raw_dataset = tf.data.TFRecordDataset(tfrecord_files)

# Show the first 3 records.
for raw_record in raw_dataset.take(3):
  print(repr(raw_record))

<tf.Tensor: shape=(), dtype=string, numpy=b'\n\xc0\x1b\n\x99\x19\n\x0ccontext_data\x12\x88\x19\n\x85\x19\n\x82\x19$R[0]C[0]$formula=RANGE UPLUS RANGE - FORMULA_START, formulaTokenList==+ RANGE - RANGE, ranges=[R1C1FormulaRange{relativeRange=R[0]C[-12], startTokenIndex=2, workbookRangeId=null}, R1C1FormulaRange{relativeRange=R[0]C[-2], startTokenIndex=4, workbookRangeId=null}], computedValue=doub:0.0$$R[1]C[0]$formula=RANGE UPLUS RANGE - FORMULA_START, formulaTokenList==+ RANGE - RANGE, ranges=[R1C1FormulaRange{relativeRange=R[0]C[-12], startTokenIndex=2, workbookRangeId=null}, R1C1FormulaRange{relativeRange=R[0]C[-2], startTokenIndex=4, workbookRangeId=null}], computedValue=doub:-2392.0$$R[2]C[0]$formula=RANGE UPLUS RANGE - FORMULA_START, formulaTokenList==+ RANGE - RANGE, ranges=[R1C1FormulaRange{relativeRange=R[0]C[-12], startTokenIndex=2, workbookRangeId=null}, R1C1FormulaRange{relativeRange=R[0]C[-2], startTokenIndex=4, workbookRangeId=null}], computedValue=doub:-995.0$$R[3]C[0]$fo

In [6]:
# Create a description of the features.
feature_description = {
    'table_id': tf.io.FixedLenFeature([], tf.int64),
    'doc_id': tf.io.FixedLenFeature([], tf.string),
    'record_index': tf.io.FixedLenFeature([], tf.int64),
    'col_index': tf.io.FixedLenFeature([], tf.int64),
    'formula': tf.io.FixedLenFeature([], tf.string),
    'formula_token_list': tf.io.FixedLenFeature([], tf.string),
    'ranges': tf.io.FixedLenFeature([], tf.string),
    'computed_value': tf.io.FixedLenFeature([], tf.string),
    'header': tf.io.FixedLenFeature([], tf.string),
    'context_header': tf.io.FixedLenFeature([], tf.string),
    'context_data': tf.io.FixedLenFeature([], tf.string)
}

def _parse_example(example_proto):
  # Parse the input `tf.train.Example` proto using the dictionary above.
  feature_dict = tf.io.parse_single_example(example_proto, feature_description)
  return feature_dict

parsed_dataset = raw_dataset.map(_parse_example)

parsed_dataset

<MapDataset shapes: {col_index: (), computed_value: (), context_data: (), context_header: (), doc_id: (), formula: (), formula_token_list: (), header: (), ranges: (), record_index: (), table_id: ()}, types: {col_index: tf.int64, computed_value: tf.string, context_data: tf.string, context_header: tf.string, doc_id: tf.string, formula: tf.string, formula_token_list: tf.string, header: tf.string, ranges: tf.string, record_index: tf.int64, table_id: tf.int64}>

In [7]:
# Show the first 3 observations in datasets.
for parsed_record in parsed_dataset.take(10):
  print(parsed_record)

{'col_index': <tf.Tensor: shape=(), dtype=int64, numpy=25>, 'computed_value': <tf.Tensor: shape=(), dtype=string, numpy=b'doub:0.0'>, 'context_data': <tf.Tensor: shape=(), dtype=string, numpy=b'$R[0]C[0]$formula=RANGE UPLUS RANGE - FORMULA_START, formulaTokenList==+ RANGE - RANGE, ranges=[R1C1FormulaRange{relativeRange=R[0]C[-12], startTokenIndex=2, workbookRangeId=null}, R1C1FormulaRange{relativeRange=R[0]C[-2], startTokenIndex=4, workbookRangeId=null}], computedValue=doub:0.0$$R[1]C[0]$formula=RANGE UPLUS RANGE - FORMULA_START, formulaTokenList==+ RANGE - RANGE, ranges=[R1C1FormulaRange{relativeRange=R[0]C[-12], startTokenIndex=2, workbookRangeId=null}, R1C1FormulaRange{relativeRange=R[0]C[-2], startTokenIndex=4, workbookRangeId=null}], computedValue=doub:-2392.0$$R[2]C[0]$formula=RANGE UPLUS RANGE - FORMULA_START, formulaTokenList==+ RANGE - RANGE, ranges=[R1C1FormulaRange{relativeRange=R[0]C[-12], startTokenIndex=2, workbookRangeId=null}, R1C1FormulaRange{relativeRange=R[0]C[-2], s

  <table border="1">
  <caption><i>Table 1. </i>Sample Data</caption>
    <thead>
      <tr>
        <th></th>
        <th>table_id</th>
        <th>doc_id</th>
        <th>record_index</th>
        <th>col_index</th>
        <th>formula</th>
        <th>formula_token_list</th>
        <th>ranges</th>
        <th>computed_value</th>
      </tr>
    </thead>
    <tbody>
      <tr align=center>
        <th align=left>1</th>
        <td>2</td>
        <td>b''</td>
        <td>1</td>
        <td>19</td>
        <td>b'RANGE UPLUS FORMULA_START'</td>
        <td>b"b'= +  RANGE'"</td>
        <td>b'G278457581!R[3]C[13]'</td>
        <td>b'doub:83.41666666666667'</td>
      </tr>
      <tr align=center>
        <th align=left>2</th>
        <td>2</td>
        <td>b''</td>
        <td>1</td>
        <td>20</td>
        <td>b'RANGE UPLUS FORMULA_START'</td>
        <td>b"b'= +  RANGE'"</td>
        <td>b'G278457581!R[4]C[12]'</td>
        <td>b'doub:54.220833333333346'</td>
      </tr>
      <tr align=center>
        <th align=left>3</th>
        <td>2</td>
        <td>b''</td>
        <td>1</td>
        <td>21</td>
        <td>b'RANGE UPLUS FORMULA_START'</td>
        <td>b"b'= +  RANGE'"</td>
        <td>b'G278457581!R[3]C[-17]'</td>
        <td>b'doub:49.215833333333336'</td>
      </tr>
    </tbody>
  </table>

<table border="1">
    <thead>
      <tr>
        <th></th>
        <th>header</th>
        <th>context_header</th>
      </tr>
    </thead>
    <tbody>
      <tr align=center>
        <th align=left>1</th>
        <td>b'Contract Cost'</td>
        <td>b'$C[-10]$New/Used$C[-9]$Delivery Date$C[-8]$Owner$C[-7]$DASH\nApproval\nDate$C[-6]$Notice to Proceed Given<br/>$C[-5]$Status of Financing$C[-4]$Controlled By$C[-3]$Originator / Developer$C[-2]$Project Manager$C[-1]$Project Name<br/>$C[0]$Contract Cost$C[1]$Contract Paid to Date\n(Scheduled)$C[2]$\nCancel-lation Payment$C[3]$Comments$C[4]$Status Update'</td>
      </tr>
      <tr align=center>
        <th align=left>2</th>
        <td>b'Contract Paid to Date\n(Scheduled)'</td>
        <td>b'$C[-10]$Delivery Date$C[-9]$Owner$C[-8]$DASH\nApproval\nDate$C[-7]$Notice to Proceed Given$C[-6]$Status of Financing<br/>$C[-5]$Controlled By$C[-4]$Originator / Developer$C[-3]$Project Manager$C[-2]$Project Name$C[-1]$Contract Cost<br/>$C[0]$Contract Paid to Date\n(Scheduled)$C[1]$\nCancel-lation Payment$C[2]$Comments$C[3]$Status Update'</td>
      </tr>
      <tr align=center>
        <th align=left>3</th>
        <td>b'\nCancel-lation Payment'</td>
        <td>b'$C[-10]$Owner$C[-9]$DASH\nApproval\nDate$C[-8]$Notice to Proceed Given$C[-7]$Status of Financing$C[-6]$Controlled By<br/>$C[-5]$Originator / Developer$C[-4]$Project Manager$C[-3]$Project Name$C[-2]$Contract Cost$C[-1]$Contract Paid to Date\n(Scheduled)<br/>$C[0]$\nCancel-lation Payment$C[1]$Comments$C[2]$Status Update'</td>
      </tr>
    </tbody>
<table>

<table border="1">
    <thead>
      <tr>
        <th></th>
        <th>context_data</th>
      </tr>
    </thead>
    <tbody>
      <tr align=center>
        <th align=left>1</th>
        <td>b'$R[-1]C[-10]$userEnteredValue=empty$<br/>$R[-1]C[-9]$userEnteredValue=empty$<br/>$R[-1]C[-8]$userEnteredValue=empty$<br/>$R[-1]C[-7]$userEnteredValue=empty$<br/>$R[-1]C[-6]$userEnteredValue=empty$<br/>$R[-1]C[-5]$userEnteredValue=empty$<br/>$R[-1]C[-4]$userEnteredValue=empty$<br/>$R[-1]C[-3]$userEnteredValue=empty$<br/>$R[-1]C[-2]$userEnteredValue=empty$<br/>$R[-1]C[-1]$userEnteredValue=empty$<br/>$R[-1]C[0]$userEnteredValue=empty$<br/>$R[-1]C[1]$userEnteredValue=empty$<br/>$R[-1]C[2]$userEnteredValue=empty$<br/>$R[-1]C[3]$userEnteredValue=empty$<br/>$R[-1]C[4]$userEnteredValue=empty$<br/>$R[0]C[-10]$userEnteredValue=str:New$<br/>$R[0]C[-9]$userEnteredValue=doub:37165.0$<br/>$R[0]C[-8]$userEnteredValue=str:Whitewing$<br/>$R[0]C[-7]$userEnteredValue=str:$4.5MM DASHed$<br/>$R[0]C[-6]$userEnteredValue=str:N$<br/>$R[0]C[-5]$userEnteredValue=str:N/A$<br/>$R[0]C[-4]$userEnteredValue=str:EWS$<br/>$R[0]C[-3]$userEnteredValue=str:John Chappell$<br/>$R[0]C[-2]$userEnteredValue=str:Stephen Heck$<br/>$R[0]C[-1]$userEnteredValue=str:Sale in Process$<br/>$R[0]C[0]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[3]C[-17], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:83.41666666666667$<br/>$R[0]C[1]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[3]C[13], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:54.220833333333346$<br/>$R[0]C[2]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[4]C[12], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:49.215833333333336$<br/>$R[0]C[3]$userEnteredValue=empty$<br/>$R[0]C[4]$userEnteredValue=str:DASH in progress.$<br/>$R[1]C[-10]$userEnteredValue=str:New$<br/>$R[1]C[-9]$userEnteredValue=doub:37196.0$<br/>$R[1]C[-8]$userEnteredValue=str:Whitewing$<br/>$R[1]C[-7]$userEnteredValue=str:$4.5MM DASHed$<br/>$R[1]C[-6]$userEnteredValue=str:N$<br/>$R[1]C[-5]$userEnteredValue=str:N/A$<br/>$R[1]C[-4]$userEnteredValue=str:EWS$<br/>$R[1]C[-3]$userEnteredValue=str:John Chappell$<br/>$R[1]C[-2]$userEnteredValue=str:Stephen Heck$<br/>$R[1]C[-1]$userEnteredValue=str:Sale in Process$<br/>$R[1]C[0]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[10]C[-17], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:83.41666666666667$<br/>$R[1]C[1]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[10]C[13], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:54.220833333333346$<br/>$R[1]C[2]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[11]C[12], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:49.215833333333336$<br/>$R[1]C[3]$userEnteredValue=empty$<br/>$R[1]C[4]$userEnteredValue=str:DASH in progress.$<br/>$R[2]C[-10]$userEnteredValue=str:New$<br/>$R[2]C[-9]$userEnteredValue=doub:37226.0$<br/>$R[2]C[-8]$userEnteredValue=str:Whitewing$<br/>$R[2]C[-7]$userEnteredValue=str:$4.5MM DASHed$<br/>$R[2]C[-6]$userEnteredValue=str:N$<br/>$R[2]C[-5]$userEnteredValue=str:N/A$<br/>$R[2]C[-4]$userEnteredValue=str:EWS$<br/>$R[2]C[-3]$userEnteredValue=str:John Chappell$<br/>$R[2]C[-2]$userEnteredValue=str:Stephen Heck$<br/>$R[2]C[-1]$userEnteredValue=str:Sale in Process$<br/>$R[2]C[0]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[17]C[-17], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:83.41666666666667$<br/>$R[2]C[1]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[17]C[13], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:54.220833333333346$<br/>$R[2]C[2]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[18]C[12], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:49.215833333333336$<br/>$R[2]C[3]$userEnteredValue=empty$<br/>$R[2]C[4]$userEnteredValue=str:DASH in progress.$<br/>$R[3]C[-10]$userEnteredValue=str:New$<br/>$R[3]C[-9]$userEnteredValue=doub:37135.0$<br/>$R[3]C[-8]$userEnteredValue=str:ENE B/S$<br/>$R[3]C[-7]$userEnteredValue=str:Approved$<br/>$R[3]C[-6]$userEnteredValue=str:N$<br/>$R[3]C[-5]$userEnteredValue=str:N/A$<br/>$R[3]C[-4]$userEnteredValue=str:EWS$<br/>$R[3]C[-3]$userEnteredValue=str:Dick Westfahl$<br/>$R[3]C[-2]$userEnteredValue=empty$<br/>$R[3]C[-1]$userEnteredValue=str:NEPCO / NESCO - Goldendale (EECC)$<br/>$R[3]C[0]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[32]C[-17], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:36.24736$<br/>$R[3]C[1]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[32]C[13], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:35.613031199999995$<br/>$R[3]C[2]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[33]C[12], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:35.7036496$<br/>$R[3]C[3]$userEnteredValue=empty$<br/>$R[3]C[4]$userEnteredValue=str:Contract in the works, possible buyer$<br/>$R[4]C[-10]$userEnteredValue=str:Used$<br/>$R[4]C[-9]$userEnteredValue=str:Delivered$<br/>$R[4]C[-8]$userEnteredValue=str:West LB$<br/>$R[4]C[-7]$userEnteredValue=str:Analyzing$<br/>$R[4]C[-6]$userEnteredValue=str:N$<br/>$R[4]C[-5]$userEnteredValue=str:N/A$<br/>$R[4]C[-4]$userEnteredValue=str:EA$<br/>$R[4]C[-3]$userEnteredValue=empty$<br/>$R[4]C[-2]$userEnteredValue=empty$<br/>$R[4]C[-1]$userEnteredValue=str:Sale in Process$<br/>$R[4]C[0]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[111]C[-17], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:17.25$<br/>$R[4]C[1]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[111]C[13], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:17.25$<br/>$R[4]C[2]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[112]C[12], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:17.25$<br/>$R[4]C[3]$userEnteredValue=empty$<br/>$R[4]C[4]$userEnteredValue=str:3 Potential buyers in Canada, no legitimate offer as of 1/25/01.$<br/>$R[5]C[-10]$userEnteredValue=str:Used$<br/>$R[5]C[-9]$userEnteredValue=str:Delivered$<br/>$R[5]C[-8]$userEnteredValue=str:West LB$<br/>$R[5]C[-7]$userEnteredValue=str:Analyzing$<br/>$R[5]C[-6]$userEnteredValue=str:N$<br/>$R[5]C[-5]$userEnteredValue=str:N/A$<br/>$R[5]C[-4]$userEnteredValue=str:EA$<br/>$R[5]C[-3]$userEnteredValue=empty$<br/>$R[5]C[-2]$userEnteredValue=empty$<br/>$R[5]C[-1]$userEnteredValue=str:Sale in Process$<br/>$R[5]C[0]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[118]C[-17], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:17.25$<br/>$R[5]C[1]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[118]C[13], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:17.25$<br/>$R[5]C[2]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[119]C[12], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:17.25$<br/>$R[5]C[3]$userEnteredValue=empty$<br/>$R[5]C[4]$userEnteredValue=str:3 Potential buyers in Canada, no legitimate offer as of 1/25/01.$<br/>$R[6]C[-10]$userEnteredValue=str:New$<br/>$R[6]C[-9]$userEnteredValue=str:being cleaned$<br/>$R[6]C[-8]$userEnteredValue=str:ENE B/S$<br/>$R[6]C[-7]$userEnteredValue=str:Analyzing$<br/>$R[6]C[-6]$userEnteredValue=str:N$<br/>$R[6]C[-5]$userEnteredValue=str:N/A$<br/>$R[6]C[-4]$userEnteredValue=str:EA$<br/>$R[6]C[-3]$userEnteredValue=str:David Fairley, Mathew Gimble$<br/>$R[6]C[-2]$userEnteredValue=empty$<br/>$R[6]C[-1]$userEnteredValue=str:Purchaser Identified$<br/>$R[6]C[0]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[21]C[-17], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:24.506$<br/>$R[6]C[1]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[21]C[13], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:24.506000000000007$<br/>$R[6]C[2]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[22]C[12], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:24.506$<br/>$R[6]C[3]$userEnteredValue=str:CALME purchased turbine from ENA; turbine has not yet cleared customs;  generator incurred salt water damage while unloading$<br/>$R[6]C[4]$userEnteredValue=empty$<br/>$R[7]C[-10]$userEnteredValue=empty$<br/>$R[7]C[-9]$userEnteredValue=empty$<br/>$R[7]C[-8]$userEnteredValue=str:E-Next Generation$<br/>$R[7]C[-7]$userEnteredValue=str:$16.5MM on 2/16/01$<br/>$R[7]C[-6]$userEnteredValue=str:N$<br/>$R[7]C[-5]$userEnteredValue=str:N/A$<br/>$R[7]C[-4]$userEnteredValue=str:EA$<br/>$R[7]C[-3]$userEnteredValue=str:Jake Thomas/Laura Wente$<br/>$R[7]C[-2]$userEnteredValue=empty$<br/>$R[7]C[-1]$userEnteredValue=str:Columbia$<br/>$R[7]C[0]$formula=RANGE FORMULA_START, formulaTokenList== RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[100]C[-17], startTokenIndex=1, workbookRangeId=null}], computedValue=doub:39.2$<br/>$R[7]C[1]$formula=RANGE FORMULA_START, formulaTokenList== RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[100]C[13], startTokenIndex=1, workbookRangeId=null}], computedValue=doub:6.272$<br/>$R[7]C[2]$formula=RANGE FORMULA_START, formulaTokenList== RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[101]C[12], startTokenIndex=1, workbookRangeId=null}], computedValue=doub:4.704$<br/>$R[7]C[3]$userEnteredValue=empty$<br/>$R[7]C[4]$userEnteredValue=empty$<br/>$R[8]C[-10]$userEnteredValue=str:New$<br/>$R[8]C[-9]$userEnteredValue=doub:37408.0$<br/>$R[8]C[-8]$userEnteredValue=str:West LB$<br/>$R[8]C[-7]$userEnteredValue=str:$2.5MM on 1/31/01$<br/>$R[8]C[-6]$userEnteredValue=str:N$<br/>$R[8]C[-5]$userEnteredValue=str:N/A$<br/>$R[8]C[-4]$userEnteredValue=str:EA$<br/>$R[8]C[-3]$userEnteredValue=empty$<br/>$R[8]C[-2]$userEnteredValue=empty$<br/>$R[8]C[-1]$userEnteredValue=str:Fort Pierce$<br/>$R[8]C[0]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[139]C[-17], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:43.618$<br/>$R[8]C[1]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[139]C[13], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:34.894400000000005$<br/>$R[8]C[2]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[140]C[12], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:43.618$<br/>$R[8]C[3]$userEnteredValue=empty$<br/>$R[8]C[4]$userEnteredValue=empty$<br/>$R[9]C[-10]$userEnteredValue=str:New$<br/>$R[9]C[-9]$userEnteredValue=doub:37591.0$<br/>$R[9]C[-8]$userEnteredValue=empty$<br/>$R[9]C[-7]$userEnteredValue=str:$4.2MM on 7/24/00$<br/>$R[9]C[-6]$userEnteredValue=str:N$<br/>$R[9]C[-5]$userEnteredValue=str:N/A$<br/>$R[9]C[-4]$userEnteredValue=str:EA$<br/>$R[9]C[-3]$userEnteredValue=str:Maurice Gilbert$<br/>$R[9]C[-2]$userEnteredValue=empty$<br/>$R[9]C[-1]$userEnteredValue=str:Las Vegas CoGen II$<br/>$R[9]C[0]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[34]C[-17], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:15.769725$<br/>$R[9]C[1]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[34]C[13], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:3.94243125$<br/>$R[9]C[2]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[35]C[12], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:3.153945$<br/>$R[9]C[3]$userEnteredValue=empty$<br/>$R[9]C[4]$userEnteredValue=empty$<br/>$R[10]C[-10]$userEnteredValue=str:New$<br/>$R[10]C[-9]$userEnteredValue=doub:37257.0$<br/>$R[10]C[-8]$userEnteredValue=empty$<br/>$R[10]C[-7]$userEnteredValue=str:$4.2MM on 7/24/00$<br/>$R[10]C[-6]$userEnteredValue=str:N$<br/>$R[10]C[-5]$userEnteredValue=str:N/A$<br/>$R[10]C[-4]$userEnteredValue=str:EA$<br/>$R[10]C[-3]$userEnteredValue=str:Maurice Gilbert$<br/>$R[10]C[-2]$userEnteredValue=empty$<br/>$R[10]C[-1]$userEnteredValue=str:Las Vegas CoGen II$<br/>$R[10]C[0]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[41]C[-17], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:15.769725$<br/>$R[10]C[1]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[41]C[13], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:3.94243125$<br/>$R[10]C[2]$formula=RANGE UPLUS FORMULA_START, formulaTokenList==+ RANGE, ranges=[R1C1FormulaRange{relativeRange=G278457581!R[42]C[12], startTokenIndex=2, workbookRangeId=null}], computedValue=doub:3.153945$<br/>$R[10]C[3]$userEnteredValue=empty$<br/>$R[10]C[4]$userEnteredValue=empty$'
      </tr>
    </tbody>
<table>

## Data preprocess
```
record_index: tf.int64
col_index: tf.int64
formula: tf.string, '<SKETCH> $ENDSKETCH$ $R$ R[] C[] $SEP$ R[] C[] $ENDR$ EOF'
context_header: tf.string
context_data: tf.string, [21, 21]
```

In [8]:
import re
import numpy as np
import tensorflow_datasets as tfds
from constants import SPECIAL_TOKEN_LIST

D = 10
print(SPECIAL_TOKEN_LIST)

['PAD', 'EOF', 'UNK', 'GO', 'empty', '$RANGESPLIT$', '$ENDRANGE$', '$ENDFORMULASKETCH$', 'R[0]', 'R[1]', 'R[2]', 'R[3]', 'R[4]', 'R[5]', 'R[6]', 'R[7]', 'R[8]', 'R[9]', 'R[10]', 'R[-1]', 'R[-2]', 'R[-3]', 'R[-4]', 'R[-5]', 'R[-6]', 'R[-7]', 'R[-8]', 'R[-9]', 'R[-10]', 'C[0]', 'C[1]', 'C[2]', 'C[3]', 'C[4]', 'C[5]', 'C[6]', 'C[7]', 'C[8]', 'C[9]', 'C[10]', 'C[-1]', 'C[-2]', 'C[-3]', 'C[-4]', 'C[-5]', 'C[-6]', 'C[-7]', 'C[-8]', 'C[-9]', 'C[-10]', 'str:', 'doub:', 'bool:', 'quotedstr:', 'err:', 'image:', 'sparkchart:']


In [37]:
# Complete brackets, eliminate Gxxxx
def format_range(range_instance):
  range_eliminate = re.sub('G\d*', '', range_instance)
  range_eliminate = re.sub('"', '', range_instance)
  str_list = list(range_eliminate)
  for i in range(len(str_list) + 8):
    if i >= len(str_list):
      break
    if (str_list[i] == 'R' or str_list[i] == 'C') and str_list[i + 1] != '[':
      str_list.insert(i + 1, '[')
    if (str_list[i] == 'C' or str_list[i] == ':') and str_list[i - 1] != ']':
        str_list.insert(i, ']')
  if str_list[-1] != ']':
    str_list += ']'

  format_string = ''.join(str_list)
  range_values = np.array([int(range_string[1:-1]) for range_string in re.findall('\[.*?\]', format_string)])
  exceed = not ((-D <= range_values) & (range_values < 11)).all()

  return format_string, exceed

a, b = format_range('G34545R-11C5:R[3]C[9')
print(a, b)

G34545R[-11]C[5]:R[3]C[9] True


In [38]:
# 
def format_example(example):
  # print(example)

  formula_token_list = str(example['formula_token_list'].numpy())
  ranges = str(example['ranges'].numpy())
  context_header = str(example['context_header'].numpy())
  context_data = str(example['context_data'].numpy())


  sketch = re.search("b'(.*)'", formula_token_list).group(1)
  sketch = sketch.replace('=', 'GO', 1)
  range_simple = re.search("b'(.*)'", ranges)
  if range_simple:
    range_list = range_simple.group(1).split(' ')
  else:
    range_list = re.findall('!(.*?) ', ranges) + ranges.split("!")[-1:]
  exceed = False
  format_ranges = []
  for ra in range_list:
    format_ra, exceed = format_range(ra)
    format_ranges.append(format_ra)
    if exceed:
      break
  if exceed:
    return None

  formula = sketch + ' $ENDFORMULASKETCH$ '
  for (i, range_i) in enumerate(format_ranges):
    if(i):
      formula += ' $RANGESPLIT$ '
    formula += range_i
  formula += ' $ENDRANGE$ EOF'
  
  tokenizer = tfds.deprecated.text.Tokenizer(
      alphanum_only=False, reserved_tokens=SPECIAL_TOKEN_LIST)
  formula_tokens = tokenizer.tokenize(formula.replace(' ', ''))
  formula_doc.append(formula_tokens)
  
  cells = context_header[3:-1].split('$')

  if len(cells):
    start_col_index = int(cells[0][2:-1])
    for i in range(len(cells)):
      if i >= len(cells):
        break
      while i % 2 == 0 and cells[i] != "C[" + str(start_col_index + i // 2) + "]":
        part_of_former = cells[i]
        cells.pop(i)
        if i > 0:
          cells[i - 1] += '$' + part_of_former
        if i >= len(cells):
          break
      if i >= len(cells):
        break
        
    end_col_index = int(cells[-2][2:-1])
  else:
    start_col_index = 10
    end_col_index = 10

  context_header = ''
  if start_col_index > -D:
    context_header += 'empty [SEP] ' * (start_col_index + D)

  for i in range(end_col_index - start_col_index + 1):
    if cells[2 * i + 1]:
      context_header += cells[2 * i + 1]
    else:
      context_header += 'empty'
    if i != end_col_index - start_col_index:
      context_header += ' [SEP] '

  if end_col_index < D:
    context_header += ' [SEP] empty' * (D - end_col_index)
  else:
    context_header += 'empty'


  cells = context_data[3:-1].split('$')
    
  context_data = ['empty'] * 21 * 21
  cell_num = len(cells) // 3
  last_index = 0

  for i in range(cell_num):

    if len(cells) <= 3 * i:
      break;
    cell_index = re.search("R\[(.*)\]C\[(.*)\]", cells[3 * i])
    while not cell_index:
      part_of_last = cells[3 * i - 1]
      cells.pop(3 * i - 1)
      context_data[last_index] += '$' + part_of_last
      if 3 * i >= len(cells):
        break
      cell_index = re.search("R\[(.*)\]C\[(.*)\]", cells[3 * i])
    if not cell_index:
      break

    row_index = int(cell_index.group(1))
    col_index = int(cell_index.group(2))
    context_index = (row_index + 10) * 21 + col_index + 10
    last_index = context_index

    value = cells[3 * i + 1]
    value = re.search("Value=(.*)", value).group(1)

    context_data[context_index] = value

  formula = tf.convert_to_tensor(np.array(formula),
                                         dtype=tf.string)
  context_header = tf.convert_to_tensor(np.array(context_header),
                                         dtype=tf.string)
  context_data = tf.convert_to_tensor(np.array(context_data).reshape(21, 21),
                                         dtype=tf.string)
  return {'record_index': example['record_index'], 
          'col_index': example['col_index'],
          'formula': formula, 
          'context_header': context_header,
          'context_data': context_data
          }

format_dataset = []
formula_doc = []
for parsed_data in parsed_dataset:
  format_data = format_example(parsed_data)
  if format_data:
    format_dataset.append(format_data)


In [40]:
# Show the first observations in datasets.
for format_record in format_dataset[:0]:
  print(repr(format_record))

## Statistics

### From paper

  <table border="1">
  <caption><i>Table 1. </i>Dataset</caption>
    <thead>
      <tr>
        <th></th>
        <th>Train</th>
        <th>Validation</th>
        <th>Test</th>
        <th>Total</th>
      </tr>
    </thead>
    <tbody>
      <tr align=center>
        <th align=left>#Samples</th>
        <td>178K</td>
        <td>41K</td>
        <td>33K</td>
        <td>252K</td>
      </tr>
    </tbody>
  </table>

  <br />

  <table border="1">
  <caption><i>Table 2. </i>Sketch Length (excluding ENDSKETCH).</caption>
    <thead>
      <tr>
        <th align=left>Length</th>
        <th>2</th>
        <th>3</th>
        <th>4-5</th>
        <th>6-7</th>
        <th>8+</th>
      </tr>
    </thead>
    <tbody>
      <tr align=center>
        <th align=left>Distribution</th>
        <td>55%</td>
        <td>18%</td>
        <td>13%</td>
        <td>9%</td>
        <td>5%</td>
      </tr>
    </tbody>
  </table>

  <br />
  
  <table border="1">
  <caption><i>Table 3. </i>Spreadsheet Functions & Operators</caption>
    <thead>
      <tr>
        <th></th>
        <th>Spreadsheet Functions</th>
        <th>Type</th>
      </tr>
    </thead>
    <tbody>
      <tr align=center>
        <th align=left>1</th>
        <td>ADD(+)</td>
        <td>Operator</td>
      </tr>
      <tr align=center>
        <th align=left>2</th>
        <td>MINUS(-)</td>
        <td>Operator</td>
      </tr>
      <tr align=center>
        <th align=left>3</th>
        <td>MULPTIPLY(*)</td>
        <td>Operator</td>
      </tr>
      <tr align=center>
        <th align=left>4</th>
        <td>DIV(/)</td>
        <td>Operator</td>
      </tr>
      <tr align=center>
        <th align=left>5</th>
        <td>UPLUS</td>
        <td>Operator</td>
      </tr>
      <tr align=center>
        <th align=left>6</th>
        <td>UMINUS</td>
        <td>Operator</td>
      </tr>      
      <tr align=center>
        <th align=left>7</th>
        <td>SUM</td>
        <td>Math</td>
      </tr>
      <tr align=center>
        <th align=left>8</th>
        <td>ABS</td>
        <td>Math</td>
      </tr>
      <tr align=center>
        <th align=left>9</th>
        <td>LN</td>
        <td>Math</td>
      </tr>
      <tr align=center>
        <th align=left>10</th>
        <td>AVERAGE</td>
        <td>Statistical</td>
      </tr>
      <tr align=center>
        <th align=left>11</th>
        <td>MIN</td>
        <td>Statistical</td>
      </tr>
      <tr align=center>
        <th align=left>12</th>
        <td>MAX</td>
        <td>Statistical</td>
      </tr>
      <tr align=center>
        <th align=left>13</th>
        <td>COUNT</td>
        <td>Statistical</td>
      </tr>
      <tr align=center>
        <th align=left>14</th>
        <td>COUNTA</td>
        <td>Statistical</td>
      </tr>
      <tr align=center>
        <th align=left>15</th>
        <td>STDEV</td>
        <td>Statistical</td>
      </tr>
      <tr align=center>
        <th align=left>16</th>
        <td>DAY</td>
        <td>Date</td>
      </tr>
      <tr align=center>
        <th align=left>17</th>
        <td>WEEKDAY</td>
        <td>Date</td>
      </tr>
    </tbody>
  </table>

<br />

### From data

In [41]:
TOTAL_SIZE = len(format_dataset)
print(TOTAL_SIZE)

172908


## Construct output vocabulary

In [43]:
from collections import defaultdict

class Vocab:
  def __init__(self, tokens=None):
    self.idx_to_token = list()
    self.token_to_idx = dict()

    if tokens is not None:
      if "UNK" not in tokens:
        tokens = tokens + ["UNK"]
      for token in tokens:
        self.idx_to_token.append(token)
        self.token_to_idx[token] = len(self.idx_to_token) - 1
      self.unk = self.token_to_idx['UNK']

  @classmethod
  def build(cls, text, min_freq=1, reserved_tokens=None):
    token_freqs = defaultdict(int)
    for sentence in text:
      for token in sentence:
        token_freqs[token] += 1
    token_list = reserved_tokens if reserved_tokens else []
    uniq_tokens = set(token_list)
    for token, freq in token_freqs.items():
      if freq >= min_freq and (token not in uniq_tokens):
        token_list += [token]
        uniq_tokens.add(token)
    return cls(tokens=token_list)
  
  def __len__(self):
    return len(self.idx_to_token)

  def __getitem__(self, token):
    return self.token_to_idx
  
  def convert_tokens_to_ids(self, tokens):
    return [self[token] for token in tokens]

  def convert_ids_to_tokens(self, indices):
    return [self.idx_to_token[index] for index in indices]

In [44]:
# Filter out tokens that appear less than 10 times in the training set to construct the output formula token vocabulary.
output_vocab = Vocab.build(formula_doc, min_freq=10, reserved_tokens=SPECIAL_TOKEN_LIST)
VOCAB_SIZE = len(output_vocab)

print(output_vocab.idx_to_token)
print(VOCAB_SIZE)

['PAD', 'EOF', 'UNK', 'GO', 'empty', '$RANGESPLIT$', '$ENDRANGE$', '$ENDFORMULASKETCH$', 'R[0]', 'R[1]', 'R[2]', 'R[3]', 'R[4]', 'R[5]', 'R[6]', 'R[7]', 'R[8]', 'R[9]', 'R[10]', 'R[-1]', 'R[-2]', 'R[-3]', 'R[-4]', 'R[-5]', 'R[-6]', 'R[-7]', 'R[-8]', 'R[-9]', 'R[-10]', 'C[0]', 'C[1]', 'C[2]', 'C[3]', 'C[4]', 'C[5]', 'C[6]', 'C[7]', 'C[8]', 'C[9]', 'C[10]', 'C[-1]', 'C[-2]', 'C[-3]', 'C[-4]', 'C[-5]', 'C[-6]', 'C[-7]', 'C[-8]', 'C[-9]', 'C[-10]', 'str:', 'doub:', 'bool:', 'quotedstr:', 'err:', 'image:', 'sparkchart:', '+', 'RANGE', '-', 'C', '[-', '12', ']', 'SUM', '(', ')', 'R', '33', ':', ']:', '14', '11', '31', '"]', '*', '13', '/', '[', '24', ']"]', ')/', '92', '91', '95', '93', '96', '94', '97', '98', '99', '100', '101', '102', '103', '104', '105', '106', '15', '107', '16', '108', '17', '109', '18', '110', '19', '111', '20', '112', '21', '113', '22', '114', '23', '115', '116', '25', '117', '26', '118', '27', '119', '28', '120', '29', '121', '30', '122', '123', '32', '124', '125', '3

## Split dataset

In [1]:
TRAIN_SIZE = 130000
VALID_SIZE = 25000

BATCH_SIZE = 64

dataset = tf.data.Dataset.from_tensor_slices(format_dataset)
train_data = dataset.take(TRAIN_SIZE).batch(BATCH_SIZE)
valid_data = dataset.skip(TRAIN_SIZE).take(VALID_SIZE).batch(BATCH_SIZE)
test_data = dataset.skip(TRAIN_SIZE + VALID_SIZE).batch(BATCH_SIZE)

NameError: ignored

# Model![model architecture](https://raw.github.com/wy-go/google-research/master/spreadsheet_coder/spreadsheetcoder_model_architecture.png)


In [45]:
# Install deps
!pip install -q -U tensor2tensor

from model import create_model
from model_utils import get_train_op
from bert_modeling import BertConfig

NUM_ENCODER_LAYERS = 1 # not provided in paper
NUM_DECODER_LAYERS = 1 
EMBEDDING_SIZE = 512 # not provided in paper
HIDDEN_SIZE = 512
DROPOUT_RATE = 0.1
BEAM_SIZE = 64

MAX_CELL_CONTEXT_LENGTH = 128
FORMULA_LENGTH = 30 # not provided in paper

BERT_CONFIG = BertConfig(hidden_size=512, hidden_act='gelu', 
                         initializer_range=0.02, vocab_size=30522, 
                         hidden_dropout_prob=0.1, num_attention_heads=8,
                         type_vocab_size=2, max_position_embeddings=512,
                         num_hidden_layers=8, intermediate_size=2048,
                         attention_probs_dropout_prob=0.1)  # BERT-Medium

formula_placeholder = tf.compat.v1.placeholder(name='formula', dtype=tf.int32)    # [batch_size, formula_length]
row_cell_context_placeholder = tf.compat.v1.placeholder(name='row_cell_context', dtype=tf.int32) 
col_cell_context_placeholder = tf.compat.v1.placeholder(name='col_cell_context', dtype=tf.int32) 
row_context_mask_placeholder = tf.compat.v1.placeholder(name='row_context_mask', dtype=tf.float32) 
col_context_mask_placeholder = tf.compat.v1.placeholder(name='col_context_mask', dtype=tf.float32) 
row_context_segment_ids_placeholder = tf.compat.v1.placeholder(name='row_context_segment_ids', dtype=tf.int32) 
col_context_segment_ids_placeholder = tf.compat.v1.placeholder(name='col_context_segment_ids', dtype=tf.int32) 
row_cell_indices_placeholder = tf.compat.v1.placeholder(name='row_cell_indices', dtype=tf.int32) 
col_cell_indices_placeholder = tf.compat.v1.placeholder(name='col_cell_indices', dtype=tf.int32) 
row_context_mask_per_cell_placeholder = tf.compat.v1.placeholder(name='row_context_mask_per_cell', dtype=tf.float32) 
col_context_mask_per_cell_placeholder = tf.compat.v1.placeholder(name='col_context_mask_per_cell', dtype=tf.float32) 
row_context_segment_ids_per_cell_placeholder = tf.compat.v1.placeholder(name='row_context_segment_ids_per_cell', dtype=tf.int32) 
col_context_segment_ids_per_cell_placeholder = tf.compat.v1.placeholder(name='col_context_segment_ids_per_cell', dtype=tf.int32) 
record_index_placeholder = tf.compat.v1.placeholder(name='record_index', dtype=tf.int32) 
column_index_placeholder = tf.compat.v1.placeholder(name='column_index', dtype=tf.int32) 

SyntaxError: ignored

bert_modeling:
- get_shape_list()
- BertModel()
- gelu()

model_utils:
- configure_tpu(flags): Configures the TPU from the command line flags.
- class AdamWeightDecayOptimizer(): apply_gradients(grads_and_vars, global_step=None, name=None)
- get_train_op(flags, total_loss, ema=None, tvars=None): Generates the training operation.
- construct_scalar_host_call(monitor_dict, model_dir, prefix="", reduce_fn=None): Construct host calls to monitor training progress on TPUs.
- build_lstm(num_units)
- print_tensors(**tensors): Host call function to print Tensors from the TPU during training.
- get_assignment_map_from_checkpoint(vars_to_restore, init_checkpoint, bert_prefix=""): Compute the union of the current variables and checkpoint variables.

# Experiment

  <table border="1">
    <caption><i>Table 4. </i>Hyper-parameters</caption>
    <thead>
      <tr align=center>
        <th></th>
        <th></th>
      </tr>
    </thead>
    <tbody>
      <tr align=center>
        <th align=left>#Encoder layer</th>
        <td></td>
      </tr>
      <tr align=center>
        <th align=left>#Decoder layer</th>
        <td>1</td>
      </tr>
      <tr align=center>
        <th align=left>Embedding size</th>
        <td></td>
      </tr>
      <tr align=center>
        <th align=left>Hidden size</th>
        <td>512</td>
      </tr>
      <tr align=center>
        <th align=left>Dropout rate</th>
        <td>0.1</td>
      </tr>
      <tr align=center>
        <th align=left>Initial lr</th>
        <td>5e-5</td>
      </tr>
      <tr align=center>
        <th align=left>Batch size</th>
        <td>64</td>
      </tr>
      <tr align=center>
        <th align=left>Gradient clipping norm</th>
        <td>1.0</td>
      </tr>
      <tr align=center>
        <th align=left>Optimizer</th>
        <td>Adam</td>
      </tr>
      <tr align=center>
        <th align=left>#Minibatch updates</th>
        <td>200K</td>
      </tr>
      <tr align=center>
        <th align=left>Beam size</th>
        <td>64</td>
      </tr>
    </tbody>
  </table>

## Full model

In [None]:
y_train = create_model(num_encoder_layers=NUM_ENCODER_LAYERS,
                      num_decoder_layers=NUM_DECODER_LAYERS,
                      embedding_size=EMBEDDING_SIZE,
                      hidden_size=HIDDEN_SIZE,
                      dropout_rate=DROPOUT_RATE,
                      is_traing=True,
                      formula=formula_placeholder, 
                      row_cell_context=row_cell_context_placeholder,
                      col_cell_context=col_cell_context_placeholder,
                      row_context_mask=row_context_mask_placeholder, 
                      col_context_mask=col_context_mask_placeholder,
                      row_context_segment_ids=row_context_segment_ids_placeholder,
                      col_context_segment_ids=col_context_segment_ids_placeholder,
                      row_cell_indices=row_cell_indices_placeholder,
                      col_cell_indices=col_cell_indices_placeholder,
                      row_context_mask_per_cell=row_context_mask_per_cell_placeholder,
                      col_context_mask_per_cell=col_context_mask_per_cell_placeholder,
                      row_context_segment_ids_per_cell=row_context_segment_ids_per_cell_placeholder,
                      col_context_segment_ids_per_cell=col_context_segment_ids_per_cell_placeholder,
                      exclude_headers=False,
                      max_cell_context_length=MAX_CELL_LENGTH, #
                      num_rows=21,
                      record_index=record_index_placeholder,
                      column_index=column_index_placeholder,
                      layer_norm=True,
                      cell_position_encoding=, #
                      cell_context_encoding=True,
                      use_bert=True,
                      use_mobilebert=False,
                      per_row_encoding=, #
                      max_pooling=, #
                      use_cnn=True,
                      use_pointer_network=False,
                      two_stage_encoding=True,
                      conv_type=, #
                      grid_type='both',
                      skip_connection=, #
                      bert_config=bert_config,
                      unused_tensors_to_print=True, # unused
                      formula_length=FORMULA_LENGTH,
                      formula_prefix_length=0,
                      vocab_size=VOCAB_SIZE,
                      beam_size=BEAM_SIZE,
                      use_tpu=True,
                      use_one_hot_embeddings=False
                      )

y_test, y_p = create_model(num_encoder_layers=NUM_ENCODER_LAYERS,
                      num_decoder_layers=NUM_DECODER_LAYERS,
                      embedding_size=EMBEDDING_SIZE,
                      hidden_size=HIDDEN_SIZE,
                      dropout_rate=DROPOUT_RATE,
                      is_traing=True,
                      formula=formula_placeholder, 
                      row_cell_context=row_cell_context_placeholder,
                      col_cell_context=col_cell_context_placeholder,
                      row_context_mask=row_context_mask_placeholder, 
                      col_context_mask=col_context_mask_placeholder,
                      row_context_segment_ids=row_context_segment_ids_placeholder,
                      col_context_segment_ids=col_context_segment_ids_placeholder,
                      row_cell_indices=row_cell_indices_placeholder,
                      col_cell_indices=col_cell_indices_placeholder,
                      row_context_mask_per_cell=row_context_mask_per_cell_placeholder,
                      col_context_mask_per_cell=col_context_mask_per_cell_placeholder,
                      row_context_segment_ids_per_cell=row_context_segment_ids_per_cell_placeholder,
                      col_context_segment_ids_per_cell=col_context_segment_ids_per_cell_placeholder,
                      exclude_headers=False,
                      max_cell_context_length=MAX_CELL_LENGTH, #
                      num_rows=21,
                      record_index=record_index_placeholder,
                      column_index=column_index_placeholder,
                      layer_norm=True,
                      cell_position_encoding=, #
                      cell_context_encoding=True,
                      use_bert=True,
                      use_mobilebert=False,
                      per_row_encoding=, #
                      max_pooling=, #
                      use_cnn=True,
                      use_pointer_network=False,
                      two_stage_encoding=True,
                      conv_type=, #
                      grid_type='both',
                      skip_connection=, #
                      bert_config=bert_config,
                      unused_tensors_to_print=True, # unused
                      formula_length=FORMULA_LENGTH,
                      formula_prefix_length=0,
                      vocab_size=VOCAB_SIZE,
                      beam_size=BEAM_SIZE,
                      use_tpu=True,
                      use_one_hot_embeddings=False
                      )


loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=formula_placeholder, y_pred=y_train)

top1_value = tf.gather(y_test, tf.argmax(y_p, axis=-1, output_type=tf.int32), axis=1)

count = y_test.get_shape()[0]
top1_accuracy = tf.reduce_sum(tf.equal(tf.reduce_sum(tf.cast(tf.equal(formula_placeholder, top1_value), tf.int32), 1), tf.fill([count], y_test.get_shape()[-1])))/count

import model_utils

INIT_LR = 5e-5
CLIP = 1.0
NUM_EPOCH = 72

train_op, learning_rate, gnorm = model_utils.get_train_op(flags, loss)
init_op = tf.global_variables_initializer()

iterator = tf.data.Iterator.from_structure(train_data.output_types, train_data.output_shapes)

next_elem = iterator.get_next()

train_init_op = iterator.make_initializer(train_data)

with tf.Session() as sess:
  #Initialize variables
  sess.run(tf.global_variables_initializer())
  sess.run(train_init_op)
  while True:
    try:
      print(sess.run(loss, next_elem))
    except tf.errors.OutOfRangeError:
      print('End of training datset.')
      break





In [None]:
# Test

test_init_op = iterator.make_initializer(test_data)

with tf.Session() as sess:
  #Initialize variables
  sess.run(tf.global_variables_initializer())
  sess.run(test_init_op)
  while True:
    try:
      print(sess.run(top1_accuracy, next_elem))
    except tf.errors.OutOfRangeError:
      print('End of test datset.')
      break

## Different encoder architecture
- — Column-based BERT
- — Row-based BERT
- — Convolution layers

### — Column-based BERT

### — Row-based BERT

### — Convolution layers

## Different decoding architecture
Same predictor for both the sketch and ranges, with a single output vocabulary

### — Two-stage decoding

## Different model initialization
Randomly initialize BERT encoders

### — Pretraining

## Previous approaches
Randomly initialize BERT encoders

### Row-based RobustFill

### Column-based RobustFill

## Baseline
LSTM decoder only

### No context


# Results

  <table border="1">
  <caption><i>Table 5. </i>Full model formula accuracy from paper.</caption>
    <thead>
      <tr>
        <th align=left>Dataset</th>
        <th>Top-1</th>
        <th>Top-5</th>
        <th>Top-10</th>
      </tr>
    </thead>
    <tbody>
      <tr align=center>
        <th align=left>Enron</th>
        <td>29.8%</td>
        <td>41.8%</td>
        <td>48.5%</td>
      </tr>
      <tr align=center>
        <th align=left>Google Sheets</th>
        <td>42.51%</td>
        <td>54.41%</td>
        <td>58.57%</td>
      </tr>
    </tbody>
  </table>

  <br/>

  <table border="1">
    <caption><i>Table 6. </i>Formula accuracy on the Google Sheets test set.</caption>
    <thead>
      <tr>
        <th align=left>Approach</th>
        <th>Top-1</th>
        <th>Top-5</th>
        <th>Top-10</th>
      </tr>
    </thead>
    <tbody>
      <tr align=center>
        <th align=left>Full model</th>
        <td><b>42.51%</b></td>
        <td><b>54.41%</b></td>
        <td><b>58.57%</b></td>
      </tr>
      <tr align=center>
        <th align=left>— Column-based BERT</th>
        <td>39.42%</td>
        <td>51.68%</td>
        <td>56.50%</td>
      </tr>
      <tr align=center>
        <th align=left>— Row-based BERT</th>
        <td>20.37%</td>
        <td>40.87%</td>
        <td>48.37%</td>
      </tr>
      <tr align=center>
        <th align=left>— Convolution layers</th>
        <td>38.43%</td>
        <td>51.31%</td>
        <td>55.87%</td>
      </tr>
      <tr align=center>
        <th align=left>— Two-stage decoding</th>
        <td>41.12%</td>
        <td>53.57%</td>
        <td>57.95%</td>
      </tr>
      <tr align=center>
        <th align=left>— Pretraining</th>
        <td>31.51%</td>
        <td>42.64%</td>
        <td>49.77%</td>
      </tr>
      <tr align=center>
        <th align=left>Row-based RobustFill</th>
        <td>31.14%</td>
        <td>40.09%</td>
        <td>47.10%</td>
      </tr>
      <tr align=center>
        <th align=left>Column-based RobustFill</th>
        <td>20.65%</td>
        <td>39.69%</td>
        <td>46.96%</td>
      </tr>
      <tr align=center>
        <th align=left>No context</th>
        <td>10.56%</td>
        <td>23.27%</td>
        <td>31.96%</td>
      </tr>
    </tbody>
  </table>