In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import game_dataset
from vae_model import NBA_AE
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib
import numpy as np
import importlib

In [4]:
class LSTMPredict(nn.Module):
    
    def __init__(self, state_dim, hidden_dim, sizes, num_layers = 1,):
        super(LSTMPredict, self).__init__()
        self.hidden_dim = hidden_dim
        self.state_dim = state_dim
        self.num_layers = num_layers
        
        self.lstm = nn.LSTM(state_dim, hidden_dim, num_layers)
        self.h2d = nn.Linear(hidden_dim, state_dim)
        self.hidden = self.init_hidden()
        
    def init_hidden(self):
        return (torch.zeros(self.num_layers, 1, self.hidden_dim),
                torch.zeros(self.num_layers, 1, self.hidden_dim))
    
    def forward(self, state):
        out, self.hidden = self.lstm(state, self.hidden)
        delta = self.h2d(out)
        next_state = torch.add(delta, state)
        return next_state

In [3]:
# LOAD TRAINING DATA
importlib.reload(game_dataset)
X = game_dataset.GameDataset('../data/unzipped/',0,20)
# FIND EVENT INDICES
changeidx = [0]
for i in range(len(X)-1):
    a = X[i][1]-X[i+1][0]
    if a[0] != 0.:
        changeidx.append(i + 1)
print('number of loaded events: ', len(changeidx))

available games:  24
number of loaded events:  4715


New variables:

            X:  stores all loaded game data in a list. All events are concatenated together. 
                As demonstrated below, each element in the list is composed of a tuple of two arrays. 
                The first array is the initial state, the second array is the next state. 
                X[0] = (s1,s2), X[1] = (s2,s3)
                    ** A BIT REDUNDANT, MAY WANT TO CHANGE IN THE FUTURE FOR MEMORY EFFICIENCY
                    
                Each state is composed of 23 doubles indicating the coordinates of the players in feet, 
                representing the xy coordinates of all 10 players on the court, and xyz coordinates for the ball. 
                The court dimensions are 94x 50y feet. 
                [home1x, home1y, home2x ..... away1x, away2y, ..... ballx, bally, ballz]
                    
                    
    changeidx:  stores the indices of X at which a new event begins. If we have X[n] = (s40, t1) and X[n+1] = (t1, 
                t2), where s and t are different events, then n+1 is appended

In [5]:
# CREATE LSTM INSTANCE
model = LSTMPredict(state_dim = 23, hidden_dim = 64, num_layers = 5, sizes = (128, 128))
loss_function = nn.MSELoss() # set loss function
optimizer = optim.Adadelta(model.parameters(), weight_decay = 0)

In [None]:
# TRAIN
for epoch in range(5):
    event_num = 0
    for idx1, idx2 in zip(changeidx[:-1], changeidx[1:]): # loops through events (idx1 is the start of one event idx2 is the start of the next)
        # Construct data for each event
        inputs = [] 
        targets = []
        for i in range(idx1, idx2): # using range doesn't include idx2 itself which is good since X[idx2][1] is not the same event
            inputs.append(X[i][0])
            targets.append(X[i][1])
        inputs = torch.tensor(np.expand_dims(inputs, 1)).float()
        targets = torch.tensor(np.expand_dims(targets, 1)).float()
        
        # Overhead
        model.zero_grad() # clear out gradients
        model.hidden = model.init_hidden() # reinitialize hidden state for each new event
        
        # Forward pass
        outputs = model(inputs) # outputs of LSTM. LSTM is fed pure inputs s.t. outputs are never fed back into LSTM for >1 timestep prediction
        
        # Compute loss and update parameters
        loss = loss_function(outputs, targets)
        loss.backward()
        optimizer.step()
        
        if event_num % 500 == 0:
            print('finished event {}!'.format(event_num))
        event_num += 1
print('done!')

finished event 0!
finished event 500!
finished event 1000!
finished event 1500!
finished event 2000!
finished event 2500!
finished event 3000!
finished event 3500!
finished event 4000!
finished event 4500!
finished event 0!
finished event 500!


In [None]:
# TRAIN ON SMALL INTERVALS
interval = 5

for epoch in range(5):
    event_num = 0
    for idx1, idx2 in zip(changeidx[:-1], changeidx[1:]): # loops events
        changeid = []
        i = 0
        # changeid filled with the indices of states every 'interval' timesteps from the start of the event, dividing the event into subevents
        while i < (idx2 - idx1):
            if i % interval == 0:
                changeid.append(idx1 + i)
            i += 1
        
        for id1, id2 in zip(changeid[:-1], changeid[1:]): # loops through each subevent as if it were an event
            # Construct data for each event
            inputs = [] 
            targets = []
            for i in range(id1, id2):
                inputs.append(X[i][0])
                targets.append(X[i][1])
            inputs = torch.tensor(np.expand_dims(inputs, 1)).float()
            targets = torch.tensor(np.expand_dims(targets, 1)).float()

            # Overhead
            model.zero_grad() # clear out gradients
            model.hidden = model.init_hidden() # reinitialize hidden state for each new event

            # Forward pass
            outputs = model(inputs) # compute outputs and train on the last element. PROBLEM: THIS DOESN"T FEED OUTPUTS BACK INTO INPUT. NOT PREDICTING MULTIPLE STEPS FORWARD

            # Compute loss and update parameters
            loss = loss_function(outputs[-1], targets[-1])
            loss.backward()
            optimizer.step()
        
        if event_num % 500 == 0:
            print('finished event {}!'.format(event_num))
        event_num += 1
print('done!')

PROBLEM: MODEL REQUIRES THE CORRECT ARCHITECTURE TO RELOAD

In [None]:
# SAVE MODEL
torch.save(model.state_dict(), "./models/model4")

In [None]:
# LOAD MODEL
model.load_state_dict(torch.load("./models/model4"))

In [None]:
# LOAD EVALUATION DATA
importlib.reload(game_dataset)
Y = game_dataset.GameDataset('../data/unzipped/',20,23)
# FIND EVENT INDICES
changeidy = [0]
for i in range(len(Y)-1):
    a = Y[i][1]-Y[i+1][0]
    if a[0] != 0.:
        changeidy.append(i + 1)
print('number of loaded events: ', len(changeidy))

In [None]:
# how much initial observation of each entity do we get?
horizon = [150,150,150] # [home, away, ball] must be at least 1
plotTrajectory(getPredictHorizon(horizon, event = 200), horizon, speed = 1, delay = 1, savename = 'home')

In [None]:
# how much initial observation of each entity do we get?
horizon = [10000,100,10000] # [home, away, ball] must be at least 1
plotTrajectory(getPredictHorizon(horizon, event = 10), horizon, speed = 1, delay = 1, savename = 'away')

In [None]:
# how much initial observation of each entity do we get?
horizon = [10000,10000,250] # [home, away, ball] must be at least 1
plotTrajectory(getPredictHorizon(horizon, event = 16), horizon, speed = 1, delay = 1, savename = 'ball')

In [None]:
# PLOTTING FUNCTIONS

%matplotlib notebook

def getTrajectory(event, dataset = 'train'):
    if dataset == 'train':
        idx1, idx2 = changeidx[event], changeidx[event+1]
        steps = idx2 - idx1
        print('event {} length: {}'.format(event, steps))
        trajectory = X[idx1:idx2][0]
        return trajectory
    elif dataset == 'eval':
        idy1, idy2 = changeidy[event], changeidy[event+1]
        steps = idy2 - idy1
        print('event {} length: {}'.format(event, steps))
        trajectory = Y[idy1:idy2][0]
        return trajectory
    
def plotTrajectory(trajectories, horizon, speed=1., delay = 0.2, savename = None):
    %matplotlib notebook
    fig = plt.figure(figsize = (94/10,50/10))
    frames = []
    trans = plt.axes().transAxes
    backdrop = plt.imread('../court.png')
       
    # if real trajectory is not given, do not show
    if len(trajectories[1]) == 0:
        trajectories[1] = -np.ones((len(trajectories[0]),23))
        
    for _, (cord1, cord2) in enumerate(zip(trajectories[0], trajectories[1])):
        coords1 = [[[],[]],[[],[]]]
        coords2 = [[[],[]],[[],[]]]
        for n, t in enumerate([0,10]): # team
            for d in [0,1]: # dimension
                for p in range(0,5):
                    coords1[n][d].append(cord1[2*p + d + t]) 
                    coords2[n][d].append(cord2[2*p + d + t]) 
        height1 = (10 + cord1[22])
        height2 = (10 + cord2[22])
        frame = []
        # prediction
        home1, = plt.plot(coords1[0][0], coords1[0][1], color = '#00a4ff', marker = 'D', linestyle = '', markersize = 10)
        away1, = plt.plot(coords1[1][0], coords1[1][1], color = '#ff00f3', marker = 'D', linestyle = '', markersize = 10)
        ball1, = plt.plot(cord1[20], cord1[21], color = '#ff8600', marker = 'o', linestyle = '', markersize= height1)
        frame.extend([home1, away1, ball1])
        # real (faded)
        home2, = plt.plot(coords2[0][0], coords2[0][1], color = '#00a4ff', marker = 'D', linestyle = '', markersize = 10, alpha = 0.3)
        away2, = plt.plot(coords2[1][0], coords2[1][1], color = '#ff00f3', marker = 'D', linestyle = '', markersize = 10, alpha = 0.3)
        ball2, = plt.plot(cord2[20], cord2[21], color = '#ff8600', marker = 'o', linestyle = '', markersize= height1, alpha = 0.3)
        frame.extend([home2, away2, ball2])
        # horizon marker
        if _ > horizon[0]:
            h, = plt.plot([91], [47], color = 'white', marker = '$H$', linestyle = '', markersize = 20)
            frame.append(h)
        if _ > horizon[1]:
            a, = plt.plot([86.3], [47], color = 'white', marker = '$A$', linestyle = '', markersize = 20)
            frame.append(a)
        if _ > horizon[2]:
            b, = plt.plot([82], [47], color = 'white', marker = '$B$', linestyle = '', markersize = 20)
            frame.append(b)
        t, = plt.plot([10], [2], color = 'white', marker = '$tsteps: '+str(_)+'$', linestyle = '', markersize = 60)
        frame.append(t)
        # save frame
        frames.append(frame)

    plt.xlim(0,94)
    plt.ylim(0,50)
    plt.axes().set_aspect('equal')
    ani = animation.ArtistAnimation(fig, frames, interval=10 / speed, blit = True, repeat_delay = delay)
    plt.imshow(backdrop,extent = (0,94,0,50))
    if savename != None:
        savename = savename + '.mp4'
        ani.save(savename, fps = 30, bitrate = 1000)
    return ani

def getPredictHorizon(horizon, event, dataset = 'evaluation'):
    if dataset == 'evaluation':
        Z = Y
        changeid = changeidy
    elif dataset == 'test':
        Z = X
        changeid = changeidx
    else:
        print("not a dataset")
        
    id1, id2 = changeid[event], changeid[event+1]
    steps = id2 - id1
    print('event {} length: {}'.format(event, steps))
    
    model.hidden = model.init_hidden() # initialize hidden state
    pred_trajectory = []
    real_trajectory = []
    for n, i in enumerate(range(id1, id2)):
        # get observation
        observation = Z[i][0]
        # replace observation with predictions 
        observation[0:10]  = observation[0:10]  if n < horizon[0] else prediction[0:10]
        observation[10:20] = observation[10:20] if n < horizon[1] else prediction[10:20]
        observation[20:23] = observation[20:23] if n < horizon[2] else prediction[20:23]
        # overhead
        observation = torch.tensor(np.expand_dims([observation], 1)).float()
        # predict
        prediction = model(observation)
        prediction = prediction.detach().numpy().flatten()
        # save prediction and target
        pred_trajectory.append(prediction)
        real_trajectory.append(Z[i][1])
        
    return pred_trajectory, real_trajectory