**This notebook is Inference only using TPU. For training using custom loop with TPU please [see](http://www.kaggle.com/ashusma/training-lyft-tensorflow-tpu-multi-mode) this kernel.**

Model has been trained on TPU with following parameters,
image size = [224, 224, 25]
,epochs = 45
,steps per epoch = 1000
,batch_size = 192
around 8.5 million samples

In [None]:
import pandas as pd
import tensorflow as tf
import numpy as np
from kaggle_datasets import KaggleDatasets
from l5kit.evaluation import write_pred_csv
import time

In [None]:
try:
    # TPU detection. No parameters necessary if TPU_NAME environment variable is
    # set: this is always the case on Kaggle.
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    # Default distribution strategy in Tensorflow. Works on CPU and single GPU.
    strategy = tf.distribute.get_strategy()

print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
sub = pd.read_csv('../input/lyft-motion-prediction-autonomous-vehicles/multi_mode_sample_submission.csv')

TEST_DATA_DIR = 'lyft-test-tfrecords'
TEST_GCS_PATH = KaggleDatasets().get_gcs_path(TEST_DATA_DIR)
test_files = tf.io.gfile.glob(TEST_GCS_PATH + '/test' +'/shard*.tfrecord')

In [None]:
feature_descriptions = {
    'image': tf.io.FixedLenFeature([], tf.string),
    'target_positions': tf.io.FixedLenFeature([], tf.string),
    'target_availabilities': tf.io.FixedLenFeature([], tf.string),
    'target_yaws': tf.io.FixedLenFeature([], tf.string),
    'world_from_agent' : tf.io.FixedLenFeature([], tf.string),
    'history_positions': tf.io.FixedLenFeature([], tf.string),
    'history_yaws': tf.io.FixedLenFeature([], tf.string),
    'history_availabilities': tf.io.FixedLenFeature([], tf.string),
    'world_to_image': tf.io.FixedLenFeature([], tf.string),
    'track_id': tf.io.FixedLenFeature([], tf.string),
    'timestamp': tf.io.FixedLenFeature([], tf.string),
    'centroid': tf.io.FixedLenFeature([], tf.string),
    'yaw': tf.io.FixedLenFeature([], tf.string),
    'extent': tf.io.FixedLenFeature([], tf.string),
}

feature_dtypes = {
    'image': tf.uint8,
    'target_positions': tf.float32,
    'target_availabilities': tf.float32,
    'target_yaws': tf.float32,
    'world_from_agent' : tf.float64,
    'history_positions': tf.float32,
    'history_yaws':tf.float32,
    'history_availabilities': tf.float32,
    'world_to_image': tf.float64,
    'track_id': tf.int64,
    'timestamp': tf.int64,
    'centroid': tf.float64,
    'yaw': tf.float64,
    'extent': tf.float32,
}

In [None]:
AUTO = tf.data.experimental.AUTOTUNE
GLOBAL_BATCH_SIZE  = 64*strategy.num_replicas_in_sync
IMG_SIZE = 224
CHANNEL_DIM = 25
TOTAL_STEPS = np.int(np.ceil(len(sub) / GLOBAL_BATCH_SIZE))

In [None]:
# we need to pad last partial batch with zeros to make each batch equal to global batch size
# otherwise model will through an assertion error of shape mismatch on last batch because tf.distribute create dummy 
# variable of batch size 0 if it encounters any partial batch to make global batch divisible by no of replicas

def padded_batch(images, timestamp , track_id, world_from_agent, target, avail):
    
    if tf.shape(timestamp)[0] != GLOBAL_BATCH_SIZE:
        images =    tf.concat([images , tf.zeros((GLOBAL_BATCH_SIZE - tf.shape(images)[0] , IMG_SIZE, IMG_SIZE, CHANNEL_DIM ))] , axis = 0)
        timestamp = tf.concat([timestamp , tf.zeros((GLOBAL_BATCH_SIZE - tf.shape(timestamp)[0]) , tf.int64)] , axis = 0)
        track_id =  tf.concat([track_id , tf.zeros((GLOBAL_BATCH_SIZE - tf.shape(track_id)[0]) , dtype = tf.int64)] , axis = 0)
        world_from_agent = tf.concat([world_from_agent , tf.zeros((GLOBAL_BATCH_SIZE - tf.shape(world_from_agent)[0] , 3, 3) , tf.float64)] , axis = 0)
    
    return images , timestamp , track_id , world_from_agent , target , avail

def read_unlabeled_tfrecord(example):
    
    example = tf.io.parse_single_example(example, feature_descriptions)            # returns a dictionary
    data = {k:tf.io.parse_tensor(example[k], feature_dtypes[k]) for k in example}
    
    image     =  tf.image.convert_image_dtype(data['image'], dtype = tf.float32)
    image     =  tf.transpose(tf.squeeze(image), [1,2,0]) # converting images to format (batch_size, height, width, channel)
    image     =  tf.reshape(image, (IMG_SIZE, IMG_SIZE, CHANNEL_DIM))
    timestamp =  tf.squeeze(data['timestamp'])
    track_id =   tf.squeeze(data['track_id'])
    world_from_agent = tf.squeeze(data['world_from_agent'])
    target =  tf.squeeze(data['target_positions'])
    avail =  tf.squeeze(data['target_availabilities'])
    return image , timestamp , track_id , world_from_agent , target, avail

def load_dataset(filenames, ordered = True, training = False):

    # num_parallel reads is disabled to preserve the order of data
    dataset = tf.data.TFRecordDataset(filenames,compression_type = 'GZIP' ) # tfrecords files are zipped
    dataset = dataset.map(read_unlabeled_tfrecord, num_parallel_calls = AUTO).batch(GLOBAL_BATCH_SIZE,) 
    dataset = dataset.map(padded_batch , num_parallel_calls = AUTO) 
    return dataset

In [None]:
def get_dataset(files ):
    
    dataset = load_dataset(files , training = False)
    dataset = dataset.prefetch(AUTO)                   
    
     # each iteration of test_dataset contains data for all replicas which aggregates to global batch size
    test_dataset = strategy.experimental_distribute_dataset(dataset) 
    return test_dataset

In [None]:

def transform_points(points, transf_matrix):
    
    # transform points in batches
    
    transf_matrix = tf.expand_dims(transf_matrix , axis = -1)
    assert len(points.shape) == len(transf_matrix.shape) == 4, (
    f"dimensions mismatch, both points ({points.shape}) and "
    f"transf_matrix ({transf_matrix.shape}) needs to be tensors of rank 4."
    )

    if points.shape[3] not in [2, 3]:
        raise AssertionError(f"Points input should be (N, 2) or (N, 3) shape, received {points.shape}")

    assert points.shape[3] == transf_matrix.shape[2] - 1, "points dim should be one less than matrix dim"

    points = tf.cast(points , tf.float64)
    points = tf.matmul(points , tf.transpose(transf_matrix[:, :-1, :-1, :] , perm = [0,3,2,1])) 
    return tf.cast(points , tf.float32)

In [None]:
def modified_resnet50():
    
     # model with 3 input channel dim with pretrained weights
    pretrained_model = tf.keras.applications.ResNet50(include_top=False, weights='imagenet', input_shape = (None, None, 3)) 
    
    # model with 25 input channel dim without pretrained weight
    modified_model = tf.keras.applications.ResNet50(include_top=False, weights= None, input_shape = (None, None, 25))                

    for pretrained_model_layer , modified_model_layer in zip(pretrained_model.layers , modified_model.layers):
        layer_to_modify = ['conv1_conv']                    # conv1_conv is name of layer that takes the input and will be modified
        if pretrained_model_layer.name in layer_to_modify :          
            kernel = pretrained_model_layer.get_weights()[0]  # kernel weight shape is (7, 7 ,3, 64)
            bias = pretrained_model_layer.get_weights()[1]
            
            # concatenating along channel axis to make channel dimension 25
            weights = np.concatenate(( kernel[:, :, -1: ,:] , np.tile( kernel , [1, 1, 8, 1]) , ) , axis=  -2)  
            modified_model_layer.set_weights((weights , bias))
        else:
            modified_model_layer.set_weights(pretrained_model_layer.get_weights())

    return modified_model

In [None]:
class LyftModel(tf.keras.Model):
    def __init__(self ,num_modes = 3,future_pred_frames = 50 ,):
        super(LyftModel , self).__init__()
        
        self.conv1 = tf.keras.layers.Conv2D(3,kernel_size=1,use_bias=False,padding="same" ,)
        self.bn1   = tf.keras.layers.BatchNormalization()
        self.relu  = tf.keras.layers.ReLU()
        self.model = tf.keras.applications.ResNet50(include_top=False, weights='imagenet',)
        self.gap   = tf.keras.layers.GlobalAveragePooling2D()
        self.dropout = tf.keras.layers.Dropout(0.2)
        
        self.future_len = num_modes * future_pred_frames * 2      
        self.future_pred_frames = future_pred_frames
        self.num_modes = num_modes
        
        self.dense1 = tf.keras.layers.Dense(self.future_len + self.num_modes ,)
  



    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.model(x)
        x = self.gap(x)
        x = self.dropout(x)
        x = self.dense1(x)
        
        batch_size, _  = x.shape
        pred , confidence = tf.split(x , num_or_size_splits = [self.future_len, self.num_modes], axis = 1)

        assert confidence.shape == (batch_size , self.num_modes) , f'confidence shape got {confidence.shape}'
        pred = tf.reshape(pred , shape = (batch_size ,self.num_modes, self.future_pred_frames , 2))
        confidence = tf.nn.softmax(confidence , axis = 1)
        return pred , confidence 
        

In [None]:
# keeping all model variables under the scope so that TPU can track them

def get_model():
    with strategy.scope():
#         tf.tpu.experimental.initialize_tpu_system()
        model = LyftModel()
        model.build((GLOBAL_BATCH_SIZE , None, None, CHANNEL_DIM))
        model.summary()
        model.load_weights('../input/lyft-resnet-model/epoch 45 and val_loss 23.7635 model.h5')
        transf_points = lambda pred , world_from_agent : transform_points(pred, world_from_agent)
        return model , transf_points

In [None]:
# prediction function  
# here tf.function compiles the function into tensorflow graph 

@tf.function
def test_step(image , world_from_agent):
    pred , confidence = model(image , training = False)
    pred = transf_points(pred , world_from_agent)
    return  pred , confidence

In [None]:
# prediction loop

def model_predict(test_dataset):
    
    start_time = epoch_start_time = time.time()

    future_coord_offset = []
    timestamps = []
    track_ids = []
    confs = []

    print(f'predicting {TOTAL_STEPS} steps')
    for step, (images, timestamp, track_id , world_from_agent , target , avail) ,in enumerate(test_dataset):
         
        # test dataset returns Per replica object dictionary whose keys contain index for each replicas and its values 
         #are dataset with batch_size GLOBAL_BATCH_SIZE / strategy.num_replicas_in_sync
        
        # in TPU : keys are  0, 1, 2, 3, 4, 5, 6, 7 for 8 replicas
        
        pred , confidence = strategy.run(test_step, args=(images, world_from_agent))
        
        global_batch_pred = np.concatenate([pred.values[i].numpy() for i in range(strategy.num_replicas_in_sync)], axis = 0)
        global_batch_confs = np.concatenate([confidence.values[i].numpy() for i in range(strategy.num_replicas_in_sync)], axis = 0)
        global_batch_timestamps = np.concatenate([timestamp.values[i].numpy() for i in range(strategy.num_replicas_in_sync)], axis = 0)
        global_batch_track_ids = np.concatenate([track_id.values[i].numpy() for i in range(strategy.num_replicas_in_sync)], axis = 0 )
        
        future_coord_offset.append(global_batch_pred)
        confs.append(global_batch_confs)
        timestamps.append(global_batch_timestamps)
        track_ids.append(global_batch_track_ids)
         
        print('=' , end = ' ' , flush = True)
    epoch_time =   time.time() - epoch_start_time
    print('time: {:0.1f}s'.format(epoch_time))
    
    return future_coord_offset, timestamps , track_ids, confs

In [None]:
# now getting the dataset and model

test_dataset = get_dataset(test_files)
model , transf_points = get_model()
future_coord_offset, timestamps , track_ids , confs = model_predict(test_dataset )

In [None]:
# making csv file using l5kit built in function
# since last batch contains zeros that we padded so we will discard them by taking only length of test set

pred_path = 'submission.csv'
test_length = len(sub)
write_pred_csv(pred_path,
              timestamps =  np.concatenate(timestamps, axis = 0)[0:test_length], 
              track_ids  =  np.concatenate(track_ids , axis = 0)[0:test_length],   
              coords     =  np.concatenate(future_coord_offset , axis = 0)[0:test_length],   
              confs      =  np.concatenate(confs , axis = 0)[0:test_length],   
              )