In [1]:
import os
import pandas as pd
import numpy as np
import polars as pl
import nfl_data_py as nfl

env = "local"

In [2]:
if env == "local":
    os.chdir("/Users/samuel/Documents/GitHub/QB-GPT/")
else:
    from google.colab import drive
    drive.mount('/content/gdrive')
    os.chdir("/content/gdrive/MyDrive/NFL_Challenge/NFL-GPT/NFL data")

In [3]:
os.listdir()

['data_models',
 '.DS_Store',
 'app',
 'LICENSE',
 'models',
 'README.md',
 '.gitignore',
 '.gitattributes',
 'data_preprocessing',
 'index',
 '.git',
 'notebooks']

In [4]:
years_to_get = [2017, 2018, 2019, 2020, 2021, 2022]
season_data = pl.from_pandas(nfl.import_pbp_data(years_to_get))

2017 done.
2018 done.
2019 done.
2020 done.
2021 done.
2022 done.
Downcasting floats.


In [5]:
data = (pl.read_parquet("data_preprocessing/0_raw/final_df.parquet").
        with_columns(pl.when(pl.col("team") == "OAK").
                     then(pl.lit("LV")).
                     otherwise(pl.col("team")).
                     alias("team")).
        filter(pl.col("Length") > 10).
        explode("frameId", "x", "y").
        filter(pl.col("frameId") < 102).
        group_by('season_type', 'season', 'gameId', 'qtr', 'down', 'yardline_100', 'playId', 'PlayType', 'team', 'OffDef', 'nflId', 'position', 'Length').
        agg(pl.col("frameId"),
            pl.col("x"),
            pl.col("y")))

In [6]:
data = (data.
        join(season_data.
             select("old_game_id", "play_id", "posteam").
             rename({"old_game_id" : "gameId",
                     "play_id" : "playId"}).
             with_columns([
                pl.col("gameId").cast(pl.Int64),
                pl.col("playId").cast(pl.Int64)
             ]),
             on = ["gameId", "playId"],
             how = "left").
        with_columns(pl.when(pl.col("team") == pl.col("posteam")).
                     then(pl.lit("Offense")).
                     otherwise(pl.lit("Defense")).
                     alias("OffDef")).
        drop("posteam"))

In [7]:
updated_scrim_side = (season_data.
                      select("old_game_id", "play_id", "posteam", "home_team", "away_team", "yardline_100").
                      filter(pl.col("posteam").is_not_null()).
                      with_columns(pl.when(pl.col("posteam") == pl.col("home_team")).
                                  then(100 - pl.col("yardline_100") +10).
                                  otherwise(pl.col("yardline_100") +10).
                                  alias("line_scrimmage")).
                      select("old_game_id", "play_id", "posteam", "home_team", "away_team", "line_scrimmage").
                      melt(id_vars = ["old_game_id", "posteam", "play_id", "line_scrimmage"],
                          value_vars = ["home_team", "away_team"],
                          value_name = "team").
                      with_columns(pl.when(pl.col("posteam") == pl.col("team")).
                                  then(pl.lit("Offense")).
                                  otherwise(pl.lit("Defense")).
                                  alias("OffDef")).
                      drop("posteam", "team").
                      with_columns(pl.when(pl.col("variable") == "home_team").
                                  then(pl.lit("Left")).
                                  otherwise(pl.lit("Right")).
                                  alias("Side")).
                      drop("variable").
                      rename({"old_game_id": "gameId",
                              "play_id" : "playId"}).
                      with_columns([
                          pl.col("gameId").cast(pl.Int64),
                          pl.col("playId").cast(pl.Int64)]))

In [8]:
data = (data.
        drop("yardline_100").
        join(updated_scrim_side,
             on = ["gameId", "playId", "OffDef"],
             how = "left"))

In [9]:
check = (data.
         select("gameId", "OffDef", "playId", "frameId", "x", "line_scrimmage", "Side").
         explode("frameId", "x").
         filter(pl.col("frameId") == 1).
         drop("frameId").
         group_by("gameId", "OffDef", "playId", "line_scrimmage", "Side").
         mean().
         rename({"line_scrimmage" : "line_scrimmage_1"}).
         with_columns((120-pl.col("line_scrimmage_1")).
                      alias("line_scrimmage_2")).
         melt(id_vars= ["gameId", "playId", "OffDef", "Side",	"x"	], value_vars=["line_scrimmage_1", "line_scrimmage_2"], variable_name = "Scrimmage", value_name = "line").
         with_columns((pl.col("x") - pl.col("line")).abs().alias("dist")))

middle_scrimmage = (check.
                    filter(pl.col("line") == 60).
                    filter(pl.col("Scrimmage") == "line_scrimmage_1").
                    select("gameId", "playId", "Scrimmage"))

scrimmage_comparison = (check.
                        filter(pl.col("line") != 60).
                        drop("OffDef", "Side", "x", "line").
                        group_by("gameId", "playId", "Scrimmage").
                        mean().
                        filter(pl.col('dist') == pl.col('dist').min().over(["gameId", "playId"])).
                        drop("dist"))

scrimmage_comparison_simple = (scrimmage_comparison.
                               group_by("gameId", "playId").
                               count().
                               filter(pl.col("count") == 1).
                               drop("count").
                               join(scrimmage_comparison,
                                    on = ["gameId", "playId"],
                                    how = "left"))

scrimmage_comparison_double = (scrimmage_comparison.
                               group_by("gameId", "playId").
                               count().
                               filter(pl.col("count") == 2).
                               drop("count").
                               join(check,
                                    on = ["gameId", "playId"],
                                    how = "left").
                               filter(pl.col("OffDef") == "Offense").
                               filter(pl.col('dist') == pl.col('dist').min().over(["gameId", "playId"])).
                               unique().
                               select("gameId", "playId", "Scrimmage"))

scrimmage_final = pl.concat([middle_scrimmage, scrimmage_comparison_simple, scrimmage_comparison_double])

scrimmage_checked = (scrimmage_final.
                     join(check,
                          on = ["gameId", "playId", "Scrimmage"],
                          how = "left").
                     select("gameId", "playId", "OffDef", "Side", "line").
                     rename({"line" : "line_scrimmage"}))

data = (data.
        drop("line_scrimmage", "Side").
        join(scrimmage_checked,
             on = ["gameId", "playId", "OffDef"],
             how = "left"))

## Positions to movements and starting place

In [10]:
data = (data.
        explode("x", "y", "frameId").
        group_by('season_type', 'season', 'gameId', 'qtr', 'down', 'playId', 'PlayType', 'team', 'OffDef', 'nflId', 'position', 'Length', 'Side','line_scrimmage').
        agg(pl.col("x").
            sort_by("frameId"),
            pl.col("y").
            sort_by("frameId"),
            pl.col("frameId").
            sort_by("frameId")).
        with_columns(pl.lit(26.5).cast(pl.Int64).alias("line_scrimmage_y")).
        rename({"line_scrimmage" : "line_scrimmage_x"}).
        with_columns(pl.col("x").list.first().alias("Starting_x")).
        with_columns(pl.col("y").list.first().alias("Starting_y")).
        with_columns((pl.col("Starting_x") - pl.col("line_scrimmage_x")).alias("Starting_x")).
        with_columns((pl.col("Starting_y") - pl.col("line_scrimmage_y")).alias("Starting_y")).
        with_columns(
            [
                pl.col("x").list.first().alias("first_x"),
                pl.col("y").list.first().alias("first_y")]
            ).
        explode("frameId", "x", "y").
        with_columns(
            [
                (pl.col("x") - pl.col("first_x")).alias("x"),
                (pl.col("y") - pl.col("first_y")).alias("y")]
            ).
        drop("first_x", "first_y").
        filter(pl.col("frameId") != 1).
        with_columns(pl.when(pl.col("Starting_x") > 0).
                    then(pl.col("x") * -1).
                    otherwise(pl.col("x")).
                    alias("x")).
        with_columns(pl.when(pl.col("Starting_x") > 0).
                    then(pl.col("y") * -1).
                    otherwise(pl.col("y")).
                    alias("y")).
        with_columns(pl.when(pl.col("Starting_x") > 0).
                    then(pl.col("Starting_x") * -1).
                    otherwise(pl.col("Starting_x")).
                    alias("Starting_x")).
        with_columns(pl.when(pl.col("Starting_x") > 0).
                    then(pl.col("Starting_y") * -1).
                    otherwise(pl.col("Starting_y")).
                    alias("Starting_y")).
        group_by("season_type", "season", "gameId", "qtr", "down", "playId", "PlayType", "team", "OffDef", "nflId", "position", "Length", "Side", "line_scrimmage_x", "line_scrimmage_y", "Starting_x", "Starting_y").
        agg(pl.col("x"), pl.col("y"), pl.col("frameId")).
        drop("line_scrimmage_y").
        rename({"line_scrimmage_x" : "line_scrimmage"}))

In [11]:
exploded_data = (data.
                explode("frameId", "x",	"y").
                with_columns([
                    (pl.col("x").cast(pl.Int64)).alias("x"),
                    (pl.col("y").cast(pl.Int64)).alias("y"),
                    (pl.col("Starting_x").cast(pl.Int64)).alias("Starting_x"),
                    (pl.col("Starting_y").cast(pl.Int64)).alias("Starting_y")
                    ]).
                with_columns(pl.when(pl.col("x") == -0).
                             then(pl.lit(0)).
                             otherwise(pl.col("x")).
                             alias("x")).
                with_columns(pl.when(pl.col("y") == -0).
                             then(pl.lit(0)).
                             otherwise(pl.col("y")).
                             alias("y")).
                with_columns(pl.when(pl.col("Starting_x") == -0).
                             then(pl.lit(0)).
                             otherwise(pl.col("Starting_x")).
                             alias("Starting_x")).
                with_columns(pl.when(pl.col("Starting_y") == -0).
                             then(pl.lit(0)).
                             otherwise(pl.col("Starting_y")).
                             alias("Starting_y")))

## Zones and time frames

### Zones

In [12]:
min_x = round(exploded_data.select("x").min().to_series().to_list()[0])
max_x = round(exploded_data.select("x").max().to_series().to_list()[0])
min_y = round(exploded_data.select("y").min().to_series().to_list()[0])
max_y = round(exploded_data.select("y").max().to_series().to_list()[0])

print("Max of x is ", max_x)
print("Min of x is ", min_x)
print("Max of y is ", max_y)
print("Min of y is ", min_y)

Max of x is  77
Min of x is  -72
Max of y is  50
Min of y is  -50


### Starting zone

In [13]:
min_start_x = exploded_data.select("Starting_x").min().to_series().to_list()[0]
max_start_x = exploded_data.select("Starting_x").max().to_series().to_list()[0]
min_start_y = exploded_data.select("Starting_y").min().to_series().to_list()[0]
max_start_y = exploded_data.select("Starting_y").max().to_series().to_list()[0]

print("Max of x is ", max_start_x)
print("Min of x is ", min_start_x)
print("Max of y is ", max_start_y)
print("Min of y is ", min_start_y)

Max of x is  0
Min of x is  -76
Max of y is  35
Min of y is  -28


In [14]:
move_step = 1
start_step = 1

moves_index = (pl.DataFrame({"x" : range(min_x, max_x + move_step, move_step),
                             "y" : [list(range(min_y, max_y + move_step, move_step)) for v in range(min_x, max_x + move_step, move_step)]}).
               explode("y").
               join(exploded_data.
                    select("x", "y").
                    unique().
                    with_columns(pl.lit(1).alias("Check")),
                    on = ["x", "y"],
                    how = "left").
               filter(pl.col("Check").is_not_null()).
               drop("Check"))

starts_index = (pl.DataFrame({"Starting_x" : range(min_start_x-start_step, max_start_x + start_step, start_step),
                             "Starting_y" : [list(range(min_start_y, max_start_y + start_step, start_step)) for v in range(min_start_x-start_step, max_start_x + start_step, start_step)]}).
               explode("Starting_y").
               join(exploded_data.
                    select("Starting_x", "Starting_y").
                    unique().
                    with_columns(pl.lit(1).alias("Check")),
                    on = ["Starting_x", "Starting_y"],
                    how = "left").
               filter(pl.col("Check").is_not_null()).
               drop("Check"))

starts_core = (starts_index.
               filter((pl.col("Starting_x") >= -20).and_(pl.col("Starting_x") <= 20)))


zones_max = moves_index.shape[0]
starts_core_max = starts_core.shape[0]
starts_max = starts_core_max+1

scrimmage_max = 99
positions_max= 28

starts_long = (starts_index.
               filter((pl.col("Starting_x") < -20)).
               with_columns(pl.lit(starts_core_max).alias("Start_ID")))

In [15]:
zones_max

10876

In [16]:
starts_core_max

1031

In [17]:
starts_max

1032

In [18]:
moves_index = (moves_index.
               with_columns(pl.arange(0, zones_max).alias("Zone_ID")))


starts_index = pl.concat([(starts_index.
                           filter((pl.col("Starting_x") >= -20).and_(pl.col("Starting_x") <= 20)).
                           with_columns(pl.arange(0, starts_core_max).cast(pl.Int32).alias("Start_ID"))),
                          starts_long])


scrimmage_index = (data.
                   select("line_scrimmage").
                   unique().
                   with_columns(pl.arange(0, scrimmage_max).alias("Scrimmage_ID")))

time_index = (data.
              select("frameId").
              explode("frameId").
              unique().
              sort("frameId").
              with_columns(pl.arange(1, 51).alias("Frame_ID")))

In [19]:
scrimmage_index

line_scrimmage,Scrimmage_ID
f64,i64
68.0,0
79.0,1
94.0,2
76.0,3
60.0,4
87.0,5
88.0,6
41.0,7
64.0,8
54.0,9


In [22]:
moves_index.tail(5)

x,y,Zone_ID
i64,i64,i64
75,-14,10871
75,-1,10872
76,-39,10873
76,-1,10874
77,-39,10875


In [23]:
starts_index

Starting_x,Starting_y,Start_ID
i64,i64,i32
-20,-22,0
-20,-21,1
-20,-20,2
-20,-19,3
-20,-18,4
-20,-17,5
-20,-16,6
-20,-15,7
-20,-14,8
-20,-13,9


In [24]:
scrimmage_index.tail(5)

line_scrimmage,Scrimmage_ID
f64,i64
28.0,94
61.0,95
106.0,96
21.0,97
26.0,98


In [25]:
time_index.tail(5)

frameId,Frame_ID
i64,i64
93,46
95,47
97,48
99,49
101,50


In [26]:
updated_data = (exploded_data.
                join(starts_index,
                     on = ["Starting_x", "Starting_y"],
                     how = "left").
                join(time_index,
                     on = ["frameId"],
                     how = "left").
                join(moves_index,
                     on = ["x", "y"],
                     how = "left").
                join(scrimmage_index,
                     on = "line_scrimmage",
                     how = "left").
                drop("frameId", "x", "y", "line_scrimmage", "Starting_x", "Starting_y").
                group_by("season_type", "season", "gameId", "qtr", "down", "playId", "Scrimmage_ID", "Side", "PlayType", "team", "OffDef", "nflId", "position", "Length", "Start_ID").
                agg([pl.col("Frame_ID"),
                     pl.col("Zone_ID")]).
                with_columns(pl.col("Zone_ID").list.unique().list.lengths().alias("n_unique")).
                filter(pl.col("n_unique") != 1).
                drop("n_unique"))

In [27]:
class_weights = (updated_data.
                 select("Zone_ID").
                 explode("Zone_ID"))

class_weights.write_parquet("models/modeling/class_weights.parquet")

In [28]:
(class_weights.
 filter(pl.col("Zone_ID") != 5517).
 group_by("Zone_ID").
 count().
 sort("count").
 sum())

Zone_ID,count
i64,u32
59132733,28143606


## Play  types

In [29]:
plays_index = (updated_data.
              select("PlayType").
              unique().
              sort("PlayType").
              with_columns(pl.arange(0, 8).alias("PlayType_ID")))

In [30]:
plays_index

PlayType,PlayType_ID
str,i64
"""extra_point""",0
"""field_goal""",1
"""kickoff""",2
"""no_play""",3
"""pass""",4
"""punt""",5
"""qb_spike""",6
"""run""",7


In [31]:
updated_data = (updated_data.
                join(plays_index,
                     on = "PlayType",
                     how = "left").
                drop("PlayType"))

## OffDef

In [32]:
updated_data = (updated_data.
                with_columns(pl.when(pl.col("OffDef") == "Offense").then(1).otherwise(0).alias("OffDef_ID")))

## Position

In [33]:
positions_index = (data.
                   select("position").
                   unique().
                   sort("position").
                   with_columns(pl.arange(0, positions_max).alias("position_ID")))

In [34]:
moves_index = (moves_index.
               with_columns(pl.lit("Moves").alias("Cat")))

starts_index = (starts_index.
                with_columns(pl.lit("Start").alias("Cat")))

scrimmage_index = (scrimmage_index.
                   with_columns(pl.lit("Scrimm").alias("Cat")))

positions_index = (positions_index.
                   with_columns(pl.lit("Pos").alias("Cat")))

In [35]:
starts_index

Starting_x,Starting_y,Start_ID,Cat
i64,i64,i32,str
-20,-22,0,"""Start"""
-20,-21,1,"""Start"""
-20,-20,2,"""Start"""
-20,-19,3,"""Start"""
-20,-18,4,"""Start"""
-20,-17,5,"""Start"""
-20,-16,6,"""Start"""
-20,-15,7,"""Start"""
-20,-14,8,"""Start"""
-20,-13,9,"""Start"""


In [36]:
time_index.write_parquet("index/time_index.parquet")
moves_index.write_parquet("index/moves_index.parquet")
positions_index.write_parquet("index/positions_index.parquet")
plays_index.write_parquet("index/plays_index.parquet")
starts_index.write_parquet("index/starts_index.parquet")
scrimmage_index.write_parquet("index/scrimmage_index.parquet")

In [37]:
starts_index

Starting_x,Starting_y,Start_ID,Cat
i64,i64,i32,str
-20,-22,0,"""Start"""
-20,-21,1,"""Start"""
-20,-20,2,"""Start"""
-20,-19,3,"""Start"""
-20,-18,4,"""Start"""
-20,-17,5,"""Start"""
-20,-16,6,"""Start"""
-20,-15,7,"""Start"""
-20,-14,8,"""Start"""
-20,-13,9,"""Start"""


In [38]:
positions_index

position,position_ID,Cat
str,i64,str
"""C""",0,"""Pos"""
"""CB""",1,"""Pos"""
"""DB""",2,"""Pos"""
"""DE""",3,"""Pos"""
"""DL""",4,"""Pos"""
"""DT""",5,"""Pos"""
"""FB""",6,"""Pos"""
"""FS""",7,"""Pos"""
"""G""",8,"""Pos"""
"""HB""",9,"""Pos"""


In [39]:
scrimmage_index

line_scrimmage,Scrimmage_ID,Cat
f64,i64,str
68.0,0,"""Scrimm"""
79.0,1,"""Scrimm"""
94.0,2,"""Scrimm"""
76.0,3,"""Scrimm"""
60.0,4,"""Scrimm"""
87.0,5,"""Scrimm"""
88.0,6,"""Scrimm"""
41.0,7,"""Scrimm"""
64.0,8,"""Scrimm"""
54.0,9,"""Scrimm"""


In [40]:
starts_index.select("Start_ID").unique()

Start_ID
i32
0
1
2
3
4
5
6
7
8
9


In [41]:
index = pl.concat([(moves_index.
                    rename({"Zone_ID" : "ID"}).
                    select("Cat", "ID").
                    with_columns(pl.col("ID").cast(pl.Int32))), 
                    (starts_index.
                    rename({"Start_ID" : "ID"}).
                    select("Cat", "ID").
                    with_columns(pl.col("ID").cast(pl.Int32))), 
                    (scrimmage_index.
                    rename({"Scrimmage_ID" : "ID"}).
                    select("Cat", "ID").
                    with_columns(pl.col("ID").cast(pl.Int32))),
                    (positions_index.
                    rename({"position_ID" : "ID"}).
                    select("Cat", "ID").
                    with_columns(pl.col("ID").cast(pl.Int32)))]).unique().sort("ID")



index

Cat,ID
str,i32
"""Moves""",0
"""Pos""",0
"""Scrimm""",0
"""Start""",0
"""Moves""",1
"""Scrimm""",1
"""Start""",1
"""Pos""",1
"""Moves""",2
"""Scrimm""",2


In [42]:
index.write_parquet("index/index.parquet")

In [43]:
updated_data = (updated_data.
                join(positions_index,
                     on = "position",
                     how = "left").
                drop("position"))

## Side ID

In [44]:
updated_data = (updated_data.
                with_columns(pl.when(pl.col("Side") == "Right").
                             then(1).
                             otherwise(0).
                             alias("side_ID")))

In [45]:
updated_data = (updated_data.
                select("gameId", "playId", "PlayType_ID", "OffDef_ID", "side_ID", "nflId", "position_ID", "Scrimmage_ID", "Start_ID", "Frame_ID", "Zone_ID"))

In [46]:
updated_data.write_parquet("data_preprocessing/1_mapped/mapped_df.parquet")

In [47]:
starts_count = (updated_data.
                select("Start_ID", "position_ID").
                group_by("position_ID", "Start_ID").
                count().
                sort("position_ID", "count"))

most_probable_starts = pl.concat([
    x.top_k(8, by='count')
    for x in starts_count.partition_by('position_ID')
    ])

most_probable_starts.write_parquet("models/generating/most_probable_starts.parquet")


In [48]:
pos_count = (updated_data.
                select("OffDef_ID", "position_ID").
                group_by("OffDef_ID", "position_ID").
                count().
                sort("OffDef_ID", "count"))

most_probable_pos = pl.concat([
    x.top_k(15, by='count')
    for x in pos_count.partition_by('OffDef_ID')
    ])

most_probable_pos.write_parquet("models/generating/most_probable_pos.parquet")

In [49]:
most_probable_pos

OffDef_ID,position_ID,count
i32,i64,u32
0,1,114266
0,17,55097
0,7,44549
0,24,38453
0,10,34197
0,3,31819
0,12,26371
0,5,22951
0,2,21287
0,27,16812


In [50]:
pos_count

OffDef_ID,position_ID,count
i32,i64,u32
0,16,1
0,13,1
0,18,5
0,0,61
0,25,96
0,8,134
0,4,312
0,9,356
0,20,578
0,23,648


In [51]:
updated_data

gameId,playId,PlayType_ID,OffDef_ID,side_ID,nflId,position_ID,Scrimmage_ID,Start_ID,Frame_ID,Zone_ID
i64,i64,i64,i32,i32,i64,i64,i64,i32,list[i64],list[i64]
2018110405,3362,4,0,0,42388,14,63,748,"[1, 2, … 37]","[5081, 5081, … 4144]"
2018101402,3988,4,0,1,42496,17,66,900,"[1, 2, … 31]","[5081, 5081, … 4836]"
2018093003,395,4,1,1,41254,27,14,883,"[1, 2, … 27]","[5081, 5081, … 5676]"
2018112201,3935,4,1,1,35524,20,8,800,"[1, 2, … 28]","[5081, 5081, … 4915]"
2018092308,4248,4,1,1,38531,20,68,800,"[1, 2, … 37]","[5081, 5081, … 5086]"
2018110410,1258,4,0,1,44873,24,41,610,"[1, 2, … 26]","[5081, 5081, … 4735]"
2018092310,257,4,0,0,42361,1,82,713,"[1, 2, … 20]","[5081, 5081, … 5094]"
2018122305,147,4,1,0,46104,21,28,744,"[1, 2, … 23]","[5081, 5081, … 5168]"
2018093000,4305,4,1,1,42500,26,47,933,"[1, 2, … 28]","[5081, 5081, … 6422]"
2018093008,894,4,1,0,28955,27,32,934,"[1, 2, … 32]","[5081, 5081, … 6444]"
