# Imports

In [None]:
import sys
device: str
if 'google.colab' in sys.modules:
    device = 'colab'
if 'kaggle_web_client' in sys.modules:
    device = 'kaggle'
else:
    device = 'locally'

In [None]:
%%time
if device == 'colab':
    %pip install -q tensorflow_text
    %pip install -q tqdm
    %pip install -q wandb --upgrade
elif device == 'kaggle':
    %pip install -q google-cloud-bigquery-storage
    %pip install -q numpy==1.19.0
    %pip install -q tensorflow==2.9.1
    %pip install -q absl-py==0.9
    %pip install -q matplotlib==3.1.1
    %pip install -q protobuf==3.11.2
    %pip install -q tensorflow_text==2.9.0
    %pip install -q tqdm
    %pip install -q wandb --upgrade

%pip install flopco-keras

In [None]:
# standard liberties:
from typing import Optional, List, Set, Dict, Tuple
import datetime
import os
import random
import statistics
import math
import time
import sys
# NOT-standard liberties:
import wandb
import tensorflow as tf
import tensorflow_text as tf_text
import pandas as pd
import seaborn as sns
import numpy as np
import tqdm
import matplotlib.pyplot as plt
# My code:
from setransformer import SeTransformer

<IPython.core.display.Javascript object>

In [None]:
print(f"Python version: {sys.version}")
print(f"Tensorflow version: {tf.__version__}")
print(f"tf text version: {tf_text.__version__}")

Python version: 3.7.13 (default, Apr 24 2022, 01:04:09) 
[GCC 7.5.0]
Tensorflow version: 2.9.1
tf text version: 2.9.0


In [None]:
print('GPU info:')
!nvidia-smi

GPU info:
Thu Jul  7 13:15:46 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   40C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------

# Settings

In [None]:
tf.random.set_seed(0)
random.seed(0)
# tf.keras.backend.set_floatx('float16')

## Define a strategy - Accelerator optimization 

In [None]:
%%capture
# disable printing
try:
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver();
    tf.config.experimental_connect_to_cluster(resolver);
    tf.tpu.experimental.initialize_tpu_system(resolver);
    strategy = tf.distribute.TPUStrategy(resolver);
    using_tpu: bool = True;
except ValueError:
    print("You must connect to a TPU in order to train a model. The models dont fit in a colab GPU")
    exit()

# Data loading

In [None]:
if device == 'colab':  # If notebook is ran on colab
    from google.colab import drive
    drive.mount('/drive')
    df: pd.DataFrame = pd.read_csv('/drive/MyDrive/final_project/wikipedia_articles.csv')
else:  # If notebook is ran on my laptop
    df: pd.DataFrame = pd.read_csv('wiki_data/articles.csv')
print(df.shape)

Drive already mounted at /drive; to attempt to forcibly remount, call drive.mount("/drive", force_remount=True).
(30279, 2)


In [None]:
df: pd.Series = df['text']
data_list: List[str] = df.to_list()
DATA_SIZE = len(data_list)
print(f"There are {DATA_SIZE} data points")
string_lengths: List[int] = [len(data_point) for data_point in data_list]
max_string_len = max(string_lengths)
print(f"The length of the longest text IN CHARACTERS is: {max_string_len}")
min_string_len = min(string_lengths)
print(f"The length of the shortest text IN CHARACTERS is: {min_string_len}")

There are 30279 data points
The length of the longest text IN CHARACTERS is: 141803
The length of the shortest text IN CHARACTERS is: 816


## Creating the vocabulary

In [None]:
# %%time
bert_tokenizer_params: dict = dict(lower_case=True)
VOCAB_SIZE: int = 8192  # Always the same for all models

if device == 'colab':  # If notebook is ran on colab
    path = '/drive/MyDrive/final_project/vocab.txt'
else:  # If notebook is ran on my laptop
    path = 'C:/yoni/final_project/model/vocab.txt'


reserved_tokens: List[str] = ["[PAD]", "[UNK]", "[START]", "[END]", "[MASK]"]

if os.path.exists(path):
    with open(path, 'r') as f:
        vocab: List[str] = f.read().split()
else:
    vocab_args: dict = dict(
        # The target vocabulary size
        vocab_size = VOCAB_SIZE,
        # Reserved tokens that must be included in the vocabulary
        reserved_tokens=reserved_tokens,
        # Arguments for `tf_text.BertTokenizer`
        bert_tokenizer_params=bert_tokenizer_params,
        # Arguments for `wordpiece_vocab.wordpiece_tokenizer_learner_lib.learn`
        learn_params={},
    )
    tensor_list: list = [tf.convert_to_tensor(data_point) for data_point in data_list]
    data_set: tf.data.Dataset = tf.data.Dataset.from_tensor_slices(tensor_list)
    # I already ran this code and saved the file to C:/yoni/final_project/model/vocab.txt
    vocab: List[str] = tf_text.bert_vocab_from_dataset.bert_vocab_from_dataset(
        data_set,
        **vocab_args,)
    with open('C:/yoni/final_project/model/vocab.txt', 'w') as f:
        for token in vocab:
            f.write(token + ' ')

In [None]:
print(f"the type of the items in vocab: {type(vocab[0])}")
print(f"the first 15 items in vocab: {vocab[:15]}")
print(f" the length of vocab: {len(vocab)}")

the type of the items in vocab: <class 'str'>
the first 15 items in vocab: ['[PAD]', '[UNK]', '[START]', '[END]', '[MASK]', "'", ',', '.', '0', '1', '2', '3', '4', '5', '6']
 the length of vocab: 7882


In [None]:
tensor_vocab: List[tf.Tensor] = [tf.convert_to_tensor(token_key, dtype=tf.string) for token_key in vocab]  # dtype = tf.String
print(f" the type of the items in tensor_vocab is: {type(tensor_vocab[0])}")
print(f" the data type of the tensors in tensor_vocab is: {tensor_vocab[0].dtype}")
vocab_size = len(vocab)

 the type of the items in tensor_vocab is: <class 'tensorflow.python.framework.ops.EagerTensor'>
 the data type of the tensors in tensor_vocab is: <dtype: 'string'>


## Creating the tokenizer

In [None]:
lookup_table = tf.lookup.StaticVocabularyTable(
    tf.lookup.KeyValueTensorInitializer(
        keys=tensor_vocab,
        key_dtype=tf.string,
        values=tf.range(tf.size(vocab, out_type=tf.int64), dtype=tf.int64),
        value_dtype=tf.int64),
    num_oov_buckets=1
)
tokenizer = tf_text.BertTokenizer(lookup_table, **bert_tokenizer_params)

## Tokenizing the data

In [None]:
# START: int = tf.argmax(tf.constant(reserved_tokens) == "[START]")  # The value of the start token
# END: int = tf.argmax(tf.constant(reserved_tokens) == "[END]")  # The value of the end token
# starts = tf.cast(tf.Variable([START]), dtype = tf.int32)  # Tensor of shape [1] and dtype int
# ends = tf.cast(tf.Variable([END]), dtype = tf.int32)  # Tensor of shape [1] and dtype int
starts = tf.constant([2], dtype=tf.int32)
ends = tf.constant([3], dtype=tf.int32)
pad_int: int = int(tf.argmax(tf.constant(reserved_tokens) == "[PAD]"))
pad_ten: tf.TensorSpec(dtype=tf.int32, shape=()) = tf.constant([pad_int], dtype=tf.int32)

In [None]:
def tokenize_string(text: str) -> tf.Tensor:
    """Converts string to tensor"""
    ragged: tf.RaggedTensor = tokenizer.tokenize(text)[0, :]
    eager: tf.Tensor = ragged.to_tensor(default_value=0, shape=[None, 1])  # 0 is the value of the padding token
    squeezed: tf.Tensor = tf.squeeze(eager, axis=1)
    typed: tf.Tensor = tf.cast(squeezed, tf.int32)
    edited: tf.Tensor = tf.concat([starts, typed, ends], axis=0)
    return edited

In [None]:
tokenized_data: List[tf.Tensor] = [tokenize_string(data_point) for data_point in data_list] 

# tqdm is a progress bar

In [None]:
print(len(tokenized_data))
print(tokenized_data[0].shape)
print(tokenized_data[0][:10])

30279
(670,)
tf.Tensor([   2 1011 7670   57   18 6423  617   33   61   44], shape=(10,), dtype=int32)


### chunk too long texts

In [None]:
max_seq_len: int = 256
def chunk_tensor(tensor: tf.Tensor, max_len: int = max_seq_len) -> List[tf.Tensor]:
    """Splits 1d tensor to chunks (1d tensors) of maximum size: max_len"""
    return [tensor[i*max_len:(i+1)*max_len] for i in range(tensor.shape[0] // max_len)]

In [None]:
chunked_data: List[tf.Tensor] = []
for tensor in tokenized_data:
    chunks = chunk_tensor(tensor, max_seq_len)
    for chunk in chunks:
        chunked_data.append(chunk)
DATA_SIZE: int = len(chunked_data)
print(DATA_SIZE)
print(chunked_data[0].shape)

336056
(256,)


## Padding

In [None]:
def pad(tensor: tf.Tensor, pad_int: int) -> tf.Tensor:
    """Pads the tensor to the length of the longest text in the data set"""
    padded: tf.Tensor = tf.pad(tensor=tensor, paddings=[[pad_int, max_seq_len - tensor.shape[0]]], mode='CONSTANT', constant_values=0)
    # 0 is the padding token
    return padded

In [None]:
padded_data: List[tf.Tensor] = [pad(text, pad_int) for text in chunked_data]
chunked_data.sort(key = lambda t: t.shape[0])  # sorting so that every batch will have similar sized texts

## Train test val split

In [None]:
batch_size: int = 128

def list_to_dataset(tokenized_list: List[tf.Tensor]) -> tf.data.Dataset:
    """Converts a list of tokenized texts after all preprocessing to a tf.data.Dataset"""
    dataset: tf.data.Dataset = tf.data.Dataset.from_tensor_slices(tokenized_list)
    batched: tf.data.Dataset = dataset.batch(batch_size)
    return batched

batched_data_ten = list_to_dataset(padded_data)
batched_data_list = list(batched_data_ten)
random.shuffle(batched_data_list)
if batched_data_list[-1].shape[0] != batch_size:  # if the last batch is smaller than batch_size
    batched_data_list = batched_data_list[:-1]  # remove the last batch
data_size = len(batched_data_list)
train_size: int = int(data_size * 0.8) 
val_test_size: int = int(data_size * 0.1)  # Both validation and test get 10% of the data
list_train_set: List[tf.Tensor] = batched_data_list[:train_size]
list_val_set: List[tf.Tensor] = batched_data_list[train_size:(train_size + val_test_size)]
list_test_set: List[tf.Tensor] = batched_data_list[(train_size + val_test_size)]

## Clear memory

In [None]:
del batched_data_list, train_size, val_test_size, data_size
del padded_data, chunked_data, tokenized_data, data_list, df
del lookup_table, reserved_tokens
del bert_tokenizer_params, ends, starts, vocab, tensor_vocab
del chunk, chunks

# Training the model

## Hyper-Parameters

In [None]:
set_size: int = 2
learning_rate: float = 0.01

num_sets: int = (max_seq_len // set_size) - 1 # Because we dont predict the first set
# number of sets in each sequence

num_blocks: int = 8
d_model: int = 256
dff: int = 512
num_heads: int = 16
dropout_rate: float = 0.1

## Create the model

In [None]:
with strategy.scope():
    model = SeTransformer(
        num_blocks=num_blocks,
        d_model=d_model,
        num_heads=num_heads,
        dff=dff,
        vocab_size=vocab_size,
        max_len=max_seq_len,
        rate=dropout_rate,
        pad_int=pad_int,
        using_tpu=using_tpu)
    
    loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
    optimizer = tf.keras.optimizers.Adam(learning_rate, epsilon=tf.keras.backend.epsilon())
    temp_input = tf.random.uniform((batch_size, max_seq_len), dtype=tf.int32, minval=5, maxval=6999)
    temp_target = tf.random.uniform((batch_size, set_size), dtype=tf.int32, minval=5, maxval=6999)
    model.compile(optimizer=optimizer, loss=loss_func)

param_count: int = model.count_params()
print(f"The model has {param_count:,} = {param_count * (10**-6):,}M trainable parameters")
stats = FlopCoKeras(model)
flops_per_call: int = stats.total_flops
macs_per_call: int = stats.total_macs

# (add-multiplies per forward pass) * (2 FLOPs/add-multiply) * (3 for forward and backward pass) * (number of examples in dataset) 
training_flops: float  = macs_per_call * 2 * flops_per_call / macs_per_call * (3 * train_step_calls + val_step_calls)
print(f"FLOPs per call: {flops_per_call:,} = {(flops_per_call * (10 ** -6)):,}M")
print(f"MACs per call: {macs_per_call:,} = {(macs_per_call * (10 ** -6)):,}M")

del temp_input, temp_target

The model has 4,084,480 trainable parameters


## Weights and Biases

In [None]:
%wandb login
# my API key is 58def12d67e682fb2c89ab27e91e612243568aba

In [None]:
run = wandb.init(
    project="pytorch-intro",
    entity="yoniteam",
    name=datetime.datetime.today().strftime(f"run from %d/%m/%Y"),
    settings=wandb.Settings(start_method="thread"),
    config = {"set size": set_size,
              "batch size": batch_size,
              "learning rate": learning_rate,
              "max seq len": max_seq_len,
              "num blocks": num_blocks,
              "model dimention": d_model,
              "dff": dff,
              "num heads": num_heads,
              "dropout rate": dropout_rate,
              "params": param_count
              })
config = wandb.config

[34m[1mwandb[0m: Currently logged in as: [33myonikremer[0m ([33myoniteam[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Training helper functions

In [None]:
@tf.function
def contains_pad(inp: tf.Tensor):
    bool_ten = tf.math.equal(inp, pad_ten)
    nonzero_count = tf.math.count_nonzero(bool_ten)
    return nonzero_count > 0

### Train

In [None]:
@tf.function(input_signature=(tf.TensorSpec(shape=[batch_size, None], dtype=tf.int32),
                              tf.TensorSpec(shape=[batch_size, set_size], dtype=tf.int32)))
def train_step(inp: tf.Tensor, outp: tf.Tensor) -> tf.Tensor:
    with tf.GradientTape() as tape:
        pred: tf.Tensor = model([inp, outp], training=True) 
        loss_val: tf.Tensor = loss_func(y_true = outp, y_pred = pred)
    grads: tf.RaggedTensor = tape.gradient(loss_val, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    return tf.math.reduce_mean(loss_val)

In [None]:
@tf.autograph.experimental.do_not_convert
@tf.function(input_signature=[tf.TensorSpec(shape=[batch_size, max_seq_len], dtype=tf.int32)])
def train(batch: tf.Tensor) -> tf.TensorSpec(shape=[], dtype=tf.keras.backend.floatx()):
    per_generation_loss: tf.Tensor = tf.zeros([num_sets], dtype=tf.keras.backend.floatx())
    i = 0
    while i < num_sets:
        # The input is of size set_size-TAKE_TO_ACCOUNT
        already_predicted: int = i * (set_size + 1)
        start_from: int = max(0, already_predicted - max_seq_len)
        inp: tf.Tensor = batch[:, start_from:(i + 1) * set_size]
        have_pad = tf.map_fn(contains_pad, inp, fn_output_signature=tf.bool, parallel_iterations=batch_size)
        if tf.get_static_value(tf.math.reduce_all(have_pad)):
            break
        outp: tf.TensorSpec(shape=[batch_size, set_size]) = batch[:, (i + 1) * set_size:(i + 2) * set_size]
        loss_val: tf.TensorSpec(shape=[], dtype=tf.keras.backend.floatx()) = train_step(inp, outp)
        one_hot: tf.TensorSpec(shape=[num_sets], dtype=tf.keras.backend.floatx())
        one_hot = tf.one_hot([i], num_sets, dtype=tf.keras.backend.floatx()) * loss_val
        per_generation_loss += one_hot
        i += 1
    train_step_calles += i 
    return tf.math.reduce_mean(per_generation_loss[:i])
    

### Validate

In [None]:
@tf.function(input_signature=(tf.TensorSpec(shape=[batch_size, None], dtype=tf.int32),
                                  tf.TensorSpec(shape=[batch_size, set_size], dtype=tf.int32)))
def val_step(inp: tf.Tensor, outp: tf.Tensor) -> tf.Tensor:
    pred = model([inp, outp], training=False)
    loss_val = loss_func(y_true = outp, y_pred = pred)
    return tf.math.reduce_mean(loss_val)

In [None]:
@tf.function(input_signature=[tf.TensorSpec(shape=[batch_size, max_seq_len], dtype=tf.int32)])
def validate(batch: tf.Tensor) -> tf.TensorSpec(shape=[], dtype=tf.keras.backend.floatx()):
    per_generation_loss: tf.Tensor = tf.zeros([num_sets], dtype=tf.keras.backend.floatx())
    i = 0
    while i < num_sets:
        # The input is of size set_size-TAKE_TO_ACCOUNT
        already_predicted: int = i * (set_size + 1)
        start_from: int = max(0, already_predicted - max_seq_len)
        inp: tf.Tensor = batch[:, start_from:(i + 1) * set_size]
        have_pad = tf.map_fn(contains_pad, inp, fn_output_signature=tf.bool, parallel_iterations=batch_size)
        if tf.get_static_value(tf.math.reduce_all(have_pad)):
            break
        outp: tf.TensorSpec(shape=[batch_size, set_size]) = batch[:, (i + 1) * set_size:(i + 2) * set_size]
        loss_val: tf.TensorSpec(shape=[], dtype=tf.keras.backend.floatx()) = val_step(inp, outp)
        one_hot: tf.TensorSpec(shape=[num_sets], dtype=tf.keras.backend.floatx())
        one_hot = tf.one_hot([i], num_sets, dtype=tf.keras.backend.floatx()) * loss_val
        per_generation_loss += one_hot
        i += 1
    val_step_calles += i
    return tf.math.reduce_mean(per_generation_loss[:i])

## Chackpoints

In [None]:
date: str = datetime.datetime.now().strftime('%m%d-%H%M')
if device == 'colab':
    folder_path: str = "/drive/MyDrive/final_project/checkpoints/"
else:
    folder_path: str = "C:/yoni/final_project/model/checkpoints/"
check_points_path = f"{folder_path}{date}"
if not os.path.isdir(folder_path):
    os.mkdir(folder_path)
if not os.path.isdir(check_points_path):
    os.mkdir(check_points_path)

In [None]:
def check_point(folder_path: str, model: SeTransformer, val_loss: float, train_loss: float, test_loss = None):
    """Saves the model at the end of each epoch"""
    # (add-multiplies per forward pass) * (2 FLOPs/add-multiply) * 
    # * (3 for forward and backward pass) * (number of examples in dataset) 
    num_ops: float  = macs_per_call * 2 * flops_per_call / macs_per_call * (3 * train_step_calles + val_step_calles)
    peta_ops: float = num_ops * (10 ** (-15))
    tf.keras.models.save_model(model = model, filepath = folder_path, save_format='tf', overwrite=True)
    artifact = wandb.Artifact('new_artifact', type='my_model', description = f"the model after {num_ops:,} operations")
    artifact.add_dir(f'after_{training_flops:,}_ops/')
    run.log_artifact(artifact)
    print("Saved checkpoint")
    %notify("Saved checkpoint")

In [None]:
train_loss, val_loss = float('inf'), float('inf')
best_val_loss = float('inf')

In [None]:
def loss_to_prob(loss: float) -> float:
    return math.exp(-loss)

## The actual training loop!

In [None]:
def train_loop():
    epochs: int = 1000000  # Train until the cloud disconnects or the model stops improving
    per_epoch_train_loss: List[float] = []
    per_epoch_val_loss: List[float] = []
    print(f"number of train batches per epoch: {len(list_train_set)}")
    last_save_time = time.time()
    global epoch: int = 0
    global batch_num: int = 0
    global val_step_calles: int = 0
    global train_step_calles: int = 0
    for epoch in range(epochs):
        print(f"epoch number: {epoch}")
        per_batch_train_loss: List[float] = []
        per_batch_val_loss: List[float] = []
        for batch_num in tqdm.tqdm(range(len(list_train_set))):  # tqdm is a progress bar
            train_loss: tf.Tensor = train(list_train_set[batch_num])
            float_val_loss = tf.keras.backend.eval(train_loss).item()
            per_batch_train_loss.append(train_loss)
            if batch_num % 8 == 0:  # 8 is number of training batches/number of val batches
                # because training set is 80% of the data and val set is 10%
                next_val_batch: tf.Tensor = list_val_set[batch_num // 8]
                val_loss: tf.Tensor = validate(next_val_batch)
                float_val_loss = tf.keras.backend.eval(val_loss).item()
                per_batch_val_loss.append(val_loss)
                wandb.log({"epoch": epoch, "batch": batch_num, "per batch train loss": train_loss, 
                        "per batch val loss": val_loss})
                if time.time() - last_save_time > 3600.0 and val_loss < math.log(vocab_size):  
                    # If the last save is more than a hour (3600 sec) ago
                    # and if the predictions are not random
                    check_point(check_points_path, model, per_epoch_val_loss[-1], per_epoch_train_loss[-1])
                    last_save_time = time.time()
                elif train_loss < 0.01:
                    title: str = "Over fitting or data leak"
                    message = f"Training loss is {train_loss} and val loss is {val_loss} in the latest batch"
                    wandb.alert(title=title, text=message)
                    print(title)
                    print(message)
                    return train_loss, val_loss
                elif time.time() - last_save_time > 1800.0 and val_loss >= math.log(vocab_size):
                    # if the prob of every token is 1/vocab_size, the loss is
                    # -ln(1/vocab_size) = ln(vocab_size) 
                    # by the logrithem rule log(a^x)=xlog(a) where x = -1
                    # if after 30 mins of training, the model predictions are still random
                    title: str = "Under fitting"
                    message = f"training loss is {train_loss} and val loss is {val_loss} in the latest batch"
                    wandb.alert(title=title, text=message)
                    print(title)
                    print(message)
                    return train_loss, val_loss
        per_epoch_train_loss.append(statistics.mean(per_batch_train_loss))
        per_epoch_val_loss.append(statistics.mean(per_batch_val_loss))
        print(f"train_loss: {per_epoch_train_loss[-1]}")
        print(f"prob of right ans train: {loss_to_prob(per_epoch_train_loss[-1])}")
        print(f"val_loss: {per_epoch_val_loss[-1]}")
        print(f"prob of right ans val: {loss_to_prob(per_epoch_val_loss[-1])}")
        if len(per_epoch_val_loss) > 1:
            if per_epoch_val_loss[-1] >= per_epoch_val_loss[-2]:
                print("Validation loss increased. Stopped training")
                return per_epoch_train_loss[-1], per_epoch_val_loss[-1]
        check_point(check_points_path, model, per_epoch_val_loss[-1], per_epoch_train_loss[-1])
        last_save_time = time.time()
        print("Saved checkpoint")

In [None]:
with strategy.scope():
    train_loss, val_loss = train_loop()

number of train batches per epoch: 8401
epoch number: 0


  0%|          | 0/8401 [00:00<?, ?it/s]

Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: name 'fscope' is not defined
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: name 'fscope' is not defined


reproduce error

In [None]:
import tensorflow as tf
resolver = tf.distribute.cluster_resolver.TPUClusterResolver();
tf.config.experimental_connect_to_cluster(resolver);
tf.tpu.experimental.initialize_tpu_system(resolver);
strategy = tf.distribute.TPUStrategy(resolver);

In [None]:
with strategy.scope():
    exm_model = tf.keras.Sequential([
        tf.keras.layers.Dense(2, activation="relu", name="layer1"),
        tf.keras.layers.Dense(3, activation="relu", name="layer2"),
        tf.keras.layers.Dense(1, name="layer3"),
    ])
    exm_optimizer = tf.keras.optimizers.Adam()
    mse = tf.keras.losses.MeanSquaredError()

@tf.function(input_signature=(tf.TensorSpec(shape=[3,3], dtype=tf.int32),
                              tf.TensorSpec(shape=[3], dtype=tf.int32)))
def exm_train_step(inp: tf.Tensor, outp: tf.Tensor) -> tf.Tensor:
    with tf.GradientTape() as tape:
        pred: tf.Tensor = exm_model(inp) 
        loss_val: tf.Tensor = mse(outp, pred)
    grads: tf.RaggedTensor = tape.gradient(loss_val, model.trainable_weights)
    exm_optimizer.apply_gradients(zip(grads, model.trainable_weights))
    return tf.math.reduce_mean(loss_val)


exm_inp = tf.random.uniform(shape=[3, 3], dtype=tf.int32, maxval=100, minval=0)
exm_out = tf.random.uniform(shape=[3], dtype=tf.int32, maxval=100, minval=0)


exm_train_step(exm_inp, exm_out)

## After training

In [None]:
test_loss = statistics.mean([validate(test_batch) for test_batch in tqdm.tqdm(list_test_set)])
print(f"Test loss: {test_loss}")
check_point(check_points_path, model, val_loss, train_loss, test_loss)