In [75]:
import geopandas as gpd

import numpy as np
import pandas as pd
from pandas import IndexSlice as idx
import tensorflow as tf
import sys
import os
import glob

from functools import partial



code_dir = '/cluster/home/kheuto01/code/opioid-overdose-models/perturbations/'
sys.path.append(code_dir)
code_dir = '/cluster/home/kheuto01/code/opioid-overdose-models/diff_bpr'
sys.path.append(code_dir)
from top_k import top_k_idx
#from make_datasets import make_data
from bpr_model import PerturbedBPRModel


code_dir = '/cluster/home/kheuto01/code/opioid-overdose-models/'
sys.path.append(code_dir)
from zinf_gp.metrics import normcdf, fixed_top_X



from perturbations import perturbed
from bpr import bpr_variable_k_no_ties



In [2]:
data_path='/cluster/tufts/hugheslab/datasets/NSF_OD/results_20220606_update/clean_quarter_tract/'

In [113]:
def make_data_quarterly(multiindexed_gdf, first_year, last_year, time_window, feature_cols, train_shape, pred_lag=1):


    xs = []
    ys = []

    for eval_year in range(first_year, last_year + 1):
        quarters_in_year = multiindexed_gdf[multiindexed_gdf['year']==eval_year].index.unique(level='timestep')
        quarters_in_year.sort_values()
        train_x_df = multiindexed_gdf.loc[idx[:, min(quarters_in_year) - time_window:max(quarters_in_year) - pred_lag], feature_cols]

        

        

        for quarter in quarters_in_year:
            
            train_x_df['pred_timestep'] = quarter
            train_x_vals = train_x_df.values.reshape(train_shape)
            
            train_y_df = multiindexed_gdf.loc[idx[:,quarter], 'deaths']
            train_y_vals = train_y_df.values

            xs.append(train_x_vals)
            ys.append(train_y_vals)

    x_BSTD = np.stack(xs, axis=0)
    y_BS = np.stack(ys)

    x_BSTD = tf.convert_to_tensor(x_BSTD, dtype=tf.float32)
    y_BS = tf.convert_to_tensor(y_BS, dtype=tf.float32)

    B, S, T, D = x_BSTD.shape

    assert (B == len(range(first_year, last_year + 1))*pred_lag)
    assert (S == train_shape[0])
    assert (T == time_window)
    assert (D == len(feature_cols)+1)

    # Reshape the training data to flatten the dimensions
    x_BSF_flat = tf.reshape(x_BSTD, (B, S, T * D), )


    return x_BSF_flat, y_BS

In [115]:
class PerturbedBPRModel(tf.keras.Model):

    def __init__(self, perturbed_top_k_func, k=100):
        """k should match the k baked into the perturbed top_k func.
        we need k for when performing exact top k in evaluation step."""
        super(PerturbedBPRModel, self).__init__()
        self.perturbed_top_k_func = perturbed_top_k_func
        self.k = k
        self.hidden1 = tf.keras.layers.Dense(100, activation='relu')
        self.hidden2 = tf.keras.layers.Dense(50, activation='relu')
        self.hidden3 = tf.keras.layers.Dense(10, activation='relu')
        self.output_layer = tf.keras.layers.Dense(1, activation=None)

    def call(self, inputs):
        intermediate = self.hidden1(inputs)
        intermediate = self.hidden2(intermediate)
        intermediate = self.hidden3(intermediate)
        
        outputs = self.output_layer(intermediate)
        # squeeze away feature dimension
        outputs = tf.squeeze(outputs, axis=-1)
        return outputs

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            top_100_indicators = self.perturbed_top_k_func(y_pred)
            true_top_100_val, true_top_100_idx = tf.math.top_k(y, k=self.k)

            denominator = tf.reduce_sum(true_top_100_val, axis=-1)
            numerator = tf.reduce_sum(top_100_indicators * y, axis=-1)

            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(numerator, denominator, regularization_losses=self.losses)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        # Unpack the data
        x, y = data
        # Compute predictions
        y_pred = self(x, training=False)  # Forward pass
        # use discrete topk to simulate making a decision
        _, pred_100_idx = tf.math.top_k(y_pred, k=self.k)
        true_top_100_val, true_top_100_idx = tf.math.top_k(y, k=self.k)

        denominator = tf.reduce_sum(true_top_100_val, axis=-1)
        numerator = tf.reduce_sum(tf.gather(y, pred_100_idx, batch_dims=-1), axis=-1)

        # Compute the loss value
        # (the loss function is configured in `compile()`)
        self.compiled_loss(numerator, denominator, regularization_losses=self.losses)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

In [121]:
epochs = 5000
seed = 360
time_window = 5*4
first_train_eval_year = 2014
last_train_eval_year = 2018
#batch_dim_size = last_train_eval_year - first_train_eval_year + 1
first_validation_year = 2019
last_validation_year = 2019
first_test_year = 2020
last_test_year = 2021

tf.random.set_seed(seed)


timestep_col = 'timestep'
geography_col = 'geoid'
outcome_col = 'deaths'

x_idx_cols = [geography_col, 'lat', 'lon', timestep_col,
              'theme_1_pc', 'theme_2_pc', 'theme_3_pc', 'theme_4_pc',
              'svi_pctile', 'year',
              'neighbor_t', 'deaths']
y_idx_cols = [geography_col, timestep_col, outcome_col]
features_only = ['lat', 'lon', timestep_col,
                 'theme_1_pc', 'theme_2_pc', 'theme_3_pc', 'theme_4_pc',
                 'svi_pctile',
                 'neighbor_t', 'deaths']
#features_only = ['deaths']

#data_gdf = gpd.read_file(data_path)

multiindexed_gdf = data_gdf.set_index(['geoid', 'timestep'])
multiindexed_gdf['timestep'] = multiindexed_gdf.index.get_level_values('timestep')
num_geoids = len(data_gdf['geoid'].unique())

train_shape = (num_geoids, time_window, len(features_only)+1)

train_x_BSF_flat, train_y_BS = make_data_quarterly(multiindexed_gdf, first_train_eval_year, last_train_eval_year,
                                                  time_window, features_only, train_shape, pred_lag=4)

valid_x_BSF_flat, valid_y_BS = make_data_quarterly(multiindexed_gdf, first_validation_year, last_validation_year,
                                         time_window, features_only, train_shape, pred_lag=4)

test_x_BSF_flat, test_y_BS = make_data_quarterly(multiindexed_gdf, first_test_year, last_test_year,
                                       time_window, features_only, train_shape, pred_lag=4)

norm_layer = tf.keras.layers.Normalization()
norm_layer.adapt(train_x_BSF_flat)
train_x_BSF_flat = norm_layer(train_x_BSF_flat)
valid_x_BSF_flat = norm_layer(valid_x_BSF_flat)
test_x_BSF_flat = norm_layer(test_x_BSF_flat)

top_100_idx_func = partial(top_k_idx, k=100)

In [144]:
perturbation_samples = 37
noise=0.4
learning_rate = 0.0001


In [145]:
perturbed_top_100 = perturbed(top_100_idx_func,
                              num_samples=perturbation_samples,
                              sigma=noise,
                              noise='normal',
                              batched=True)

model = PerturbedBPRModel(perturbed_top_100)

In [146]:
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

# Compile the model
def weird_loss(a, b):
    return -a / b

model.compile(optimizer=optimizer, loss=weird_loss)


In [None]:

model.fit(train_x_BSF_flat, train_y_BS, epochs=epochs, batch_size=batch_dim_size,
          validation_data=(valid_x_BSF_flat, valid_y_BS))

Epoch 1/5000
Epoch 2/5000
Epoch 3/5000
Epoch 4/5000
Epoch 5/5000
Epoch 6/5000
Epoch 7/5000
Epoch 8/5000
Epoch 9/5000
Epoch 10/5000
Epoch 11/5000
Epoch 12/5000
Epoch 13/5000
Epoch 14/5000
Epoch 15/5000
Epoch 16/5000
Epoch 17/5000
Epoch 18/5000
Epoch 19/5000
Epoch 20/5000
Epoch 21/5000
Epoch 22/5000
Epoch 23/5000
Epoch 24/5000
Epoch 25/5000
Epoch 26/5000
Epoch 27/5000
Epoch 28/5000
Epoch 29/5000
Epoch 30/5000
Epoch 31/5000
Epoch 32/5000
Epoch 33/5000
Epoch 34/5000
Epoch 35/5000
Epoch 36/5000
Epoch 37/5000
Epoch 38/5000
Epoch 39/5000
Epoch 40/5000
Epoch 41/5000
Epoch 42/5000
Epoch 43/5000
Epoch 44/5000
Epoch 45/5000
Epoch 46/5000
Epoch 47/5000
Epoch 48/5000
Epoch 49/5000
Epoch 50/5000
Epoch 51/5000
Epoch 52/5000
Epoch 53/5000
Epoch 54/5000
Epoch 55/5000
Epoch 56/5000
Epoch 57/5000
Epoch 58/5000
Epoch 59/5000
Epoch 60/5000
Epoch 61/5000
Epoch 62/5000
Epoch 63/5000
Epoch 64/5000
Epoch 65/5000
Epoch 66/5000
Epoch 67/5000
Epoch 68/5000
Epoch 69/5000
Epoch 70/5000
Epoch 71/5000
Epoch 72/5000
E

Epoch 81/5000
Epoch 82/5000
Epoch 83/5000
Epoch 84/5000
Epoch 85/5000
Epoch 86/5000
Epoch 87/5000
Epoch 88/5000
Epoch 89/5000
Epoch 90/5000
Epoch 91/5000
Epoch 92/5000
Epoch 93/5000
Epoch 94/5000
Epoch 95/5000
Epoch 96/5000
Epoch 97/5000
Epoch 98/5000
Epoch 99/5000
Epoch 100/5000
Epoch 101/5000
Epoch 102/5000
Epoch 103/5000
Epoch 104/5000
Epoch 105/5000
Epoch 106/5000
Epoch 107/5000
Epoch 108/5000
Epoch 109/5000
Epoch 110/5000
Epoch 111/5000
Epoch 112/5000
Epoch 113/5000
Epoch 114/5000
Epoch 115/5000
Epoch 116/5000
Epoch 117/5000
Epoch 118/5000
Epoch 119/5000
Epoch 120/5000
Epoch 121/5000
Epoch 122/5000
Epoch 123/5000
Epoch 124/5000
Epoch 125/5000
Epoch 126/5000
Epoch 127/5000
Epoch 128/5000
Epoch 129/5000
Epoch 130/5000
Epoch 131/5000
Epoch 132/5000
Epoch 133/5000
Epoch 134/5000
Epoch 135/5000
Epoch 136/5000
Epoch 137/5000
Epoch 138/5000
Epoch 139/5000
Epoch 140/5000
Epoch 141/5000
Epoch 142/5000
Epoch 143/5000
Epoch 144/5000
Epoch 145/5000
Epoch 146/5000
Epoch 147/5000
Epoch 148/5000

Epoch 160/5000
Epoch 161/5000
Epoch 162/5000
Epoch 163/5000
Epoch 164/5000
Epoch 165/5000
Epoch 166/5000
Epoch 167/5000
Epoch 168/5000
Epoch 169/5000
Epoch 170/5000
Epoch 171/5000
Epoch 172/5000
Epoch 173/5000
Epoch 174/5000
Epoch 175/5000
Epoch 176/5000
Epoch 177/5000
Epoch 178/5000
Epoch 179/5000
Epoch 180/5000
Epoch 181/5000
Epoch 182/5000
Epoch 183/5000
Epoch 184/5000
Epoch 185/5000
Epoch 186/5000
Epoch 187/5000
Epoch 188/5000
Epoch 189/5000
Epoch 190/5000
Epoch 191/5000
Epoch 192/5000
Epoch 193/5000
Epoch 194/5000
Epoch 195/5000
Epoch 196/5000
Epoch 197/5000
Epoch 198/5000
Epoch 199/5000
Epoch 200/5000
Epoch 201/5000
Epoch 202/5000
Epoch 203/5000
Epoch 204/5000
Epoch 205/5000
Epoch 206/5000
Epoch 207/5000
Epoch 208/5000
Epoch 209/5000
Epoch 210/5000
Epoch 211/5000
Epoch 212/5000
Epoch 213/5000
Epoch 214/5000
Epoch 215/5000
Epoch 216/5000
Epoch 217/5000
Epoch 218/5000
Epoch 219/5000
Epoch 220/5000
Epoch 221/5000
Epoch 222/5000
Epoch 223/5000
Epoch 224/5000
Epoch 225/5000
Epoch 226/

Epoch 238/5000
Epoch 239/5000
Epoch 240/5000
Epoch 241/5000
Epoch 242/5000
Epoch 243/5000
Epoch 244/5000
Epoch 245/5000
Epoch 246/5000
Epoch 247/5000
Epoch 248/5000
Epoch 249/5000
Epoch 250/5000
Epoch 251/5000
Epoch 252/5000
Epoch 253/5000
Epoch 254/5000
Epoch 255/5000
Epoch 256/5000
Epoch 257/5000
Epoch 258/5000
Epoch 259/5000
Epoch 260/5000
Epoch 261/5000
Epoch 262/5000
Epoch 263/5000
Epoch 264/5000
Epoch 265/5000
Epoch 266/5000
Epoch 267/5000
Epoch 268/5000
Epoch 269/5000
Epoch 270/5000
Epoch 271/5000
Epoch 272/5000
Epoch 273/5000
Epoch 274/5000
Epoch 275/5000
Epoch 276/5000
Epoch 277/5000
Epoch 278/5000
Epoch 279/5000
Epoch 280/5000
Epoch 281/5000
Epoch 282/5000
Epoch 283/5000
Epoch 284/5000
Epoch 285/5000
Epoch 286/5000
Epoch 287/5000
Epoch 288/5000
Epoch 289/5000
Epoch 290/5000
Epoch 291/5000
Epoch 292/5000
Epoch 293/5000
Epoch 294/5000
Epoch 295/5000
Epoch 296/5000
Epoch 297/5000
Epoch 298/5000
Epoch 299/5000
Epoch 300/5000
Epoch 301/5000
Epoch 302/5000
Epoch 303/5000
Epoch 304/

Epoch 317/5000
Epoch 318/5000
Epoch 319/5000
Epoch 320/5000
Epoch 321/5000
Epoch 322/5000
Epoch 323/5000
Epoch 324/5000
Epoch 325/5000
Epoch 326/5000
Epoch 327/5000
Epoch 328/5000
Epoch 329/5000
Epoch 330/5000
Epoch 331/5000
Epoch 332/5000
Epoch 333/5000
Epoch 334/5000
Epoch 335/5000
Epoch 336/5000
Epoch 337/5000
Epoch 338/5000
Epoch 339/5000
Epoch 340/5000
Epoch 341/5000
Epoch 342/5000
Epoch 343/5000
Epoch 344/5000
Epoch 345/5000
Epoch 346/5000
Epoch 347/5000
Epoch 348/5000
Epoch 349/5000
Epoch 350/5000
Epoch 351/5000
Epoch 352/5000
Epoch 353/5000
Epoch 354/5000
Epoch 355/5000
Epoch 356/5000
Epoch 357/5000
Epoch 358/5000
Epoch 359/5000
Epoch 360/5000
Epoch 361/5000
Epoch 362/5000
Epoch 363/5000
Epoch 364/5000
Epoch 365/5000
Epoch 366/5000
Epoch 367/5000
Epoch 368/5000
Epoch 369/5000
Epoch 370/5000
Epoch 371/5000
Epoch 372/5000
Epoch 373/5000
Epoch 374/5000
Epoch 375/5000
Epoch 376/5000
Epoch 377/5000
Epoch 378/5000
Epoch 379/5000
Epoch 380/5000
Epoch 381/5000
Epoch 382/5000
Epoch 383/

Epoch 396/5000
Epoch 397/5000
Epoch 398/5000
Epoch 399/5000
Epoch 400/5000
Epoch 401/5000
Epoch 402/5000
Epoch 403/5000
Epoch 404/5000
Epoch 405/5000
Epoch 406/5000
Epoch 407/5000
Epoch 408/5000
Epoch 409/5000
Epoch 410/5000
Epoch 411/5000
Epoch 412/5000
Epoch 413/5000
Epoch 414/5000
Epoch 415/5000
Epoch 416/5000
Epoch 417/5000
Epoch 418/5000
Epoch 419/5000
Epoch 420/5000
Epoch 421/5000
Epoch 422/5000
Epoch 423/5000
Epoch 424/5000
Epoch 425/5000
Epoch 426/5000
Epoch 427/5000
Epoch 428/5000
Epoch 429/5000
Epoch 430/5000
Epoch 431/5000
Epoch 432/5000
Epoch 433/5000
Epoch 434/5000
Epoch 435/5000
Epoch 436/5000
Epoch 437/5000
Epoch 438/5000
Epoch 439/5000
Epoch 440/5000
Epoch 441/5000
Epoch 442/5000
Epoch 443/5000
Epoch 444/5000
Epoch 445/5000
Epoch 446/5000
Epoch 447/5000
Epoch 448/5000
Epoch 449/5000
Epoch 450/5000
Epoch 451/5000
Epoch 452/5000
Epoch 453/5000
Epoch 454/5000
Epoch 455/5000
Epoch 456/5000
Epoch 457/5000
Epoch 458/5000
Epoch 459/5000
Epoch 460/5000
Epoch 461/5000
Epoch 462/

Epoch 475/5000
Epoch 476/5000
Epoch 477/5000
Epoch 478/5000
Epoch 479/5000
Epoch 480/5000
Epoch 481/5000
Epoch 482/5000
Epoch 483/5000
Epoch 484/5000
Epoch 485/5000
Epoch 486/5000
Epoch 487/5000
Epoch 488/5000
Epoch 489/5000
Epoch 490/5000
Epoch 491/5000
Epoch 492/5000
Epoch 493/5000
Epoch 494/5000
Epoch 495/5000
Epoch 496/5000
Epoch 497/5000
Epoch 498/5000
Epoch 499/5000
Epoch 500/5000
Epoch 501/5000
Epoch 502/5000
Epoch 503/5000
Epoch 504/5000
Epoch 505/5000
Epoch 506/5000
Epoch 507/5000
Epoch 508/5000
Epoch 509/5000
Epoch 510/5000
Epoch 511/5000
Epoch 512/5000
Epoch 513/5000
Epoch 514/5000
Epoch 515/5000
Epoch 516/5000
Epoch 517/5000
Epoch 518/5000
Epoch 519/5000
Epoch 520/5000
Epoch 521/5000
Epoch 522/5000
Epoch 523/5000
Epoch 524/5000
Epoch 525/5000
Epoch 526/5000
Epoch 527/5000
Epoch 528/5000
Epoch 529/5000
Epoch 530/5000
Epoch 531/5000
Epoch 532/5000
Epoch 533/5000
Epoch 534/5000
Epoch 535/5000
Epoch 536/5000
Epoch 537/5000
Epoch 538/5000
Epoch 539/5000
Epoch 540/5000
Epoch 541/

Epoch 554/5000
Epoch 555/5000
Epoch 556/5000
Epoch 557/5000
Epoch 558/5000
Epoch 559/5000
Epoch 560/5000
Epoch 561/5000
Epoch 562/5000
Epoch 563/5000
Epoch 564/5000
Epoch 565/5000
Epoch 566/5000
Epoch 567/5000
Epoch 568/5000
Epoch 569/5000
Epoch 570/5000
Epoch 571/5000
Epoch 572/5000
Epoch 573/5000
Epoch 574/5000
Epoch 575/5000
Epoch 576/5000
Epoch 577/5000
Epoch 578/5000
Epoch 579/5000
Epoch 580/5000
Epoch 581/5000
Epoch 582/5000
Epoch 583/5000
Epoch 584/5000
Epoch 585/5000
Epoch 586/5000
Epoch 587/5000
Epoch 588/5000
Epoch 589/5000
Epoch 590/5000
Epoch 591/5000
Epoch 592/5000
Epoch 593/5000
Epoch 594/5000
Epoch 595/5000
Epoch 596/5000
Epoch 597/5000
Epoch 598/5000
Epoch 599/5000
Epoch 600/5000
Epoch 601/5000
Epoch 602/5000
Epoch 603/5000
Epoch 604/5000
Epoch 605/5000
Epoch 606/5000
Epoch 607/5000
Epoch 608/5000
Epoch 609/5000
Epoch 610/5000
Epoch 611/5000
Epoch 612/5000
Epoch 613/5000
Epoch 614/5000
Epoch 615/5000
Epoch 616/5000
Epoch 617/5000
Epoch 618/5000
Epoch 619/5000
Epoch 620/

Epoch 633/5000
Epoch 634/5000
Epoch 635/5000
Epoch 636/5000
Epoch 637/5000
Epoch 638/5000
Epoch 639/5000
Epoch 640/5000
Epoch 641/5000
Epoch 642/5000
Epoch 643/5000
Epoch 644/5000
Epoch 645/5000
Epoch 646/5000
Epoch 647/5000
Epoch 648/5000
Epoch 649/5000
Epoch 650/5000
Epoch 651/5000
Epoch 652/5000
Epoch 653/5000
Epoch 654/5000
Epoch 655/5000
Epoch 656/5000
Epoch 657/5000
Epoch 658/5000
Epoch 659/5000
Epoch 660/5000
Epoch 661/5000
Epoch 662/5000
Epoch 663/5000
Epoch 664/5000
Epoch 665/5000
Epoch 666/5000
Epoch 667/5000
Epoch 668/5000
Epoch 669/5000
Epoch 670/5000
Epoch 671/5000
Epoch 672/5000
Epoch 673/5000
Epoch 674/5000
Epoch 675/5000
Epoch 676/5000
Epoch 677/5000
Epoch 678/5000
Epoch 679/5000
Epoch 680/5000
Epoch 681/5000
Epoch 682/5000
Epoch 683/5000
Epoch 684/5000
Epoch 685/5000
Epoch 686/5000
Epoch 687/5000
Epoch 688/5000
Epoch 689/5000
Epoch 690/5000
Epoch 691/5000
Epoch 692/5000
Epoch 693/5000
Epoch 694/5000
Epoch 695/5000
Epoch 696/5000
Epoch 697/5000
Epoch 698/5000
Epoch 699/

Epoch 712/5000
Epoch 713/5000
Epoch 714/5000
Epoch 715/5000
Epoch 716/5000
Epoch 717/5000
Epoch 718/5000
Epoch 719/5000
Epoch 720/5000
Epoch 721/5000
Epoch 722/5000
Epoch 723/5000
Epoch 724/5000
Epoch 725/5000
Epoch 726/5000
Epoch 727/5000
Epoch 728/5000
Epoch 729/5000
Epoch 730/5000
Epoch 731/5000
Epoch 732/5000
Epoch 733/5000
Epoch 734/5000
Epoch 735/5000
Epoch 736/5000
Epoch 737/5000
Epoch 738/5000
Epoch 739/5000
Epoch 740/5000
Epoch 741/5000
Epoch 742/5000
Epoch 743/5000
Epoch 744/5000
Epoch 745/5000
Epoch 746/5000
Epoch 747/5000
Epoch 748/5000
Epoch 749/5000
Epoch 750/5000
Epoch 751/5000
Epoch 752/5000
Epoch 753/5000
Epoch 754/5000
Epoch 755/5000
Epoch 756/5000
Epoch 757/5000
Epoch 758/5000
Epoch 759/5000
Epoch 760/5000
Epoch 761/5000
Epoch 762/5000
Epoch 763/5000
Epoch 764/5000
Epoch 765/5000
Epoch 766/5000
Epoch 767/5000
Epoch 768/5000
Epoch 769/5000
Epoch 770/5000
Epoch 771/5000
Epoch 772/5000
Epoch 773/5000
Epoch 774/5000
Epoch 775/5000
Epoch 776/5000
Epoch 777/5000
Epoch 778/

In [139]:
model.trainable_variables

[<tf.Variable 'perturbed_bpr_model_11/dense_30/kernel:0' shape=(220, 100) dtype=float32, numpy=
 array([[ 0.00719592, -0.04645859, -0.01844875, ..., -0.04718753,
          0.03084159,  0.14573094],
        [-0.11836787, -0.12660244,  0.02698829, ..., -0.08388493,
         -0.14309686,  0.19262102],
        [-0.08026059, -0.28874218,  0.10318363, ..., -0.05022891,
          0.02977351,  0.3901074 ],
        ...,
        [ 0.23114634,  0.08450731, -0.03968769, ..., -0.20577903,
         -0.14238843, -0.15449598],
        [-0.13269791, -0.06289726,  0.2431028 , ...,  0.16756043,
         -0.20101394,  0.24970037],
        [-0.16971661, -0.25981182,  0.2715672 , ...,  0.07830626,
          0.02652317,  0.31739092]], dtype=float32)>,
 <tf.Variable 'perturbed_bpr_model_11/dense_30/bias:0' shape=(100,) dtype=float32, numpy=
 array([ 0.06196941,  0.161347  ,  0.26962522,  0.08456077, -0.12389129,
        -0.18943824, -0.16642849, -0.06824225, -0.01711683, -0.04803375,
        -0.12577438,  0.1

In [82]:
multiindexed_gdf

Unnamed: 0_level_0,Unnamed: 1_level_0,year,quarter,deaths,month,STATEFP,COUNTYFP,TRACTCE,NAME,NAMELSAD,MTFCC,...,TRACTCE_y,month_sinc,season,season_sin,qtr_since_,year_since,self_t-1,neighbor_t,geometry,timestep
geoid,timestep,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
25001010100,0.0,2000,1.0,0.0,1,25,001,10100,101,Census Tract 101,G5020,...,,0.0,jan-jun,0.0,0.0,0.0,0.0,0.000,"POLYGON ((-70.25001 42.06410, -70.24959 42.065...",0.0
25001010100,1.0,2000,2.0,0.0,4,25,001,10100,101,Census Tract 101,G5020,...,,3.0,jan-jun,0.0,1.0,0.0,0.0,0.000,"POLYGON ((-70.25001 42.06410, -70.24959 42.065...",1.0
25001010100,2.0,2000,3.0,0.0,7,25,001,10100,101,Census Tract 101,G5020,...,,6.0,jul-dec,1.0,2.0,0.0,0.0,0.000,"POLYGON ((-70.25001 42.06410, -70.24959 42.065...",2.0
25001010100,3.0,2000,4.0,0.0,10,25,001,10100,101,Census Tract 101,G5020,...,,9.0,jul-dec,1.0,3.0,0.0,0.0,0.000,"POLYGON ((-70.25001 42.06410, -70.24959 42.065...",3.0
25001010100,4.0,2001,1.0,3.0,1,25,001,10100,101,Census Tract 101,G5020,...,,12.0,jan-jun,2.0,4.0,1.0,0.0,0.000,"POLYGON ((-70.25001 42.06410, -70.24959 42.065...",4.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
25027761402,83.0,2020,4.0,0.0,10,25,027,761402,7614.02,Census Tract 7614.02,G5020,...,,249.0,jul-dec,41.0,83.0,20.0,0.0,0.125,"POLYGON ((-71.63921 42.53096, -71.63906 42.531...",83.0
25027761402,84.0,2021,1.0,1.0,1,25,027,761402,7614.02,Census Tract 7614.02,G5020,...,,252.0,jan-jun,42.0,84.0,21.0,0.0,0.125,"POLYGON ((-71.63921 42.53096, -71.63906 42.531...",84.0
25027761402,85.0,2021,2.0,0.0,4,25,027,761402,7614.02,Census Tract 7614.02,G5020,...,,255.0,jan-jun,42.0,85.0,21.0,0.0,0.000,"POLYGON ((-71.63921 42.53096, -71.63906 42.531...",85.0
25027761402,86.0,2021,3.0,0.0,7,25,027,761402,7614.02,Census Tract 7614.02,G5020,...,,258.0,jul-dec,43.0,86.0,21.0,0.0,0.000,"POLYGON ((-71.63921 42.53096, -71.63906 42.531...",86.0


In [122]:
train_x_BSF_flat

<tf.Tensor: shape=(20, 1620, 220), dtype=float32, numpy=
array([[[-0.7479397 ,  1.8678857 , -1.4142135 , ..., -1.0031025 ,
         -0.45583406, -1.647509  ],
        [-1.2349257 ,  2.1685455 , -1.4142135 , ..., -1.0031025 ,
         -0.45583406, -1.647509  ],
        [-0.9121866 ,  2.0892835 , -1.4142135 , ..., -1.0031025 ,
         -0.45583406, -1.647509  ],
        ...,
        [-0.10952955, -0.57149386, -1.4142135 , ..., -0.5008921 ,
         -0.45583406, -1.647509  ],
        [ 0.77810806, -0.37315318, -1.4142135 , ..., -1.0031025 ,
         -0.45583406, -1.647509  ],
        [ 0.9257272 , -0.39440045, -1.4142135 , ..., -1.0031025 ,
         -0.45583406, -1.647509  ]],

       [[-0.7479397 ,  1.8678857 , -1.4142135 , ..., -1.0031025 ,
         -0.45583406, -1.474087  ],
        [-1.2349257 ,  2.1685455 , -1.4142135 , ..., -1.0031025 ,
         -0.45583406, -1.474087  ],
        [-0.9121866 ,  2.0892835 , -1.4142135 , ..., -1.0031025 ,
         -0.45583406, -1.474087  ],
        ..