In [1]:
import os
import pandas as pd
import numpy as np
import polars as pl
import nfl_data_py as nfl

In [2]:
os.listdir()

['positions_index.parquet',
 'final_df.parquet',
 '.DS_Store',
 'train_play_prediction_binary',
 'scrimmage_index.parquet',
 'starts_index.parquet',
 'index.parquet',
 'train_split.csv',
 'test_play_prediction_binary',
 '4_play_pred.ipynb',
 'moves_index.parquet',
 '2_dataset_mapping.ipynb',
 'time_index.parquet',
 'mapped_df.parquet',
 'class_weights.parquet',
 'test_split.csv',
 'test_play_prediction_categ',
 'train_tokens_NFL_GPT',
 'train_test_split.csv',
 'train_play_prediction_categ',
 '3_sequence_creation.ipynb',
 'plays_index.parquet',
 'test_tokens_NFL_GPT']

In [3]:
years_to_get = [2017, 2018, 2019, 2020, 2021, 2022]
season_data = pl.from_pandas(nfl.import_pbp_data(years_to_get))

2017 done.
2018 done.
2019 done.
2020 done.
2021 done.
2022 done.
Downcasting floats.


In [4]:
data = (pl.read_parquet("final_df.parquet").
        with_columns(pl.when(pl.col("team") == "OAK").
                     then(pl.lit("LV")).
                     otherwise(pl.col("team")).
                     alias("team")).
        filter(pl.col("Length") > 10).
        explode("frameId", "x", "y").
        filter(pl.col("frameId") < 102).
        group_by('season_type', 'season', 'gameId', 'qtr', 'down', 'yardline_100', 'playId', 'PlayType', 'team', 'OffDef', 'nflId', 'position', 'Length').
        agg(pl.col("frameId"),
            pl.col("x"),
            pl.col("y")))

In [5]:
data = (data.
        join(season_data.
             select("old_game_id", "play_id", "posteam").
             rename({"old_game_id" : "gameId",
                     "play_id" : "playId"}).
             with_columns([
                pl.col("gameId").cast(pl.Int64),
                pl.col("playId").cast(pl.Int64)
             ]),
             on = ["gameId", "playId"],
             how = "left").
        with_columns(pl.when(pl.col("team") == pl.col("posteam")).
                     then(pl.lit("Offense")).
                     otherwise(pl.lit("Defense")).
                     alias("OffDef")).
        drop("posteam"))

In [6]:
updated_scrim_side = (season_data.
                      select("old_game_id", "play_id", "posteam", "home_team", "away_team", "yardline_100").
                      filter(pl.col("posteam").is_not_null()).
                      with_columns(pl.when(pl.col("posteam") == pl.col("home_team")).
                                  then(100 - pl.col("yardline_100") +10).
                                  otherwise(pl.col("yardline_100") +10).
                                  alias("line_scrimmage")).
                      select("old_game_id", "play_id", "posteam", "home_team", "away_team", "line_scrimmage").
                      melt(id_vars = ["old_game_id", "posteam", "play_id", "line_scrimmage"],
                          value_vars = ["home_team", "away_team"],
                          value_name = "team").
                      with_columns(pl.when(pl.col("posteam") == pl.col("team")).
                                  then(pl.lit("Offense")).
                                  otherwise(pl.lit("Defense")).
                                  alias("OffDef")).
                      drop("posteam", "team").
                      with_columns(pl.when(pl.col("variable") == "home_team").
                                  then(pl.lit("Left")).
                                  otherwise(pl.lit("Right")).
                                  alias("Side")).
                      drop("variable").
                      rename({"old_game_id": "gameId",
                              "play_id" : "playId"}).
                      with_columns([
                          pl.col("gameId").cast(pl.Int64),
                          pl.col("playId").cast(pl.Int64)]))

In [7]:
data = (data.
        drop("yardline_100").
        join(updated_scrim_side,
             on = ["gameId", "playId", "OffDef"],
             how = "left"))

In [8]:
check = (data.
         select("gameId", "OffDef", "playId", "frameId", "x", "line_scrimmage", "Side").
         explode("frameId", "x").
         filter(pl.col("frameId") == 1).
         drop("frameId").
         group_by("gameId", "OffDef", "playId", "line_scrimmage", "Side").
         mean().
         rename({"line_scrimmage" : "line_scrimmage_1"}).
         with_columns((120-pl.col("line_scrimmage_1")).
                      alias("line_scrimmage_2")).
         melt(id_vars= ["gameId", "playId", "OffDef", "Side",	"x"	], value_vars=["line_scrimmage_1", "line_scrimmage_2"], variable_name = "Scrimmage", value_name = "line").
         with_columns((pl.col("x") - pl.col("line")).abs().alias("dist")))

middle_scrimmage = (check.
                    filter(pl.col("line") == 60).
                    filter(pl.col("Scrimmage") == "line_scrimmage_1").
                    select("gameId", "playId", "Scrimmage"))

scrimmage_comparison = (check.
                        filter(pl.col("line") != 60).
                        drop("OffDef", "Side", "x", "line").
                        group_by("gameId", "playId", "Scrimmage").
                        mean().
                        filter(pl.col('dist') == pl.col('dist').min().over(["gameId", "playId"])).
                        drop("dist"))

scrimmage_comparison_simple = (scrimmage_comparison.
                               group_by("gameId", "playId").
                               count().
                               filter(pl.col("count") == 1).
                               drop("count").
                               join(scrimmage_comparison,
                                    on = ["gameId", "playId"],
                                    how = "left"))

scrimmage_comparison_double = (scrimmage_comparison.
                               group_by("gameId", "playId").
                               count().
                               filter(pl.col("count") == 2).
                               drop("count").
                               join(check,
                                    on = ["gameId", "playId"],
                                    how = "left").
                               filter(pl.col("OffDef") == "Offense").
                               filter(pl.col('dist') == pl.col('dist').min().over(["gameId", "playId"])).
                               unique().
                               select("gameId", "playId", "Scrimmage"))

scrimmage_final = pl.concat([middle_scrimmage, scrimmage_comparison_simple, scrimmage_comparison_double])

scrimmage_checked = (scrimmage_final.
                     join(check,
                          on = ["gameId", "playId", "Scrimmage"],
                          how = "left").
                     select("gameId", "playId", "OffDef", "Side", "line").
                     rename({"line" : "line_scrimmage"}))

data = (data.
        drop("line_scrimmage", "Side").
        join(scrimmage_checked,
             on = ["gameId", "playId", "OffDef"],
             how = "left"))

## Positions to movements and starting place

In [9]:
data = (data.
        explode("x", "y", "frameId").
        group_by('season_type', 'season', 'gameId', 'qtr', 'down', 'playId', 'PlayType', 'team', 'OffDef', 'nflId', 'position', 'Length', 'Side','line_scrimmage').
        agg(pl.col("x").
            sort_by("frameId"),
            pl.col("y").
            sort_by("frameId"),
            pl.col("frameId").
            sort_by("frameId")).
        with_columns(pl.lit(26.5).cast(pl.Int64).alias("line_scrimmage_y")).
        rename({"line_scrimmage" : "line_scrimmage_x"}).
        with_columns(pl.col("x").list.first().alias("Starting_x")).
        with_columns(pl.col("y").list.first().alias("Starting_y")).
        with_columns((pl.col("line_scrimmage_x") - pl.col("Starting_x")).alias("Starting_x")).
        with_columns((pl.col("line_scrimmage_y") - pl.col("Starting_y")).alias("Starting_y")).
        with_columns(
            [
                pl.col("x").list.first().alias("first_x"),
                pl.col("y").list.first().alias("first_y")]
            ).
        explode("frameId", "x", "y").
        with_columns(
            [
                (pl.col("x") - pl.col("first_x")).alias("x"),
                (pl.col("y") - pl.col("first_y")).alias("y")]
            ).
        drop("first_x", "first_y").
        filter(pl.col("frameId") != 1).
        with_columns(pl.when(pl.col("Side") == "Right").
                    then(pl.col("x") * -1).
                    otherwise(pl.col("x")).
                    alias("x")).
        with_columns(pl.when(pl.col("Side") == "Right").
                    then(pl.col("y") * -1).
                    otherwise(pl.col("y")).
                    alias("y")).
        with_columns(pl.when(pl.col("Side") == "Right").
                    then(pl.col("Starting_x") * -1).
                    otherwise(pl.col("Starting_x")).
                    alias("Starting_x")).
        with_columns(pl.when(pl.col("Side") == "Right").
                    then(pl.col("Starting_y") * -1).
                    otherwise(pl.col("Starting_y")).
                    alias("Starting_y")).
        group_by("season_type", "season", "gameId", "qtr", "down", "playId", "PlayType", "team", "OffDef", "nflId", "position", "Length", "Side", "line_scrimmage_x", "line_scrimmage_y", "Starting_x", "Starting_y").
        agg(pl.col("x"), pl.col("y"), pl.col("frameId")).
        drop("line_scrimmage_y").
        rename({"line_scrimmage_x" : "line_scrimmage"}))

In [10]:
exploded_data = (data.
                explode("frameId", "x",	"y").
                with_columns([
                    (pl.col("x").cast(pl.Int64)).alias("x"),
                    (pl.col("y").cast(pl.Int64)).alias("y"),
                    (pl.col("Starting_x").cast(pl.Int64)).alias("Starting_x"),
                    (pl.col("Starting_y").cast(pl.Int64)).alias("Starting_y")
                    ]).
                with_columns(pl.when(pl.col("x") == -0).
                             then(pl.lit(0)).
                             otherwise(pl.col("x")).
                             alias("x")).
                with_columns(pl.when(pl.col("y") == -0).
                             then(pl.lit(0)).
                             otherwise(pl.col("y")).
                             alias("y")).
                with_columns(pl.when(pl.col("Starting_x") == -0).
                             then(pl.lit(0)).
                             otherwise(pl.col("Starting_x")).
                             alias("Starting_x")).
                with_columns(pl.when(pl.col("Starting_y") == -0).
                             then(pl.lit(0)).
                             otherwise(pl.col("Starting_y")).
                             alias("Starting_y")))

## Zones and time frames

### Zones

In [11]:
min_x = round(exploded_data.select("x").min().to_series().to_list()[0])
max_x = round(exploded_data.select("x").max().to_series().to_list()[0])
min_y = round(exploded_data.select("y").min().to_series().to_list()[0])
max_y = round(exploded_data.select("y").max().to_series().to_list()[0])

print("Max of x is ", max_x)
print("Min of x is ", min_x)
print("Max of y is ", max_y)
print("Min of y is ", min_y)

Max of x is  77
Min of x is  -75
Max of y is  49
Min of y is  -50


### Starting zone

In [12]:
min_start_x = exploded_data.select("Starting_x").min().to_series().to_list()[0]
max_start_x = exploded_data.select("Starting_x").max().to_series().to_list()[0]
min_start_y = exploded_data.select("Starting_y").min().to_series().to_list()[0]
max_start_y = exploded_data.select("Starting_y").max().to_series().to_list()[0]

print("Max of x is ", max_start_x)
print("Min of x is ", min_start_x)
print("Max of y is ", max_start_y)
print("Min of y is ", min_start_y)

Max of x is  75
Min of x is  -76
Max of y is  35
Min of y is  -26


In [21]:
move_step = 1
start_step = 1

moves_index = (pl.DataFrame({"x" : range(min_x, max_x + move_step, move_step),
                             "y" : [list(range(min_y, max_y + move_step, move_step)) for v in range(min_x, max_x + move_step, move_step)]}).
               explode("y").
               join(exploded_data.
                    select("x", "y").
                    unique().
                    with_columns(pl.lit(1).alias("Check")),
                    on = ["x", "y"],
                    how = "left").
               filter(pl.col("Check").is_not_null()).
               drop("Check"))

starts_index = (pl.DataFrame({"Starting_x" : range(min_start_x-start_step, max_start_x + start_step, start_step),
                             "Starting_y" : [list(range(min_start_y, max_start_y + start_step, start_step)) for v in range(min_start_x-start_step, max_start_x + start_step, start_step)]}).
               explode("Starting_y").
               join(exploded_data.
                    select("Starting_x", "Starting_y").
                    unique().
                    with_columns(pl.lit(1).alias("Check")),
                    on = ["Starting_x", "Starting_y"],
                    how = "left").
               filter(pl.col("Check").is_not_null()).
               drop("Check"))

starts_core = (starts_index.
               filter((pl.col("Starting_x") >= -20).and_(pl.col("Starting_x") <= 20)))


zones_max = moves_index.shape[0]
starts_core_max = starts_core.shape[0]
starts_max = starts_core_max+2

scrimmage_max = 99
positions_max= 28

starts_before = (starts_index.
               filter((pl.col("Starting_x") < -20)).
               with_columns(pl.lit(starts_core_max).alias("Start_ID")))
starts_after = (starts_index.
               filter((pl.col("Starting_x") > 20)).
               with_columns(pl.lit(starts_core_max+1).alias("Start_ID")))

In [22]:
zones_max

11164

In [23]:
starts_core_max

1981

In [24]:
starts_max

1983

In [25]:
moves_index = (moves_index.
               with_columns(pl.arange(0, zones_max).alias("Zone_ID")))


starts_index = pl.concat([(starts_index.
                           filter((pl.col("Starting_x") >= -20).and_(pl.col("Starting_x") <= 20)).
                           with_columns(pl.arange(0, starts_core_max).cast(pl.Int32).alias("Start_ID"))),
                          starts_before,
                          starts_after])


scrimmage_index = (data.
                   select("line_scrimmage").
                   unique().
                   with_columns(pl.arange(0, scrimmage_max).alias("Scrimmage_ID")))

time_index = (data.
              select("frameId").
              explode("frameId").
              unique().
              sort("frameId").
              with_columns(pl.arange(1, 51).alias("Frame_ID")))

In [26]:
scrimmage_index

line_scrimmage,Scrimmage_ID
f64,i64
27.0,0
68.0,1
71.0,2
109.0,3
47.0,4
55.0,5
79.0,6
98.0,7
84.0,8
14.0,9


In [27]:
moves_index.head(5)

x,y,Zone_ID
i64,i64,i64
-75,14,0
-73,13,1
-72,0,2
-72,9,3
-72,12,4


In [28]:
starts_index.select("Start_ID").unique().sort("Start_ID").tail(5)

Start_ID
i32
1978
1979
1980
1981
1982


In [29]:
scrimmage_index.tail(5)

line_scrimmage,Scrimmage_ID
f64,i64
60.0,94
89.0,95
38.0,96
46.0,97
101.0,98


In [30]:
time_index.tail(5)

frameId,Frame_ID
i64,i64
93,46
95,47
97,48
99,49
101,50


In [31]:
updated_data = (exploded_data.
                join(starts_index,
                     on = ["Starting_x", "Starting_y"],
                     how = "left").
                join(time_index,
                     on = ["frameId"],
                     how = "left").
                join(moves_index,
                     on = ["x", "y"],
                     how = "left").
                join(scrimmage_index,
                     on = "line_scrimmage",
                     how = "left").
                drop("frameId", "x", "y", "line_scrimmage", "Starting_x", "Starting_y").
                group_by("season_type", "season", "gameId", "qtr", "down", "playId", "Scrimmage_ID", "Side", "PlayType", "team", "OffDef", "nflId", "position", "Length", "Start_ID").
                agg([pl.col("Frame_ID"),
                     pl.col("Zone_ID")]).
                with_columns(pl.col("Zone_ID").list.unique().list.lengths().alias("n_unique")).
                filter(pl.col("n_unique") != 1).
                drop("n_unique"))

In [32]:
class_weights = (updated_data.
                 select("Zone_ID").
                 explode("Zone_ID"))

class_weights.write_parquet("class_weights.parquet")

In [33]:
(class_weights.
 filter(pl.col("Zone_ID") != 5517).
 group_by("Zone_ID").
 count().
 sort("count").
 sum())

Zone_ID,count
i64,u32
62306349,20591244


## Play  types

In [34]:
plays_index = (updated_data.
              select("PlayType").
              unique().
              sort("PlayType").
              with_columns(pl.arange(0, 8).alias("PlayType_ID")))

In [35]:
plays_index

PlayType,PlayType_ID
str,i64
"""extra_point""",0
"""field_goal""",1
"""kickoff""",2
"""no_play""",3
"""pass""",4
"""punt""",5
"""qb_spike""",6
"""run""",7


In [36]:
updated_data = (updated_data.
                join(plays_index,
                     on = "PlayType",
                     how = "left").
                drop("PlayType"))

## OffDef

In [37]:
updated_data = (updated_data.
                with_columns(pl.when(pl.col("OffDef") == "Offense").then(1).otherwise(0).alias("OffDef_ID")))

## Position

In [38]:
positions_index = (data.
                   select("position").
                   unique().
                   sort("position").
                   with_columns(pl.arange(0, positions_max).alias("position_ID")))

In [39]:
moves_index = (moves_index.
               with_columns(pl.lit("Moves").alias("Cat")))

starts_index = (starts_index.
                with_columns(pl.lit("Start").alias("Cat")))

scrimmage_index = (scrimmage_index.
                   with_columns(pl.lit("Scrimm").alias("Cat")))

positions_index = (positions_index.
                   with_columns(pl.lit("Pos").alias("Cat")))

In [40]:
time_index.write_parquet("time_index.parquet")
moves_index.write_parquet("moves_index.parquet")
positions_index.write_parquet("positions_index.parquet")
plays_index.write_parquet("plays_index.parquet")
starts_index.write_parquet("starts_index.parquet")
scrimmage_index.write_parquet("scrimmage_index.parquet")

In [52]:
starts_index

Starting_x,Starting_y,Start_ID,Cat
i64,i64,i32,str
-20,-22,0,"""Start"""
-20,-21,1,"""Start"""
-20,-20,2,"""Start"""
-20,-19,3,"""Start"""
-20,-18,4,"""Start"""
-20,-17,5,"""Start"""
-20,-16,6,"""Start"""
-20,-15,7,"""Start"""
-20,-14,8,"""Start"""
-20,-13,9,"""Start"""


In [51]:
positions_index

position,position_ID,Cat
str,i64,str
"""C""",0,"""Pos"""
"""CB""",1,"""Pos"""
"""DB""",2,"""Pos"""
"""DE""",3,"""Pos"""
"""DL""",4,"""Pos"""
"""DT""",5,"""Pos"""
"""FB""",6,"""Pos"""
"""FS""",7,"""Pos"""
"""G""",8,"""Pos"""
"""HB""",9,"""Pos"""


In [50]:
scrimmage_index

line_scrimmage,Scrimmage_ID,Cat
f64,i64,str
27.0,0,"""Scrimm"""
68.0,1,"""Scrimm"""
71.0,2,"""Scrimm"""
109.0,3,"""Scrimm"""
47.0,4,"""Scrimm"""
55.0,5,"""Scrimm"""
79.0,6,"""Scrimm"""
98.0,7,"""Scrimm"""
84.0,8,"""Scrimm"""
14.0,9,"""Scrimm"""


In [41]:
starts_index.select("Start_ID").unique()

Start_ID
i32
0
1
2
3
4
5
6
7
8
9


In [42]:
index = pl.concat([(moves_index.
                    rename({"Zone_ID" : "ID"}).
                    select("Cat", "ID").
                    with_columns(pl.col("ID").cast(pl.Int32))), 
                    (starts_index.
                    rename({"Start_ID" : "ID"}).
                    select("Cat", "ID").
                    with_columns(pl.col("ID").cast(pl.Int32))), 
                    (scrimmage_index.
                    rename({"Scrimmage_ID" : "ID"}).
                    select("Cat", "ID").
                    with_columns(pl.col("ID").cast(pl.Int32))),
                    (positions_index.
                    rename({"position_ID" : "ID"}).
                    select("Cat", "ID").
                    with_columns(pl.col("ID").cast(pl.Int32)))]).unique().sort("ID")



index

Cat,ID
str,i32
"""Pos""",0
"""Moves""",0
"""Start""",0
"""Scrimm""",0
"""Scrimm""",1
"""Start""",1
"""Pos""",1
"""Moves""",1
"""Pos""",2
"""Start""",2


In [43]:
index.write_parquet("index.parquet")

In [44]:
updated_data = (updated_data.
                join(positions_index,
                     on = "position",
                     how = "left").
                drop("position"))

## Side ID

In [45]:
updated_data = (updated_data.
                with_columns(pl.when(pl.col("Side") == "Right").
                             then(1).
                             otherwise(0).
                             alias("side_ID")))

In [46]:
updated_data = (updated_data.
                select("gameId", "playId", "PlayType_ID", "OffDef_ID", "side_ID", "nflId", "position_ID", "Scrimmage_ID", "Start_ID", "Frame_ID", "Zone_ID"))

In [47]:
updated_data.write_parquet("mapped_df.parquet")

In [48]:
(updated_data.
 select("Start_ID").
 filter(pl.col("Start_ID").is_null()))

Start_ID
i32
