In [1]:
from nba_api.live.nba.endpoints import playbyplay
from nba_api.stats.static import players
import pandas as pd
from nba_api.stats.static import teams
from nba_api.stats.endpoints import leaguegamefinder
from nba_api.stats.library.parameters import Season, SeasonType, LeagueID
from nba_api.stats.endpoints import playergamelog
from nba_api.stats.endpoints import boxscoresummaryv2
from nba_api.stats.endpoints import gamerotation
from retry import retry
@retry(Exception, tries=3, delay=2, backoff=2)
def get_play_by_play_data(game_id):
    pbp = playbyplay.PlayByPlay(game_id , timeout=500)
    line1 = "{period}:{clock} {player_id} ({description}) ({scoreHome}-{scoreAway}) "
    line2 = " {period}:{clock} ({action_type})"
    actions = pbp.get_dict()['game']['actions']
    
    play_by_play_data = []
    for action in actions:
        player_name = ''
        player = players.find_player_by_id(action.get('personId'))
        if player is not None:

            
            play_by_play_data.append(line1.format(period=action['period'],clock=action['clock'],description=action['description'],player_id=action['personIdsFilter'],scoreHome=action['scoreHome'],scoreAway=action['scoreAway'] ))
        else:
            action_type = action.get('actionType', '')
            
            play_by_play_data.append(line2.format(period=action['period'],clock=action['clock'],action_type=action['actionType'],scoreHome=action['scoreHome'],scoreAway=action['scoreAway']))
    
    return play_by_play_data
#get_play_by_play_data('0022300328')

In [4]:
import requests
import time 

nba_teams = teams.get_teams()
def get_team_avg_stats(team_abbr, n=5):
    # Retrieve the team ID
    team = [t for t in nba_teams if t['abbreviation'] == team_abbr][0]
    team_id = team['id']

    max_retries = 10
    for _ in range(max_retries):
        try:
            # Query for games where the specified team was playing
            gamefinder = leaguegamefinder.LeagueGameFinder(
                team_id_nullable=team_id,
                season_type_nullable=SeasonType.regular,
                timeout=30
            )
            # The first DataFrame of those returned is what we want.
            games = gamefinder.get_data_frames()[0]
            break  # Break out of the loop if successful
        except requests.exceptions.ReadTimeout:
            print(f"Request timed out. Retrying...")
            

    if games.empty:
        raise ValueError(f"Failed after {max_retries} retries.")
        
    
    games["GAME_DATE"] = pd.to_datetime(games["GAME_DATE"])
    

    # Sort games by date
    team_games = games.sort_values(by=['GAME_DATE'], ascending=False)

    # Take the last 'n' games
    last_n_games = team_games.head(n)

    # Calculate average stats for the last 'n' games
    avg_stats = last_n_games[['PTS', 'FG_PCT', 'FG3_PCT', 'FT_PCT', 'REB', 'AST', 'STL', 'BLK', 'TOV']].mean()

    return avg_stats

# Example: Get average stats for the last five games for the Boston Celtics
#celtics_avg_stats = get_team_avg_stats('BOS', n=5)
#print(celtics_avg_stats)


In [5]:
def get_player_avg_stats(player_id, n=5):
    # Query for game logs for the specified player
    max_retries = 10
    for _ in range(max_retries):
        try:
            # Query for game logs for the specified player
            gamelog = playergamelog.PlayerGameLog(player_id=player_id, timeout=10)
            games = gamelog.get_data_frames()[0]
            break  # Break out of the loop if successful
        except requests.exceptions.ReadTimeout:
            print(f"Request timed out. Retrying...")
            

    if games.empty:
        raise ValueError(f"Failed after {max_retries} retries.")
    games["GAME_DATE"] = pd.to_datetime(games["GAME_DATE"])

    # Sort games by date
    games = games.sort_values(by=['GAME_DATE'], ascending=False)

    # Take the last 'n' games
    last_n_games = games.head(n)

    # Calculate average stats for the last 'n' games
    avg_stats = last_n_games[['PTS', 'FG_PCT', 'FG3_PCT', 'FT_PCT', 'REB', 'AST', 'STL', 'BLK', 'TOV']].mean()

    return avg_stats

# Example: Get average stats for the last five games for LeBron James
#lebron_avg_stats = get_player_avg_stats(player_id=2544, n=5)
#print(lebron_avg_stats)


In [5]:

nba_teams = teams.get_teams()
nba_players = players.get_players()
def get_team_and_player_stats(game_id):
    # Get BoxScoreSummaryV2 for inactive players
    boxscore_summary = boxscoresummaryv2.BoxScoreSummaryV2(game_id, timeout=60)
    
    # Access home and visitor team IDs
    data_frames = boxscore_summary.get_data_frames()
    home_team_id = data_frames[0]['HOME_TEAM_ID'].iloc[0]
    visitor_team_id = data_frames[0]['VISITOR_TEAM_ID'].iloc[0]

    
    # Get inactive players for each team
    inactive_players_data = data_frames[3]
    inactive_players_home = inactive_players_data[inactive_players_data['TEAM_ID'] == int(home_team_id)]['PLAYER_ID'].tolist()
    inactive_players_visitor = inactive_players_data[inactive_players_data['TEAM_ID'] == int(visitor_team_id)]['PLAYER_ID'].tolist()

    # Get active players for each team using BoxScoreSummaryV2
    
    
    # Get active players for home and visitor teams


    # Extract relevant information for each team
    gamerotation_data = gamerotation.GameRotation(game_id, timeout=60)
    gamerota = gamerotation_data.get_data_frames()

    # Extract unique player IDs from both columns
    home_team_player_ids = gamerota[1]['PERSON_ID'].unique()
    visitor_team_player_ids = gamerota[0]['PERSON_ID'].unique()

    # Filter out inactive players from active player lists
    active_players_home = [player_id for player_id in home_team_player_ids if player_id not in inactive_players_home]
    active_players_visitor = [player_id for player_id in visitor_team_player_ids if player_id not in inactive_players_visitor]
    
    home_team_info = next(team for team in nba_teams if team['id'] == home_team_id)
    visitor_team_info = next(team for team in nba_teams if team['id'] == visitor_team_id)
    
    home_team_abbr = home_team_info['abbreviation']
    visitor_team_abbr = visitor_team_info['abbreviation']

    # Get team stats for the last five games using LeagueGameFinder
    team_stats_home = get_team_avg_stats(home_team_abbr)
    team_stats_visitor = get_team_avg_stats(visitor_team_abbr)

    # Get player stats for active players using PlayerGameLog

    # Get player stats for active players using PlayerGameLog
    player_stats_home = {player['id'] : get_player_avg_stats(player['id']) for player in nba_players if player['id'] in home_team_player_ids and player['id'] not in inactive_players_home}
    player_stats_visitor = {player['id'] : get_player_avg_stats(player['id']) for player in nba_players if player['id'] in visitor_team_player_ids and player['id'] not in inactive_players_visitor}

    return {
        'home_team': {
            'abbr': home_team_abbr,
            'active_players': active_players_home,
            'team_stats': team_stats_home,
            'player_stats': player_stats_home
        },
        'visitor_team': {
            'abbr': visitor_team_abbr,
            'active_players': active_players_visitor,
            'team_stats': team_stats_visitor,
            'player_stats': player_stats_visitor
        }
    }


# Example: Get team and player stats for a specific game ID
game_id_to_check = '0022300062'
game_stats = get_team_and_player_stats(game_id_to_check)
print(game_stats)



 

{'home_team': {'abbr': 'GSW', 'active_players': [101108, 201939, 202691, 202709, 203952, 203967, 1626172, 1627780, 1630228, 1630541], 'team_stats': PTS        120.6000
FG_PCT       0.4706
FG3_PCT      0.3954
FT_PCT       0.8106
REB         47.6000
AST         27.4000
STL          7.6000
BLK          3.4000
TOV         12.8000
dtype: float64, 'player_stats': {201939: PTS        23.6000
FG_PCT      0.3908
FG3_PCT     0.3586
FT_PCT      0.9500
REB         3.4000
AST         4.6000
STL         0.8000
BLK         0.8000
TOV         2.6000
dtype: float64, 202709: PTS        2.2
FG_PCT     0.3
FG3_PCT    0.0
FT_PCT     0.3
REB        1.4
AST        2.2
STL        0.0
BLK        0.2
TOV        0.0
dtype: float64, 1630228: PTS        14.0000
FG_PCT      0.5858
FG3_PCT     0.3834
FT_PCT      0.4858
REB         5.2000
AST         1.0000
STL         1.0000
BLK         0.2000
TOV         1.0000
dtype: float64, 1626172: PTS        3.8000
FG_PCT     0.3742
FG3_PCT    0.0000
FT_PCT     0.2000
REB     

In [None]:
gamerotation_data = gamerotation.GameRotation('0021001223')
gamerota = gamerotation_data.get_data_frames()
gamerota[0]

In [38]:
active_players_home = game_stats['home_team']['active_players']

# Print or use the active players
print("Active players for the home team:", active_players_home)

Active players for the home team: [101108, 201939, 202691, 202709, 203952, 203967, 1626172, 1627780, 1630228, 1630541]


In [39]:

boxscore_summary = boxscoresummaryv2.BoxScoreSummaryV2('0022300061')
data_frames = boxscore_summary.get_data_frames()
data_frames[0]

Unnamed: 0,GAME_DATE_EST,GAME_SEQUENCE,GAME_ID,GAME_STATUS_ID,GAME_STATUS_TEXT,GAMECODE,HOME_TEAM_ID,VISITOR_TEAM_ID,SEASON,LIVE_PERIOD,LIVE_PC_TIME,NATL_TV_BROADCASTER_ABBREVIATION,LIVE_PERIOD_TIME_BCAST,WH_STATUS
0,2023-10-24T00:00:00,1,22300061,3,Final,20231024/LALDEN,1610612743,1610612747,2023,4,,TNT,Q4 - TNT,1


In [4]:
import json

class NBAResponse:
    def __init__(self, content, status_code, url):
        self.content = content
        self.status_code = status_code
        self.url = url

    def valid_json(self):
        try:
            json.loads(self.content)
            return True
        except json.JSONDecodeError:
            return False

    def __str__(self):
        return f"NBAResponse(content={self.content}, status_code={self.status_code}, url={self.url})"


In [5]:
import requests
from requests.exceptions import Timeout

class NBAHTTP:
    def __init__(self):
        self.nba_response = None
        self.base_url = None
        self.parameters = None
        self.headers = None

    def clean_contents(self, contents):
        # Your implementation for cleaning contents goes here
        pass

    def send_api_request(self, endpoint, parameters, referer=None, proxy=None, headers=None, timeout=None, raise_exception_on_error=False):
        # Set default timeout if not provided
        timeout = timeout or 30

        # Construct the URL
        url = f"{self.base_url}{endpoint}"

        # Set up proxy and headers
        proxies = {'http': proxy, 'https': proxy} if proxy else None
        request_headers = headers or {}

        try:
            # Make the API request with timeout
            response = requests.get(url, params=parameters, headers=request_headers, proxies=proxies, timeout=timeout)

            # Store the response in NBAResponse class
            self.nba_response = NBAResponse(response.text, response.status_code, url)

            # Check if the response is a valid JSON
            if raise_exception_on_error and not self.nba_response.valid_json():
                raise ValueError("Invalid JSON response")

            # Return the response
            
            return self.nba_response

        except Exception as e:
            print(f"Request to {url} timed out after {timeout} seconds.")
            # Handle the timeout error as needed
            # You can retry the request, log the error, or take other actions
            # For now, I'm raising a TimeoutError
            self.nba_response = NBAResponse(str(e), 500, url)
            return self.nba_response

In [6]:
def get_all_game_ids_for_current_season():
    # Query for all games in the current season
    max_retries = 10  # Set the maximum number of retry attempts
    for attempt in range(max_retries):
        try:
            gamefinder = leaguegamefinder.LeagueGameFinder(
                league_id_nullable='00',
                season_type_nullable=SeasonType.regular,
                timeout=30
            )

    # Extract game IDs
            games = gamefinder.get_data_frames()[0]
            games["GAME_DATE"] = pd.to_datetime(games["GAME_DATE"])
            games.sort_values(by=['GAME_DATE'], inplace=False)

            # Get the season ID from the first entry after sorting by GAME_DATE
            season_id = games.iloc[0]['SEASON_ID']

            # Filter games for the current season using the obtained season ID
            current_season_games = games[games['SEASON_ID'] == season_id]

            # Extract unique game IDs
            all_game_ids = current_season_games['GAME_ID'].unique().tolist()

            return all_game_ids
        except requests.exceptions.Timeout:
            print(f"Attempt {attempt + 1} timed out. Retrying...")
    raise TimeoutError("Max retries exceeded")

all_game_ids = get_all_game_ids_for_current_season()
all_game_ids.reverse()
print(all_game_ids)

['0022300061', '0022300062', '0022300067', '0022300074', '0022300071', '0022300065', '0022300063', '0022300064', '0022300070', '0022300072', '0022300068', '0022300073', '0022300069', '0022300066', '0022300076', '0022300075', '0022300082', '0022300085', '0022300080', '0022300078', '0022300081', '0022300083', '0022300087', '0022300079', '0022300077', '0022300086', '0022300084', '0022300094', '0022300093', '0022300089', '0022300090', '0022300091', '0022300092', '0022300088', '0022300098', '0022300096', '0022300097', '0022300100', '0022300095', '0022300099', '0022300103', '0022300104', '0022300111', '0022300109', '0022300106', '0022300107', '0022300105', '0022300102', '0022300101', '0022300110', '0022300108', '0022300112', '0022300114', '0022300113', '0022300121', '0022300125', '0022300120', '0022300124', '0022300116', '0022300126', '0022300123', '0022300118', '0022300117', '0022300115', '0022300122', '0022300119', '0022300127', '0022300129', '0022300130', '0022300128', '0022300131', '0022

In [2]:
all_game_ids = ['0022300061', '0022300062', '0022300067', '0022300074', '0022300071', '0022300065', '0022300063', '0022300064', '0022300070', '0022300072', '0022300068', '0022300073', '0022300069', '0022300066', '0022300076', '0022300075', '0022300082', '0022300085', '0022300080', '0022300078', '0022300081', '0022300083', '0022300087', '0022300079', '0022300077', '0022300086', '0022300084', '0022300094', '0022300093', '0022300089', '0022300090', '0022300091', '0022300092', '0022300088', '0022300098', '0022300096', '0022300097', '0022300100', '0022300095', '0022300099', '0022300103', '0022300104', '0022300111', '0022300109', '0022300106', '0022300107', '0022300105', '0022300102', '0022300101', '0022300110', '0022300108', '0022300112', '0022300114', '0022300113', '0022300121', '0022300125', '0022300120', '0022300124', '0022300116', '0022300126', '0022300123', '0022300118', '0022300117', '0022300115', '0022300122', '0022300119', '0022300127', '0022300129', '0022300130', '0022300128', '0022300131', '0022300005', '0022300006', '0022300001', '0022300007', '0022300004', '0022300003', '0022300002', '0022300137', '0022300136', '0022300134', '0022300132', '0022300138', '0022300139', '0022300135', '0022300133', '0022300142', '0022300143', '0022300140', '0022300141', '0022300144', '0022300151', '0022300148', '0022300145', '0022300152', '0022300155', '0022300149', '0022300150', '0022300146', '0022300147', '0022300156', '0022300154', '0022300153', '0022300168', '0022300165', '0022300160', '0022300164', '0022300162', '0022300169', '0022300166', '0022300158', '0022300167', '0022300161', '0022300157', '0022300163', '0022300170', '0022300159', '0022300172', '0022300171', '0022300014', '0022300013', '0022300009', '0022300008', '0022300016', '0022300011', '0022300010', '0022300012', '0022300015', '0022300175', '0022300173', '0022300176', '0022300174', '0022300185', '0022300180', '0022300177', '0022300179', '0022300184', '0022300178', '0022300187', '0022300182', '0022300186', '0022300181', '0022300183', '0022300191', '0022300188', '0022300189', '0022300190', '0022300023', '0022300022', '0022300019', '0022300020', '0022300018', '0022300026', '0022300021', '0022300024', '0022300017', '0022300025', '0022300195', '0022300199', '0022300198', '0022300194', '0022300192', '0022300193', '0022300197', '0022300196', '0022300201', '0022300200', '0022300032', '0022300029', '0022300027', '0022300035', '0022300037', '0022300030', '0022300036', '0022300034', '0022300031', '0022300028', '0022300033', '0022300203', '0022300205', '0022300204', '0022300206', '0022300207', '0022300202', '0022300215', '0022300216', '0022300212', '0022300209', '0022300208', '0022300211', '0022300213', '0022300214', '0022300210', '0022300223', '0022300218', '0022300219', '0022300222', '0022300224', '0022300221', '0022300220', '0022300217', '0022300042', '0022300039', '0022300040', '0022300038', '0022300041', '0022300234', '0022300235', '0022300237', '0022300228', '0022300226', '0022300236', '0022300232', '0022300230', '0022300225', '0022300238', '0022300233', '0022300229', '0022300227', '0022300231', '0022300045', '0022300050', '0022300047', '0022300051', '0022300044', '0022300049', '0022300046', '0022300043', '0022300048', '0022300052', '0022300239', '0022300243', '0022300241', '0022300242', '0022300240', '0022300244', '0022300245', '0022300248', '0022300251', '0022300246', '0022300249', '0022300250', '0022300252', '0022300247', '0022300257', '0022300253', '0022300256', '0022300255', '0022300254', '0022300060', '0022300059', '0022300053', '0022300055', '0022300054', '0022300058', '0022300057', '0022300056', '0022300259', '0022300263', '0022300261', '0022300262', '0022300260', '0022300258', '0022300264', '0022300265', '0022300271', '0022300267', '0022300270', '0022300273', '0022300266', '0022300268', '0022300272', '0022300269', '0022300279', '0022300278', '0022300274', '0022300275', '0022300277', '0022300276', '0022300289', '0022300285', '0022300282', '0022300283', '0022300288', '0022300287', '0022300291', '0022300281', '0022300290', '0022300284', '0022300286', '0022300280', '0022301204', '0022301202', '0022301201', '0022301203', '0022301206', '0022301215', '0022301208', '0022301207', '0022301212', '0022301205', '0022301214', '0022301209', '0022301213', '0022301210', '0022301211', '0022301230', '0022301229', '0022301225', '0022301228', '0022301220', '0022301218', '0022301221', '0022301223', '0022301222', '0022301224', '0022301216', '0022301217', '0022301226', '0022301219', '0022301227', '0022300298', '0022300302', '0022300294', '0022300301', '0022300295', '0022300303', '0022300304', '0022300299', '0022300300', '0022300297', '0022300296', '0022300292', '0022300293', '0022300305', '0022300308', '0022300306', '0022300309', '0022300307', '0022300315', '0022300313', '0022300314', '0022300317', '0022300311', '0022300318', '0022300312', '0022300310', '0022300316', '0022300323', '0022300319', '0022300325', '0022300320', '0022300321', '0022300324', '0022300322', '0022300333', '0022300330', '0022300327', '0022300329', '0022300332', '0022300326', '0022300331', '0022300328', '0022300340', '0022300337', '0022300335', '0022300338', '0022300343', '0022300339', '0022300342', '0022300336', '0022300334', '0022300341', '0022300345', '0022300344', '0022300347', '0022300346', '0022300348', '0022300354', '0022300351', '0022300356', '0022300357', '0022300349', '0022300352', '0022300353', '0022300350', '0022300355', '0022300359', '0022300358', '0022300360', '0022300362', '0022300361', '0022300363']

In [3]:
all_play_by_play_data = []

for game_id in all_game_ids:
    play_by_play_data = get_play_by_play_data(game_id)
    
    # Combine play-by-play data for each game into a single string
    game_text = ' '.join(play_by_play_data)
    all_play_by_play_data.append(game_text)

In [4]:
all_play_by_play_data[0]

" 1:PT12M00.00S (period) 1:PT11M55.00S [2544, 203076, 203999] (Jump Ball A. Davis vs. N. Jokic: Tip to L. James) (0-0)  1:PT11M42.00S [203076, 1626156] (A. Davis DUNK (2 PTS) (D. Russell 1 AST)) (0-2)  1:PT11M15.00S [203999, 1627750] (N. Jokic 6' driving floating Jump Shot (2 PTS) (J. Murray 1 AST)) (2-2)  1:PT10M57.00S [1627752, 2544] (T. Prince 24' 3PT  (3 PTS) (L. James 1 AST)) (2-5)  1:PT10M40.00S [1627750] (J. Murray driving finger roll Layup (2 PTS)) (4-5)  1:PT10M33.00S [1627752, 2544] (T. Prince 24' 3PT  (6 PTS) (L. James 2 AST)) (4-8)  1:PT10M16.00S [1627750] (J. Murray 26' 3PT step back (5 PTS)) (7-8)  1:PT10M03.00S [1626156] (MISS D. Russell 25' pullup 3PT) (7-8)  1:PT10M01.00S [1630559] (A. Reaves REBOUND (Off:1 Def:0)) (7-8)  1:PT09M49.00S [203076] (MISS A. Davis 18' step back Shot) (7-8)  1:PT09M46.00S [1629008] (M. Porter Jr. REBOUND (Off:0 Def:1)) (7-8)  1:PT09M36.00S [203999, 203076] (MISS N. Jokic Layup - blocked) (7-8)  1:PT09M36.00S [203076] (A. Davis BLOCK (1 BLK))

In [None]:
text_file_path = "/Users/student/combined_text.txt"
with open(text_file_path, "w", encoding="utf-8") as file:
    for game_data in all_play_by_play_data:
        file.write(game_data + '\n')

In [7]:
import sentencepiece as spm

# Assuming all_play_by_play_data is a list of strings

# Train SentencePiece model
spm.SentencePieceTrainer.train(
    input='/home/pelumi/combined_text.txt',
    model_prefix='/home/pelumi/play_byplay_spm.model',
    vocab_size=1000,
    control_symbols='<pad>,<s>,</s>',
    max_sentence_length=50000  # Set a larger value based on your data
)
# Load the trained SentencePiece model
sp = spm.SentencePieceProcessor()
sp.load("/home/pelumi/play_byplay_spm.model")

# Example tokenization



sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input: /home/pelumi/combined_text.txt
  input_format: 
  model_prefix: /home/pelumi/play_byplay_spm.model
  model_type: UNIGRAM
  vocab_size: 1000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 50000
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  control_symbols: <pad>
  control_symbols: <s>
  control_symbols: </s>
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_pi

True

  unk_surface:  ⁇ 
  enable_differential_privacy: 0
  differential_privacy_noise_level: 0
  differential_privacy_clipping_threshold: 0
}
normalizer_spec {
  name: nmt_nfkc
  add_dummy_prefix: 1
  remove_extra_whitespaces: 1
  escape_whitespaces: 1
  normalization_rule_tsv: 
}
denormalizer_spec {}
trainer_interface.cc(351) LOG(INFO) SentenceIterator is not specified. Using MultiFileSentenceIterator.
trainer_interface.cc(183) LOG(INFO) Loading corpus: /home/pelumi/combined_text.txt
trainer_interface.cc(407) LOG(INFO) Loaded all 311 sentences
trainer_interface.cc(423) LOG(INFO) Adding meta_piece: <unk>
trainer_interface.cc(423) LOG(INFO) Adding meta_piece: <s>
trainer_interface.cc(423) LOG(INFO) Adding meta_piece: </s>
trainer_interface.cc(423) LOG(INFO) Adding meta_piece: <pad>
trainer_interface.cc(428) LOG(INFO) Normalizing sentences...
trainer_interface.cc(537) LOG(INFO) all chars count=10354938
trainer_interface.cc(548) LOG(INFO) Done: 99.9609% characters are covered.
trainer_interfac

In [3]:
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.load("/home/pelumi/play_byplay_spm.model")
tokenized_play_by_play = [sp.encode_as_ids(text) for text in loaded_play_by_play_data]

In [24]:
import time
import requests
from nba_api.stats.endpoints import boxscoresummaryv2, gamerotation
from nba_api.stats.static import teams


def new_get_team_and_player_stats(game_ids):
    # Get BoxScoreSummaryV2 for inactive players
    request_interval = 3  # Adjust as needed, represents seconds between requests
    max_retries = 10
    timeout = 10
    game_counter = 0
    

    print(f"Embeddding game_id;{game_ids}")
    start_func_time = time.time()
    for _ in range(max_retries):
        try:
            boxscore_summary = boxscoresummaryv2.BoxScoreSummaryV2(game_ids, timeout=timeout)
                
            break  # Break out of the loop if successful
        except requests.exceptions.ReadTimeout:
            print(f"Request timed out. Retrying...")

    if not boxscore_summary:
        raise ValueError(f"Failed after {max_retries} retries for game {game_ids}.")
    
    data_frames = boxscore_summary.get_data_frames()
    home_team_id = data_frames[0]['HOME_TEAM_ID'].iloc[0]
    visitor_team_id = data_frames[0]['VISITOR_TEAM_ID'].iloc[0]

        # Get inactive players for each team
    inactive_players_data = data_frames[3]
    inactive_players_home = inactive_players_data[inactive_players_data['TEAM_ID'] == int(home_team_id)]['PLAYER_ID'].tolist()
    inactive_players_visitor = inactive_players_data[inactive_players_data['TEAM_ID'] == int(visitor_team_id)]['PLAYER_ID'].tolist()
    time.sleep(request_interval)

        # Get active players for each team using BoxScoreSummaryV2
   
    for _ in range(max_retries):
        try:
            gamerotation_data = gamerotation.GameRotation(game_ids, timeout=100)
                
            break  # Break out of the loop if successful
        except requests.exceptions.ReadTimeout:
            print(f"Request timed out. Retrying...")

    if not gamerotation_data:
        raise ValueError(f"Failed after {max_retries} retries for game {game_ids}.")
        # Extract unique player IDs from both columns
    gamerota = gamerotation_data.get_data_frames()
    home_team_player_ids = gamerota[1]['PERSON_ID'].unique()
    visitor_team_player_ids = gamerota[0]['PERSON_ID'].unique()

        # Filter out inactive players from active player lists
    active_players_home = [player_id for player_id in home_team_player_ids if player_id not in inactive_players_home]
    active_players_visitor = [player_id for player_id in visitor_team_player_ids if player_id not in inactive_players_visitor]

        # Add a pause between requests
    time.sleep(request_interval)

        # Get team stats for the last five games using LeagueGameFinder
    home_team_info = next(team for team in teams.get_teams() if team['id'] == home_team_id)
    visitor_team_info = next(team for team in teams.get_teams() if team['id'] == visitor_team_id)
    
    home_team_abbr = home_team_info['abbreviation']
    visitor_team_abbr = visitor_team_info['abbreviation']

        # Add a pause between requests
    time.sleep(request_interval)

        # Get team stats for the last five games using LeagueGameFinder
    team_stats_home = get_team_avg_stats(home_team_abbr)
    team_stats_visitor = get_team_avg_stats(visitor_team_abbr)

        # Add a pause between requests
    time.sleep(request_interval)

        # Get player stats for active players using PlayerGameLog
    player_stats_home = {player_id: get_player_avg_stats(player_id) for player_id in active_players_home}
    player_stats_visitor = {player_id: get_player_avg_stats(player_id) for player_id in active_players_visitor}

    time.sleep(request_interval)

        # Format the output for the embedding layer
        

        # Home team identifier (0) and team stats
    elapsed_func_time = time.time() - start_func_time
    
    embedding_input = []
    print(f"Time taken for processing game_ids {game_ids}: {elapsed_func_time:.2f} seconds")

    embedding_input.extend([0, home_team_id] + team_stats_home.values.tolist())

# Home player stats
    for player_id in active_players_home:
        embedding_input.extend([player_id] + player_stats_home[player_id].values.tolist())

# Visitor team identifier (1) and team stats
    embedding_input.extend([1, visitor_team_id] + team_stats_visitor.values.tolist())

# Visitor player stats
    for player_id in active_players_visitor:
        embedding_input.extend([player_id] + player_stats_visitor[player_id].values.tolist())

    return embedding_input



In [25]:
embedding_inputs = []
game_count = 0

for game_id in all_game_ids:
    embedding_inputs.append(new_get_team_and_player_stats(game_id))

    game_count += 1
    print(f"({game_count}/{len(all_game_ids)} games completed)")
    


Embeddding game_id;0022300061
Time taken for processing game_ids 0022300061: 21.69 seconds
(1/393 games completed)
Embeddding game_id;0022300062
Time taken for processing game_ids 0022300062: 20.80 seconds
(2/393 games completed)
Embeddding game_id;0022300067
Time taken for processing game_ids 0022300067: 21.22 seconds
(3/393 games completed)
Embeddding game_id;0022300074
Time taken for processing game_ids 0022300074: 22.66 seconds
(4/393 games completed)
Embeddding game_id;0022300071
Time taken for processing game_ids 0022300071: 19.32 seconds
(5/393 games completed)
Embeddding game_id;0022300065
Time taken for processing game_ids 0022300065: 19.46 seconds
(6/393 games completed)
Embeddding game_id;0022300063
Time taken for processing game_ids 0022300063: 20.62 seconds
(7/393 games completed)
Embeddding game_id;0022300064
Time taken for processing game_ids 0022300064: 21.44 seconds
(8/393 games completed)
Embeddding game_id;0022300070
Time taken for processing game_ids 0022300070: 21.

In [4]:
import pickle
def save_embedding_input(embedding_input, filename='home/pelumi/embedding_input.pkl'):
    with open(filename, 'wb') as file:
        pickle.dump(embedding_input, file)

# Call the function to save the embedding_input
filename = '/home/pelumi/embedding_input.pkl'
save_embedding_input(embedding_inputs, filename)



NameError: name 'embedding_inputs' is not defined

In [28]:
embedding_inputs[0]

[0,
 1610612743,
 122.8,
 0.49879999999999997,
 0.42540000000000006,
 0.8061999999999999,
 47.0,
 29.4,
 7.6,
 4.4,
 10.2,
 202704,
 15.2,
 0.4842000000000001,
 0.3816,
 0.8,
 1.4,
 4.6,
 0.4,
 0.0,
 1.0,
 203484,
 7.6,
 0.48339999999999994,
 0.3,
 0.35,
 1.6,
 1.4,
 0.2,
 0.4,
 0.4,
 203932,
 16.4,
 0.6214000000000001,
 0.2334,
 0.798,
 7.2,
 3.8,
 0.8,
 0.2,
 0.8,
 203999,
 17.4,
 0.5334,
 0.3,
 0.7130000000000001,
 9.2,
 8.8,
 1.6,
 0.4,
 1.4,
 1627750,
 20.6,
 0.5151999999999999,
 0.5599999999999999,
 0.8,
 4.2,
 3.2,
 1.0,
 0.8,
 2.0,
 1629008,
 11.4,
 0.33199999999999996,
 0.33340000000000003,
 0.45,
 7.0,
 1.0,
 0.4,
 0.0,
 0.8,
 1629618,
 1.0,
 0.4,
 0.2,
 0.0,
 0.6,
 0.4,
 0.2,
 0.0,
 0.4,
 1630192,
 7.4,
 0.6,
 0.2,
 0.63,
 4.2,
 0.8,
 0.2,
 1.4,
 0.6,
 1630296,
 0.6,
 0.1,
 0.0,
 0.1,
 0.6,
 1.0,
 0.0,
 0.0,
 0.0,
 1631128,
 9.8,
 0.5468,
 0.2466,
 0.8834,
 2.8,
 1.4,
 0.2,
 0.6,
 1.2,
 1631212,
 9.6,
 0.5184,
 0.41340000000000005,
 0.6334,
 4.2,
 1.8,
 0.6,
 0.8,
 0.6,
 163

In [2]:

import pickle
def load_embedding_input(filename='/home/pelumi/embedding_input.pkl'):
    with open(filename, 'rb') as file:
        embedding_input = pickle.load(file)
    return embedding_input

# Load the saved embedding_input
loaded_embedding_input = load_embedding_input('/home/pelumi/embedding_input.pkl')

In [17]:
loaded_embedding_input[0]

[0,
 1610612743,
 122.8,
 0.49879999999999997,
 0.42540000000000006,
 0.8061999999999999,
 47.0,
 29.4,
 7.6,
 4.4,
 10.2,
 202704,
 15.2,
 0.4842000000000001,
 0.3816,
 0.8,
 1.4,
 4.6,
 0.4,
 0.0,
 1.0,
 203484,
 7.6,
 0.48339999999999994,
 0.3,
 0.35,
 1.6,
 1.4,
 0.2,
 0.4,
 0.4,
 203932,
 16.4,
 0.6214000000000001,
 0.2334,
 0.798,
 7.2,
 3.8,
 0.8,
 0.2,
 0.8,
 203999,
 17.4,
 0.5334,
 0.3,
 0.7130000000000001,
 9.2,
 8.8,
 1.6,
 0.4,
 1.4,
 1627750,
 20.6,
 0.5151999999999999,
 0.5599999999999999,
 0.8,
 4.2,
 3.2,
 1.0,
 0.8,
 2.0,
 1629008,
 11.4,
 0.33199999999999996,
 0.33340000000000003,
 0.45,
 7.0,
 1.0,
 0.4,
 0.0,
 0.8,
 1629618,
 1.0,
 0.4,
 0.2,
 0.0,
 0.6,
 0.4,
 0.2,
 0.0,
 0.4,
 1630192,
 7.4,
 0.6,
 0.2,
 0.63,
 4.2,
 0.8,
 0.2,
 1.4,
 0.6,
 1630296,
 0.6,
 0.1,
 0.0,
 0.1,
 0.6,
 1.0,
 0.0,
 0.0,
 0.0,
 1631128,
 9.8,
 0.5468,
 0.2466,
 0.8834,
 2.8,
 1.4,
 0.2,
 0.6,
 1.2,
 1631212,
 9.6,
 0.5184,
 0.41340000000000005,
 0.6334,
 4.2,
 1.8,
 0.6,
 0.8,
 0.6,
 163

In [1]:
text_file_path = "/home/pelumi/combined_text.txt"
loaded_play_by_play_data = []

with open(text_file_path, "r", encoding="utf-8") as file:
    for line in file:
       loaded_play_by_play_data .append(line.strip())

# Now, loaded_play_by_play_data contains the list of play-by-play data


In [5]:
len(sp.encode_as_ids(loaded_play_by_play_data[0]))

14590

In [6]:
game_firstprt = sp.encode_as_ids(loaded_play_by_play_data[0][:27000])
len(game_firstprt)

11639

In [4]:
eos_token_id=sp.PieceToId('</s>')
start_token_id=sp.PieceToId('<s>')
print(start_token_id)
print(eos_token_id)

1
2


In [6]:
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
    def __init__(self, playbyplay_data, embedding_input):
        self.playbyplay_data = playbyplay_data
        self.embedding_input = embedding_input

    def __len__(self):
        return len(self.playbyplay_data)

    def __getitem__(self, idx):
        playbyplay_tensor = torch.tensor(self.playbyplay_data[idx])

        # Add start token at the beginning
        playbyplay_start = torch.cat([torch.tensor([start_token_id]), playbyplay_tensor], dim=0)

        # Add EOS token at the end and remove the first token
        playbyplay_end = torch.cat([playbyplay_tensor, torch.tensor([start_token_id])], dim=0)

        return {
            'playbyplay_start': playbyplay_start,
            'playbyplay_end': playbyplay_end,
            'embedding_input': torch.tensor(self.embedding_input[idx])
        }

def custom_collate_fn(batch):
    playbyplay_start_batch = [item['playbyplay_start'] for item in batch]
    playbyplay_end_batch = [item['playbyplay_end'] for item in batch]
    embedding_input_batch = [item['embedding_input'] for item in batch]

    # Pad playbyplay_start and playbyplay_end data to the maximum length in the batch
    playbyplay_start_padded = pad_sequence(playbyplay_start_batch, batch_first=True, padding_value=0)
    playbyplay_end_padded = pad_sequence(playbyplay_end_batch, batch_first=True, padding_value=0)
    embedding_input = pad_sequence(embedding_input_batch, batch_first=True, padding_value=0)

    return {
        'playbyplay_start': playbyplay_start_padded,
        'playbyplay_end': playbyplay_end_padded,
        'embedding_input': embedding_input
    }

# Example usage
# Assuming you have playbyplay_data and embedding_input as lists
start_token_id = sp.piece_to_id('<s>')  # Replace with your start token ID
eos_token_id = sp.piece_to_id('</s>')  # Replace with your EOS token ID

dataset = CustomDataset(tokenized_play_by_play, loaded_embedding_input)
dataloader = DataLoader(dataset, batch_size=4, collate_fn=custom_collate_fn, shuffle=False)



In [7]:
for batch_idx, batch in enumerate(dataloader):
    embedding_input_batch = batch['embedding_input']
    playbyplay_padded_batch = batch['playbyplay_start']
    playbyplay_padded_batch_eos = batch['playbyplay_end']

    # Inspect the first elements
    first_embedding_input = embedding_input_batch[0]
    first_playbyplay_padded = playbyplay_padded_batch[0]
    first_eos = playbyplay_padded_batch_eos[0]

    print("First embedding input:", first_embedding_input)
    print("First play-by-play padded:", first_playbyplay_padded)
    print("First play-by-play padded:", first_eos)

    # Check the lengths of play-by-play elements in the batch
    playbyplay_lengths = [len(playbyplay) for playbyplay in playbyplay_padded_batch]
    print("Lengths of play-by-play elements in the batch:", playbyplay_lengths)

    break

First embedding input: tensor([0.0000e+00, 1.6106e+09, 1.2280e+02, 4.9880e-01, 4.2540e-01, 8.0620e-01,
        4.7000e+01, 2.9400e+01, 7.6000e+00, 4.4000e+00, 1.0200e+01, 2.0270e+05,
        1.5200e+01, 4.8420e-01, 3.8160e-01, 8.0000e-01, 1.4000e+00, 4.6000e+00,
        4.0000e-01, 0.0000e+00, 1.0000e+00, 2.0348e+05, 7.6000e+00, 4.8340e-01,
        3.0000e-01, 3.5000e-01, 1.6000e+00, 1.4000e+00, 2.0000e-01, 4.0000e-01,
        4.0000e-01, 2.0393e+05, 1.6400e+01, 6.2140e-01, 2.3340e-01, 7.9800e-01,
        7.2000e+00, 3.8000e+00, 8.0000e-01, 2.0000e-01, 8.0000e-01, 2.0400e+05,
        1.7400e+01, 5.3340e-01, 3.0000e-01, 7.1300e-01, 9.2000e+00, 8.8000e+00,
        1.6000e+00, 4.0000e-01, 1.4000e+00, 1.6278e+06, 2.0600e+01, 5.1520e-01,
        5.6000e-01, 8.0000e-01, 4.2000e+00, 3.2000e+00, 1.0000e+00, 8.0000e-01,
        2.0000e+00, 1.6290e+06, 1.1400e+01, 3.3200e-01, 3.3340e-01, 4.5000e-01,
        7.0000e+00, 1.0000e+00, 4.0000e-01, 0.0000e+00, 8.0000e-01, 1.6296e+06,
        1.0000e+0

In [13]:
len(first_playbyplay_padded)

17430

In [26]:
first_eos.shape

torch.Size([17430])

In [8]:

def create_look_ahead_mask(size):
    mask = torch.triu(torch.ones((size, size), dtype=torch.bool), diagonal=1)
    return mask

In [9]:
eos_token_id=sp.PieceToId('</s>')
start_token_id=sp.PieceToId('<s>')
print(start_token_id)
print(eos_token_id)

1
2


In [10]:


import torch
from torch import nn
import torch.optim as optim
from torch.optim import SGD
import torch.nn.functional as F
import math
from tqdm import tqdm

class RoPEPositionalEncoding:
    def __init__(self, dim):
        self.dim = dim

    def rotate_every_two(self, x):
        x1 = x[:, :, ::2]
        x2 = x[:, :, 1::2]
        x = torch.stack((-x2, x1), dim=-1)
        return x.reshape(x.shape[0], x.shape[1], -1)
    
    def get_rotation_matrix(self, seq_len):
        inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim, 2).float() / self.dim))
        positions = torch.arange(seq_len).float()
        sinusoid_inp = torch.einsum("i,j->ij", positions, inv_freq)
        return torch.stack((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1).unsqueeze(0)
    
    def forward(self, x, seq_len):
        rotation_matrix = self.get_rotation_matrix(seq_len)  # [1, seq_len, dim // 2, 2]
        x_rotated = self.rotate_every_two(x)  # [batch_size, seq_len, dim]
        
        # Expand rotation_matrix to match batch_size
        rotation_matrix = rotation_matrix.repeat(x.size(0), 1, 1, 1)  # [batch_size, seq_len, dim // 2, 2]

        # Reshape x_rotated for element-wise multiplication
        x_rotated = x_rotated.view(x.size(0), seq_len, self.dim // 2, 2)

        return (x_rotated * rotation_matrix).view_as(x)
    
class TransformerBlock(nn.Module):
    def __init__(self, input_dim, num_heads, dim_feedforward ,dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.num_heads = num_heads
        self.dim_feedforward = dim_feedforward
         # Added linear layer for embedding input
        self.qkv_proj = nn.Linear(input_dim, 3 * input_dim)
        self.fc_out = nn.Linear(input_dim, input_dim)
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(input_dim, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, input_dim),
        )

    def scaled_dot_product_attention(self, q, k, v, mask):
        d_k = q.size(-1)
        attn_logits = torch.matmul(q, k.transpose(-2, -1))
        
        if mask is not None:
            attn_logits = attn_logits.masked_fill(mask, 0.0)
            
        attn_logits = attn_logits / math.sqrt(d_k)

        attention = torch.softmax(attn_logits, dim=-1)
        values = torch.matmul(attention, v)
        return values, attention

    def forward(self, x, mask):
        
        
        # Embedding input
          # Linear layer for embedding input
         # Expand for element-wise addition
          # Embed playbyplay data
        
        # Combine playbyplay and embedded input
        batch_size, seq_length,_ = x.size()
        rope = RoPEPositionalEncoding(x.size(-1))

        # Linear projection splits Q, K, and V, and then reshape and transpose for multi-head attention
        qkv = self.qkv_proj(x)
        q, k, v = qkv.chunk(3, dim=-1)
        q, k = map(lambda t: rope.forward(t, seq_length), (q, k))
        q = q.view(batch_size, seq_length, self.num_heads, x.size(-1) // self.num_heads).transpose(1, 2)
        k = k.view(batch_size, seq_length, self.num_heads, x.size(-1) // self.num_heads).transpose(1, 2)
        v = v.view(batch_size, seq_length, self.num_heads, x.size(-1) // self.num_heads).transpose(1, 2)

        # Scaled dot-product attention
        attn, _ = self.scaled_dot_product_attention(q, k, v, mask=mask)
    
        attn = attn.transpose(1, 2).contiguous().view(batch_size, seq_length, x.size(-1))
        attn = self.dropout(attn)

        # Add & Norm
        x = x + attn
        x = self.norm1(x)

        # Feed-forward network with injected embedding input
        ff_out = self.feed_forward(x)
        ff_out = self.dropout(ff_out)
        
        # Add & Norm
        x = x + ff_out
        x = self.norm2(x)

        return x

    
class TransformerDecoder(nn.Module):
    def __init__(self, num_layers, input_dim, num_heads, dim_feedforward, dropout):
        super(TransformerDecoder, self).__init__()
        self.layers = nn.ModuleList([TransformerBlock(input_dim, num_heads, dim_feedforward, dropout) for _ in range(num_layers)])

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return x

class Transformer(nn.Module):
    def __init__(self, vocab_size, input_dim, num_heads, dim_feedforward, num_layers, embedding_dim ,dropout):
        super(Transformer, self).__init__()
        self.linear_embedding = nn.Linear(embedding_dim, input_dim) 
        self.embed = nn.Embedding(vocab_size, input_dim)
        self.decoder = TransformerDecoder(num_layers, input_dim, num_heads, dim_feedforward,dropout)
        self.fc_out = nn.Linear(input_dim, vocab_size)

    def forward(self, tgt):
        seq_length = tgt.size(1)
        tgt_embedded = self.embed(tgt) * math.sqrt(self.embed.embedding_dim)
        #embedded_input = self.linear_embedding(context)
        #embedded_input = embedded_input.unsqueeze(1).expand(-1, seq_length, -1)
        #combined_input = torch.cat([tgt_embedded, embedded_input], dim=-1)
        
        tgt_mask = create_look_ahead_mask(seq_length)
        dec_output = self.decoder(tgt_embedded, mask=tgt_mask)

        logits = self.fc_out(dec_output)
        return logits


def generate_playbyplay(model, game_section):
    model.eval()  # Set the model to evaluation mode
    eos_token_id=sp.PieceToId('</s>')
    start_token_id=sp.PieceToId('<s>')
    # Tokenize the start token
    
    tgt = torch.cat([torch.tensor([start_token_id]), game_section], dim=0).unsqueeze(0)

    top_k=3
    #pbar = tqdm(total=(250 - len(tgt)))
    with torch.no_grad():
        for _ in range((250 - len(tgt))):
            # Forward pass through the model
            output = model.forward(tgt)

            # Get the logits for the next token
            next_token_logits = output[:, -1, :]
        

            # Get the index of the most likely token
            most_likely_index = torch.argmax(next_token_logits, dim=-1)

            # Extract the most likely token
            most_likely_token = most_likely_index.unsqueeze(0)

            # Add the most likely token to the target sequence
            tgt = torch.cat([tgt, most_likely_token], dim=1)
            #probs = F.softmax(next_token_logits, dim=-1)

                # Select the top-k tokens and their probabilities
            #top_k_probs, top_k_tokens = torch.topk(probs, top_k, dim=-1)
                
                # Normalize the probabilities of the top-k tokens
            #top_k_probs = top_k_probs / torch.sum(top_k_probs, dim=-1, keepdim=True)

                # Sample from the top-k tokens
            #next_token = torch.multinomial(top_k_probs, 1)
            #actual_next_token = top_k_tokens[0][next_token]
            #predicted_token = next_token_logits.argmax(-1).unsqueeze(-1)
            #tgt = torch.cat([tgt, actual_next_token], dim=1)
                        #pbar.update(1)
            # Break if EOS token is generated
            #if actual_next_token.item() == eos_token_id:
 
    #pbar.close()

            
            

    # Close the progress bar
    
    # Convert the output tokens to a list
    predicted_tokens = tgt.squeeze().tolist()

    model.train()  # Set the model back to training mode
    return predicted_tokens






In [11]:
import wandb

wandb.init(project='nba', entity='pelumia23')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpelumia23[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [16]:
import os
new_model_save_dir = "/home/pelumi/nba_model_5layers/"
model_save_dir = "/home/pelumi/nba_model/"
os.makedirs(model_save_dir, exist_ok=True)

def save_checkpoint(model, optimizer, epoch, loss_history, file_path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss_history': loss_history,
    }, file_path)

# Function to load model state
def load_checkpoint(model, optimizer, file_path):
    checkpoint = torch.load(file_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss_history = checkpoint['loss_history']
    return epoch, loss_history

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim

# Assuming you have a DataLoader named 'train_dataloader' and your model is named 'transformer_model'
# Also, specify other training parameters like learning rate, number of epochs, etc.

# Set up your model, DataLoader, optimizer, and criterion
vocab_size = 3679  # Adjust as needed
input_dim = 120  # Adjust as needed
num_heads = 4
dim_feedforward = 386
num_layers = 2
dropout = 0.1
learning_rate = 0.001
epochs = 10
embedding_dim = 322

# Example DataLoader (replace this with your actual DataLoader)


# Initialize model, optimizer, and criterion
transformer_model = Transformer(vocab_size, input_dim, num_heads, dim_feedforward, num_layers, embedding_dim, dropout)
optimizer = optim.Adam(transformer_model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=0)  # Assuming 0 is the padding index
#checkpoint_file = f"model_epoch_{desired_epoch}.pt"  # +1 because the file is named after completing the epoch
#checkpoint_path = os.path.join(model_save_dir, checkpoint_file)
#if os.path.exists(checkpoint_path):
    #epoch_start, loss_history = load_checkpoint(transformer_model, optimizer, checkpoint_path)
    
    

# Define the One Cycle Policy scheduler
    #game_firstprt = torch.tensor(sp.encode_as_ids(loaded_play_by_play_data[0][:10]))
    #plays = generate_playbyplay(transformer_model,game_firstprt)
# Training loop
total_loss = 0 
for epoch in range(epochs):
    transformer_model.train()  # Set the model to training mode
    
    for batch_idx, bacth in enumerate(dataloader):
        context = batch['embedding_input']
        tgt = batch['playbyplay_start']
        target = batch['playbyplay_end']
        optimizer.zero_grad()

        # Forward pass
        


          # Assuming '</s>' is the SentencePiece EOS token
        

        # Calculate loss with masking
        logits = transformer_model(tgt)
        logits = logits.view(-1, vocab_size)
        tgt_expected = target.view(-1)

        mask_target = tgt_expected != 0  # Assuming 0 is the padding index
        mask_output = tgt != 0

        mask_output = mask_output.view(-1)

        output_masked = logits[mask_output]
        target_masked = tgt_expected[mask_target]

        # Calculate loss
        loss = criterion(output_masked, target_masked)
        total_loss += loss.item()
        wandb.log({'batch_loss': loss.item(), 'epoch': epoch})
        loss.backward()
        optimizer.step()

        

    average_loss = total_loss / len(dataloader)
    print(f'Epoch {epoch + 1}, Loss: {average_loss}')
    plays = model.generate_playbyplay(first_embedding_input)
    print(f"Epoch {epoch}: Sample Translation - {plays}")

In [1]:
!pwd

/home/pelumi


In [12]:
vocab_size = 3679  # Adjust as needed
input_dim = 120  # Adjust as needed
num_heads = 4
dim_feedforward = 386
num_layers = 1
dropout = 0.1
learning_rate = 0.001
epochs = 10
embedding_dim = 282

# Example DataLoader (replace this with your actual DataLoader)


# Initialize model, optimizer, and criterion
transformer_model = Transformer(vocab_size, input_dim, num_heads, dim_feedforward, num_layers, embedding_dim, dropout)
optimizer = optim.Adam(transformer_model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=0) 
desired_epoch= 10
#checkpoint_file = f"model_epoch_{desired_epoch}.pt"  # +1 because the file is named after completing the epoch
#checkpoint_path = os.path.join(model_save_dir, checkpoint_file)
#if os.path.exists(checkpoint_path):
    #epoch_start, loss_history = load_checkpoint(transformer_model, optimizer, checkpoint_path)
    
    

# Define the One Cycle Policy scheduler
    #game_firstprt = torch.tensor(sp.encode_as_ids(loaded_play_by_play_data[0][:10]))
    #plays = generate_playbyplay(transformer_model,game_firstprt)


In [13]:
plays = sp.decode(plays)
print(plays)

1:PT12M00.00S (J. PT00S [202695] (J. Reaves 15' driving 1629651, 203076] (J. Claxton PT10M. Claxton REBOff:0 Def:PTS. Dinwiddie) (0-0) 1:1) 2 FT. PTS [1629680] (J. PTS [202695, 1627750] (AL (2 PF)))))) (AL (2pt)) (M36.00S [202695] (AL (2 STL (2 STL (2 PF) (AL (2 PF) (AL (2 STL (2 STL (2pt) (J. Free TOff:0 Def:PTS [1627826period pass out- PT11M. FOUL. FOUL. FOUL. Brogdon 3PT11MIS) (timeout. Simmons REBOff:0 Def:PT11MI. Simons FOULAL. Simons driving finger roll Layup) 3PTS) 3PTS (NOUND (timeout) 3PTS [202331] (timeout) (NOUND (NOUN
