In [39]:
import os
import pandas as pd
import numpy as np
import polars as pl
import nfl_data_py as nfl

env = "local"

In [40]:
if env == "local":
    os.chdir("/Users/samuel/Documents/GitHub/QB-GPT/")
else:
    from google.colab import drive
    drive.mount('/content/gdrive')
    os.chdir("/content/gdrive/MyDrive/NFL_Challenge/NFL-GPT/NFL data")

In [41]:
years_to_get = [2017, 2018, 2019, 2020, 2021, 2022, 2023]
season_data = pl.from_pandas(nfl.import_pbp_data(years_to_get))
rosters = pl.from_pandas(nfl.import_seasonal_rosters(years_to_get))

2017 done.
2018 done.
2019 done.
2020 done.
2021 done.
2022 done.
2023 done.
Downcasting floats.


In [42]:
plays_index = pl.read_parquet("index/plays_index.parquet")
positions_index = pl.read_parquet("index/positions_index.parquet")
OL_df = pl.DataFrame({"position" : ["OL"],
                      "position_ID" : [27],
                      "Cat" : ["Pos"]})

new_positions_index = pl.concat([positions_index, OL_df])

In [43]:
team_index = (season_data.
              select("home_team").
              rename({"home_team" : "team"}).
              unique().
              with_columns(pl.when(pl.col("team") == "OAK").
                           then(pl.lit("LV")).
                           otherwise(pl.col("team")).
                           alias("team")).
              with_columns(pl.arange(0, 32).alias("team_ID")))

In [44]:
yards_index = (pl.DataFrame({"yards_gained" : range(-99, 100), 
                            "yard_ID" : range(0, 199)}).
                with_columns(pl.col("yards_gained").cast(pl.Int32)).
                with_columns(pl.col("yard_ID").cast(pl.Int32)))

In [45]:
season_index = (pl.DataFrame({"season" : [2017, 2018, 2019, 2020, 2021, 2022, 2023], 
                            "season_ID" : [0, 1, 2, 3, 4, 5, 6]}).
                with_columns(pl.col("season").cast(pl.Int64)).
                with_columns(pl.col("season_ID").cast(pl.Int32)))

In [46]:
rosters_index = (rosters.
                 select("season", "team", "position", "player_name", "jersey_number", "player_id").
                 unique().
                 with_columns(pl.col("team").str.replace("OAK", "LV")).
                 with_columns(pl.when(pl.col("jersey_number").is_null()).
                              then(pl.lit(0.0)).
                              otherwise(pl.col("jersey_number")).
                              alias("jersey_number")))

sub_index = (rosters_index.
             select("player_id").
             unique().
             with_columns(pl.arange(0, 7228).alias("player_ID")))

rosters_index = (rosters_index.
                 join(sub_index,
                      on = ["player_id"],
                      how = "left").
                 select("season", "position", "player_id", "player_ID").
                 unique())

rosters_index.sort("player_ID")

season,position,player_id,player_ID
i32,str,str,i64
2017,"""QB""","""00-0026898""",0
2018,"""QB""","""00-0026898""",0
2022,"""DB""","""00-0036147""",1
2023,"""DB""","""00-0036147""",1
2021,"""DB""","""00-0036147""",1
2020,"""DB""","""00-0036147""",1
2020,"""TE""","""00-0035756""",2
2022,"""QB""","""00-0037175""",3
2023,"""QB""","""00-0037175""",3
2018,"""WR""","""00-0031801""",4


In [47]:
spec_data = (season_data.
             select("season", "old_game_id", "play_id", "home_team", "away_team", "posteam", "defteam", "down", "offense_players", "defense_players", "play_type", "yards_gained").
             filter(pl.col("play_type").is_not_null()).
             filter(pl.col("play_type") != "no_play").
             filter(pl.col("offense_players") != "").
             with_columns(pl.when(pl.col("down").is_null()).
                          then(pl.lit(0.0)).
                          otherwise(pl.col("down")).
                          alias("down")).
             filter(pl.col("play_type").
                    is_in(["run", "pass"])).
             with_columns(pl.when(pl.col("yards_gained") > 0).
                          then(pl.lit(1.0)).
                          otherwise(pl.lit(0.0)).
                          alias("Success")).
             melt(id_vars = ["season", "old_game_id", "play_id", "home_team", "away_team", "posteam", "defteam", "down", "play_type", "yards_gained", "Success"], 
                  value_vars = ["offense_players", "defense_players"],
                  variable_name = "team",
                  value_name = "players").
             with_columns(pl.col("players").str.split(";")).
             explode("players").
             rename({"players" : "player_id"}).
             join(rosters.
                  select("season", "player_id", "depth_chart_position").
                  with_columns(pl.col("season").cast(pl.Int64)).
                  rename({"depth_chart_position" : "position"}).
                  unique(),
                  how = "left",
                  on = ["season", "player_id"]).
             with_columns(pl.col("team").str.replace("_players", "")).
             filter(pl.col("position").is_not_null()).
             rename({"team" : "OffDef"}).
             with_columns(pl.when(pl.col("OffDef") == "offense").
                          then(pl.col("posteam")).
                          otherwise(pl.col("defteam")).
                          alias("team")).
             drop("home_team", "away_team", "posteam", "defteam").
             with_columns(pl.when(pl.col("team") == "OAK").
                          then(pl.lit("LV")).
                          otherwise(pl.col("team")).
                          alias("team")).
             join(rosters_index.
                  select("season", "player_id", "player_ID").
                  with_columns(pl.col("season").cast(pl.Int64)),
                  on = ["season", "player_id"],
                  how = "left").
             with_columns(pl.col("player_ID").cumcount().over("player_ID").alias("count")).
             with_columns(pl.when(pl.col("count") < 25).
              then(pl.lit(7228)).
              otherwise(pl.col("player_ID")).
              alias("player_ID")).
             drop("count", "player_id"))

In [48]:
spec_data = (spec_data.
               join(new_positions_index.
                    drop("Cat"), 
                    on = "position",
                    how = "left").
               drop("position").
               group_by("season", "old_game_id", "play_id", "team", "OffDef", "down", "play_type", "yards_gained", "Success").
               agg(pl.col("position_ID"),
                    pl.col("player_ID")).
               with_columns(pl.when(pl.col("OffDef") == "offense").
                              then(pl.lit(1)).
                              otherwise(pl.lit(0)).
                              alias("OffDef_ID")).
               drop("OffDef").
               join(plays_index.
                    rename({"PlayType" : "play_type"}),
                    on = "play_type",
                    how = "left").
               drop("play_type").
               join(team_index,
                    on = "team",
                    how = "left").
               drop("team").
               with_columns(pl.col("down").cast(pl.Int32).alias("down_ID")).
               drop("down").
               rename({"old_game_id" : "gameId",
                         "play_id" : "playId"}).
               with_columns(pl.col("position_ID").list.lengths().alias("Length")).
               filter(pl.col("Length") == 11).
               drop("Length").
               with_columns(pl.col("gameId").cast(pl.Int32)).
               with_columns(pl.col("playId").cast(pl.Int32)).
               with_columns(pl.col("yards_gained").cast(pl.Int32)).
               join(season_index, 
                    on = "season",
                    how = "left").
               drop("season"))

In [70]:
new_data = (spec_data.
            select("gameId", "playId", "OffDef_ID").
            unique().
            group_by("gameId", "playId").
            count().
            filter(pl.col("count") == 2).
            drop("count").
            join(spec_data,
                 on = ["gameId", "playId"],
                 how = "left").
            with_columns(pl.col("player_ID").list.unique().list.lengths().alias("NB_unique")).
            filter(pl.col("NB_unique") == 11).
            drop("NB_unique"))

In [93]:
base_index = (new_data.
              select("gameId", "playId").
              group_by("gameId", "playId").
              count().
              filter(pl.col("count") == 2).
              drop("count"))

In [94]:
new_data = (base_index.
            join(new_data,
                 on = ["gameId", "playId"],
                 how = "left"))

## Negative lineups

In [95]:
def reshape_for_sampling(df, sample_size):
    to_insert = [[i for e in range(11)] for i in range(int(df.shape[0]/11)+1)]
    to_insert = [e for s in to_insert for e in s][:df.shape[0]]
    
    full_random = (df.
                   with_columns(pl.Series(to_insert).alias("ID")).
                   group_by("ID").
                   agg("position_ID", "player_ID", "OffDef_ID", "PlayType_ID", "team_ID", "down_ID", "season_ID").
                   sample(sample_size, with_replacement = True))
    
    return full_random

In [96]:
starter = (new_data.
               explode("position_ID", "player_ID").
               drop("gameId", "playId", "yards_gained", "Success").
               select(pl.all().shuffle()))

full_random = (reshape_for_sampling(starter, 84000).
               with_columns(pl.lit("full_random").alias("scheme")))

In [97]:
full_random_off = (starter.
                   filter(pl.col("OffDef_ID") == 1).
                   unique(subset = "player_ID"))

full_random_def = (starter.
                   filter(pl.col("OffDef_ID") == 0).
                   unique(subset = "player_ID"))

full_random_off = (reshape_for_sampling(full_random_off, 42000).
                   with_columns(pl.lit("full_random_off_def").alias("scheme")))

full_random_def = (reshape_for_sampling(full_random_def, 42000).
                   with_columns(pl.lit("full_random_off_def").alias("scheme")))

full_random_off_def = pl.concat([full_random_off, full_random_def])

In [98]:
full_random_teams = [(starter.filter(pl.col("team_ID") == v).unique(subset = "player_ID")) for v in np.unique(new_data.select("team_ID").to_series().to_list())]

full_random_teams = [reshape_for_sampling(d, 2400) for d in full_random_teams]

full_random_team = (pl.concat(full_random_teams).
                    with_columns(pl.lit("full_random_team").alias("scheme")))

In [99]:
filters = starter.select("team_ID", "season_ID").to_pandas().drop_duplicates().to_dict(orient = "records")

specific_teams = [(starter.
                   filter(pl.col("team_ID") == v["team_ID"]).
                   filter(pl.col("season_ID") == v["season_ID"]).
                   unique(subset = "player_ID")) for v in filters]

specific_teams = [reshape_for_sampling(d, 400) for d in specific_teams]

specific_team = (pl.concat(specific_teams).
                 with_columns(pl.lit("same_team_season").alias("scheme")))

In [100]:
filters = starter.select("season_ID", "OffDef_ID").to_pandas().drop_duplicates().to_dict(orient = "records")

specific_roles = [(starter.
                   filter(pl.col("OffDef_ID") == v["OffDef_ID"]).
                   filter(pl.col("season_ID") == v["season_ID"]).
                   unique(subset = "player_ID")) for v in filters]

specific_roles = [reshape_for_sampling(d, 10000) for d in specific_roles]

specific_role = (pl.concat(specific_roles).
                 with_columns(pl.lit("same_season_off_def").alias("scheme")))

In [101]:
filters = starter.select("OffDef_ID", "team_ID").to_pandas().drop_duplicates().to_dict(orient = "records")

specific_seasons = [(starter.
                   filter(pl.col("OffDef_ID") == v["OffDef_ID"]).
                   filter(pl.col("team_ID") == v["team_ID"]).
                   unique(subset = "player_ID")) for v in filters]

specific_seasons = [reshape_for_sampling(d, 1000) for d in specific_seasons]

specific_season = (pl.concat(specific_roles).
                   with_columns(pl.lit("same_team_off_def").alias("scheme")))

In [102]:
filters = starter.select("OffDef_ID", "team_ID", "season_ID").to_pandas().drop_duplicates().to_dict(orient = "records")

specific_playtypes = [(starter.
                       filter(pl.col("OffDef_ID") == v["OffDef_ID"]).
                       filter(pl.col("team_ID") == v["team_ID"]).
                       filter(pl.col("season_ID") == v["season_ID"]).
                       unique(subset = "player_ID")) for v in filters]

specific_playtypes = [reshape_for_sampling(d, 300) for d in specific_playtypes]

specific_playtype = (pl.concat(specific_playtypes).
                     with_columns(pl.col("PlayType_ID").list.unique().list.lengths().alias("NB_plays")).
                     filter(pl.col("NB_plays") == 2).
                     drop("NB_plays").
                     with_columns(pl.lit("same_team_season_offdef").alias("scheme")))

In [103]:
filters = starter.select("OffDef_ID", "team_ID", "season_ID", "PlayType_ID").to_pandas().drop_duplicates().to_dict(orient = "records")

specific_playtypes_2 = [(starter.
                       filter(pl.col("OffDef_ID") == v["OffDef_ID"]).
                       filter(pl.col("team_ID") == v["team_ID"]).
                       filter(pl.col("season_ID") == v["season_ID"]).
                       filter(pl.col("PlayType_ID") == v["PlayType_ID"]).
                       unique(subset = "player_ID")) for v in filters]

specific_playtypes_2 = [reshape_for_sampling(d, 300) for d in specific_playtypes_2]

specific_playtype_2 = (pl.concat(specific_playtypes_2).
                     with_columns(pl.lit("same_team_season_offdef_playtype").alias("scheme")))

## Dataset creation

In [104]:
negatives = (pl.concat([specific_season, specific_role, specific_team, full_random_off_def, full_random, specific_playtype, specific_playtype_2]).
             with_columns(pl.lit(0).alias("Label")).
             drop("ID").
             with_columns([pl.col("OffDef_ID").list.first().alias("OffDef_ID"),
                           pl.col("PlayType_ID").list.first().alias("PlayType_ID"),
                           pl.col("team_ID").list.first().alias("team_ID"),
                           pl.col("down_ID").list.first().alias("down_ID"),
                           pl.col("season_ID").list.first().alias("season_ID")]))

positives = (new_data.
             drop("gameId", "playId", "yards_gained", "Success").
             with_columns(pl.lit("positives").alias("scheme")).
             with_columns(pl.lit(1).alias("Label")))

len_total = positives.shape[0]+negatives.shape[0]

full_dataset = (pl.concat([positives, negatives]).
                with_columns(pl.Series(range(len_total)).alias("ID")).
                with_columns(pl.col("position_ID").list.lengths().alias("Len")).
                filter(pl.col("Len") == 11).
                drop("Len").
                with_columns(pl.col("player_ID").list.unique().list.lengths().alias("NB_unique")).
                filter(pl.col("NB_unique") == 11).
                drop("NB_unique"))

In [105]:
(full_dataset.
 select("scheme").
 group_by("scheme").
 count()).to_pandas()

Unnamed: 0,scheme,count
0,full_random,80686
1,same_season_off_def,139574
2,same_team_season_offdef_playtype,267349
3,same_team_season,89251
4,full_random_off_def,83873
5,same_team_season_offdef,133435
6,same_team_off_def,139574
7,positives,404782


In [106]:
(full_dataset.
 with_columns(pl.col("player_ID").list.unique().list.lengths().alias("NB_unique")).
 with_columns(pl.col("player_ID").list.lengths().alias("Length")).
 filter(pl.col("NB_unique") != 11))

position_ID,player_ID,OffDef_ID,PlayType_ID,team_ID,down_ID,season_ID,scheme,Label,ID,NB_unique,Length
list[i64],list[i64],i32,i64,i64,i32,i32,str,i32,i64,u32,u32


In [107]:
from sklearn.model_selection import train_test_split

train_test_df = (full_dataset.
                 select("ID", "scheme").
                 unique()).to_pandas()

train, test = train_test_split(train_test_df, test_size= 0.3, stratify = train_test_df["scheme"].to_numpy())

In [108]:
train_test_df_helenos = (new_data.
                         select("gameId", "playId", "Success").
                         unique()).to_pandas()

train_helenos, test_helenos = train_test_split(train_test_df_helenos, test_size= 0.3, stratify = train_test_df_helenos["Success"].to_numpy())

In [109]:
train_data = (pl.from_pandas(train).
              drop("Label").
              join(full_dataset,
                   on = "ID",
                   how = "left"))

test_data = (pl.from_pandas(test).
             drop("Label").
             join(full_dataset,
                  on = "ID",
                  how = "left"))

In [110]:
train_data_helenos = (pl.from_pandas(train_helenos).
                      drop("Success").
                      join(new_data,
                           on = ["gameId", "playId"],
                           how = "left"))

test_data_helenos = (pl.from_pandas(test_helenos).
                     drop("Success").
                     join(new_data,
                          on = ["gameId", "playId"],
                          how = "left"))

In [111]:
train_data_helenos.shape[0] + test_data_helenos.shape[0]

404782

In [113]:
pos_val = 0
scrim_val = 99
start_val = 1032

In [114]:
train_seq_dict = {row["ID"] : 
    {"input_ids" : [10877 for i in range(len(row["position_ID"]))],
     "player_ids": row["player_ID"],
     "position_ids": row["position_ID"],
     "OffDef" : [row["OffDef_ID"] for i in range(len(row["position_ID"]))],
     "token_type_ids" : [0 for i in range(len(row["position_ID"]))],
     "pos_ids" : [pos_val for i in range(len(row["position_ID"]))],
     "team_ID" : [row["team_ID"] for i in range(len(row["position_ID"]))],
     "start_ids" : [start_val for i in range(len(row["position_ID"]))],
     "scrim_ids" : [scrim_val for i in range(len(row["position_ID"]))],
     "attention_mask" : [1 for i in range(len(row["position_ID"]))],
     "spec_token" : [0],
     "PlayType" : [row["PlayType_ID"] for i in range(len(row["position_ID"]))],
     "down_ID" : [row["down_ID"] for i in range(len(row["position_ID"]))],
     "season_ID" : [row["season_ID"] for i in range(len(row["position_ID"]))],
     "label" : row["Label"]} for row in train_data.iter_rows(named=True)}

test_seq_dict = {row["ID"] : 
    {"input_ids" : [10877 for i in range(len(row["position_ID"]))],
     "player_ids": row["player_ID"],
     "position_ids": row["position_ID"],
     "OffDef" : [row["OffDef_ID"] for i in range(len(row["position_ID"]))],
     "token_type_ids" : [0 for i in range(len(row["position_ID"]))],
     "pos_ids" : [pos_val for i in range(len(row["position_ID"]))],
     "team_ID" : [row["team_ID"] for i in range(len(row["position_ID"]))],
     "start_ids" : [start_val for i in range(len(row["position_ID"]))],
     "scrim_ids" : [scrim_val for i in range(len(row["position_ID"]))],
     "attention_mask" : [1 for i in range(len(row["position_ID"]))],
     "spec_token" : [0],
     "PlayType" : [row["PlayType_ID"] for i in range(len(row["position_ID"]))],
     "down_ID" : [row["down_ID"] for i in range(len(row["position_ID"]))],
     "season_ID" : [row["season_ID"] for i in range(len(row["position_ID"]))],
     "label" : row["Label"]} for row in test_data.iter_rows(named=True)}

In [115]:
train_seq_helenos = {str(row["gameId"]) + "_" + str(row["playId"]) + "_" + str(row["OffDef_ID"]) : 
    {"input_ids" : [10877 for i in range(len(row["position_ID"]))],
     "player_ids": row["player_ID"],
     "position_ids": row["position_ID"],
     "OffDef" : [row["OffDef_ID"] for i in range(len(row["position_ID"]))],
     "token_type_ids" : [0 for i in range(len(row["position_ID"]))],
     "pos_ids" : [pos_val for i in range(len(row["position_ID"]))],
     "team_ID" : [row["team_ID"] for i in range(len(row["position_ID"]))],
     "start_ids" : [start_val for i in range(len(row["position_ID"]))],
     "scrim_ids" : [scrim_val for i in range(len(row["position_ID"]))],
     "attention_mask" : [1 for i in range(len(row["position_ID"]))],
     "spec_token" : [0],
     "PlayType" : [row["PlayType_ID"] for i in range(len(row["position_ID"]))],
     "down_ID" : [row["down_ID"] for i in range(len(row["position_ID"]))],
     "season_ID" : [row["season_ID"] for i in range(len(row["position_ID"]))],
     "Success" : row["Success"], 
     "yards_gained" : row["yards_gained"],
     "gameId" : row["gameId"],
     "playId" : row["playId"]} for row in train_data_helenos.iter_rows(named=True)}

test_seq_helenos = {str(row["gameId"]) + "_" + str(row["playId"]) + "_" + str(row["OffDef_ID"]) : 
    {"input_ids" : [10877 for i in range(len(row["position_ID"]))],
     "player_ids": row["player_ID"],
     "position_ids": row["position_ID"],
     "OffDef" : [row["OffDef_ID"] for i in range(len(row["position_ID"]))],
     "token_type_ids" : [0 for i in range(len(row["position_ID"]))],
     "pos_ids" : [pos_val for i in range(len(row["position_ID"]))],
     "team_ID" : [row["team_ID"] for i in range(len(row["position_ID"]))],
     "start_ids" : [start_val for i in range(len(row["position_ID"]))],
     "scrim_ids" : [scrim_val for i in range(len(row["position_ID"]))],
     "attention_mask" : [1 for i in range(len(row["position_ID"]))],
     "spec_token" : [0],
     "PlayType" : [row["PlayType_ID"] for i in range(len(row["position_ID"]))],
     "down_ID" : [row["down_ID"] for i in range(len(row["position_ID"]))],
     "season_ID" : [row["season_ID"] for i in range(len(row["position_ID"]))],
     "Success" : row["Success"], 
     "yards_gained" : row["yards_gained"],
     "gameId" : row["gameId"],
     "playId" : row["playId"]} for row in test_data_helenos.iter_rows(named=True)

In [116]:
train_common_keys = np.unique([str(row["gameId"]) + "_" + str(row["playId"]) for row in train_data_helenos.iter_rows(named=True)])
train_off_keys = [v + "_1" for v in train_common_keys]
train_def_keys = [v + "_0" for v in train_common_keys]

test_common_keys = np.unique([str(row["gameId"]) + "_" + str(row["playId"]) for row in test_data_helenos.iter_rows(named=True)])
test_off_keys = [v + "_1" for v in test_common_keys]
test_def_keys = [v + "_0" for v in test_common_keys]

train_off_seq = [train_seq_helenos[v] for v in train_off_keys]
train_def_seq = [train_seq_helenos[v] for v in train_def_keys]

test_off_seq = [test_seq_helenos[v] for v in test_off_keys]
test_def_seq = [test_seq_helenos[v] for v in test_def_keys]

In [121]:
from tqdm import tqdm

def compile_seq(list_of_trajs):
    merged_dict = {k : [] for k in list_of_trajs[0].keys()}

    with tqdm(total=len(list_of_trajs)) as pbar:
      for d in list_of_trajs:
        for key, value in d.items():
          merged_dict[key] += [value]
        pbar.update(1)
        
    merged_dict = {k: np.array(v) for k,v in merged_dict.items()}
    return merged_dict

In [122]:
train_OFF_helenos = compile_seq(train_off_seq)
train_DEF_helenos = compile_seq(train_def_seq)

test_OFF_helenos = compile_seq(test_off_seq)
test_DEF_helenos = compile_seq(test_def_seq)


train_helenos = {"off" : train_OFF_helenos,
                 "def" : train_DEF_helenos}

test_helenos = {"off" : test_OFF_helenos,
                "def" : test_DEF_helenos}

100%|██████████| 141673/141673 [00:00<00:00, 541741.66it/s]
100%|██████████| 141673/141673 [00:00<00:00, 628451.38it/s]
100%|██████████| 60718/60718 [00:00<00:00, 529530.50it/s]
100%|██████████| 60718/60718 [00:00<00:00, 607983.63it/s]


In [123]:
train = compile_seq(list(train_seq_dict.values()))
test = compile_seq(list(test_seq_dict.values()))

100%|██████████| 936966/936966 [00:01<00:00, 862349.59it/s]
100%|██████████| 401558/401558 [00:00<00:00, 855290.66it/s]


In [124]:
import tensorflow as tf

train_total = tf.data.Dataset.from_tensor_slices(train)
train_labels = tf.data.Dataset.from_tensor_slices(train["label"])

train_dataset = tf.data.Dataset.zip((train_total, train_labels))

test_total = tf.data.Dataset.from_tensor_slices(test)
test_labels = tf.data.Dataset.from_tensor_slices(test["label"])

test_dataset = tf.data.Dataset.zip((test_total, test_labels))

tf.data.Dataset.save(train_dataset, "data_models/StratFormer/train_data")
tf.data.Dataset.save(test_dataset, "data_models/StratFormer/test_data")

In [125]:
train_dataset = tf.data.Dataset.from_tensor_slices(train_helenos)

test_dataset = tf.data.Dataset.from_tensor_slices(test_helenos)

tf.data.Dataset.save(train_dataset, "data_models/Helenos/train_data")
tf.data.Dataset.save(test_dataset, "data_models/Helenos/test_data")

In [127]:
train_helenos.keys()

dict_keys(['off', 'def'])

In [136]:
train_total = tf.data.Dataset.from_tensor_slices(train_helenos)
train_labels = tf.data.Dataset.from_tensor_slices(train_helenos["off"]["yards_gained"].astype("float32"))
train_dataset = tf.data.Dataset.zip((train_total, train_labels))

test_total = tf.data.Dataset.from_tensor_slices(test_helenos)
test_labels = tf.data.Dataset.from_tensor_slices(test_helenos["off"]["yards_gained"].astype("float32"))
test_dataset = tf.data.Dataset.zip((test_total, test_labels))

tf.data.Dataset.save(train_dataset, "data_models/Helenos/train_data_tfp")
tf.data.Dataset.save(test_dataset, "data_models/Helenos/test_data_tfp")