In [9]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.nn import Parameter
from torch.autograd import Variable
from torch.sparse import FloatTensor as STensor
from torch.cuda.sparse import FloatTensor as CudaSTensor
from torch.utils.data import Dataset

from datareader import ShotsDataset

In [11]:
# Training settings
def options():
  parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                      help='input batch size for training (default: 64)')
  parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                      help='input batch size for testing (default: 1000)')
  parser.add_argument('--epochs', type=int, default=10, metavar='N',
                      help='number of epochs to train (default: 10)')
  parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                      help='learning rate (default: 0.01)')
  parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                      help='SGD momentum (default: 0.5)')
  parser.add_argument('--no-cuda', action='store_true', default=False,
                      help='disables CUDA training')
  parser.add_argument('--seed', type=int, default=1, metavar='S',
                      help='random seed (default: 1)')
  parser.add_argument('--vis-scalar-freq', type=int, default=10, metavar='N',
                      help='how many batches to wait before logging training status')
  args = parser.parse_args(["--lr","1e-3", "--no-cuda"])
  args.cuda = not args.no_cuda and torch.cuda.is_available()
  return args

In [12]:
if __name__ == "__main__":
    args = options()
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

In [13]:
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
    ShotsDataset("/cs/ml/datasets/bball/v1/bball_tracking",res_bh=10, res_def=2),
    batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    ShotsDataset("/cs/ml/datasets/bball/v1/bball_tracking", res_bh=10, res_def=2),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)

In [14]:
def idx_to_one_hot(idx, batch_size, feat_dim, cuda=False):
    T = torch.cuda.LongTensor if cuda else torch.LongTensor
    ST = CudaSTensor if cuda else STensor

    batch_idx = T([i for i in range(batch_size)])
    feat_idx = idx.view(batch_size,)

    my_stack = torch.stack([batch_idx, feat_idx]) # indices must be nDim * nEntries
    ones = torch.ones(batch_size)
    if cuda:
        ones = ones.type(torch.cuda.FloatTensor)
    y = ST(my_stack, ones, torch.Size([batch_size, feat_dim])).to_dense()

    return y

def idx_to_multi_hot(idx, batch_size, feat_dim, cuda=False):
    y_ = idx_to_one_hot(idx, batch_size, feat_dim, cuda=cuda)
    y = torch.sum(y_, 1)
    return y

In [44]:
class TensorModel(nn.Module):
    def __init__(self, args, dims, test=False):
        super(TensorModel, self).__init__()
        self.args = args

        self.n_classes = 2

        self.dK = dims["K"]

        # Scaling factor = 10
        self.dA,self.dB,self.dC = dims["A"], dims["B"], dims["C"]

        T = torch.cuda if self.args.cuda else torch

        self.f_bh = T.FloatTensor(args.batch_size, self.dB)
        self.f_def = T.FloatTensor(args.batch_size, self.dC)

        self._F = torch.FloatTensor(self.dA,self.dB,self.dC).zero_()
        self._F = Parameter(self._F, requires_grad=True)

        self._A = torch.FloatTensor(self.dA,self.dK).zero_()
        self._A = Parameter(self._A, requires_grad=True)
        self._B = torch.FloatTensor(self.dB,self.dK).zero_()
        self._B = Parameter(self._B, requires_grad=True)
        self._C = torch.FloatTensor(self.dC,self.dK).zero_()
        self._C = Parameter(self._C, requires_grad=True)

        self.init_random()

    def init_random(self):
        self._F.data.normal_(std=0.1)
        self._A.data.normal_(std=0.1)
        self._B.data.normal_(std=0.1)
        self._C.data.normal_(std=0.1)
        
        print(self._F.data)

    def forward(self, x):
        self.f_bh.zero_()
        idx = torch.unsqueeze(x[:,0].data, 1) % self.dB
        self.f_bh.scatter_(1, idx, 1)

        self.f_def.zero_()
        for i in range(5):
          idx = torch.unsqueeze(x[:,i].data, 1) % self.dC
          self.f_def.scatter_(1, idx, 1)

        v_f_bh = Variable(self.f_bh, requires_grad=False)
        v_f_def = Variable(self.f_def, requires_grad=False)

        # f_bh [B x dB] -> [B x 1 x 1 x dB]
        v_f_bh = torch.unsqueeze(v_f_bh, 1)
        v_f_bh = torch.unsqueeze(v_f_bh, 2)

        # f_bh * F --> [B x dA x 1 x dC]
        x = torch.matmul(v_f_bh, self._F)
        x = torch.squeeze(x)

        # (f_bh * F) * f_def *
        # [B x dA x dC] * [B x dC x 1] = [B x dA x 1]
        v_f_def = torch.unsqueeze(v_f_def, 2)

        x = torch.bmm(x, v_f_def)
        x = torch.squeeze(x)

        # output is [B x dA]
        return F.sigmoid(x)

In [45]:
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):        
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        
        idA = target[:,0]
        output = model(data)
        
        conj_output = 1 - output
        output = torch.cat([output, conj_output], 1)
                            
        label_01 = target[:,2]
        loss = F.nll_loss(output, label_01)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data[0]))

In [46]:
def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        
        idA = target[:,0]
        label_01 = target[:,2]
        output = torch.gather(output, 1, idA)
        
        conj_output = 1 - output
        output = torch.cat([output, conj_output], 1)
        
        test_loss += F.nll_loss(output, label_01, size_average=False).data[0] # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [47]:
if __name__ == "__main__":
    dims = {"A":439,"B":20,"C":36+1,"K":2}
    model = TensorModel(args, dims)
    for epoch in range(1, args.epochs + 1):
        train(epoch)
        test()


( 0 ,.,.) = 
 -9.4955e-02 -1.6301e-01  6.6675e-02  ...  -1.5309e-01  3.6662e-02  5.8938e-02
 -2.0559e-03  1.1344e-01 -1.0513e-02  ...   3.7522e-02 -9.5171e-02 -2.1418e-01
 -1.4244e-01 -1.4692e-01  1.1340e-01  ...  -4.7083e-02 -6.2154e-03  5.3985e-02
                 ...                   ⋱                   ...                
  6.3456e-02 -1.0137e-02 -2.1284e-01  ...   8.4366e-02  1.0488e-01  9.2674e-02
 -3.5145e-02  1.5803e-02 -6.0672e-02  ...   9.9506e-02 -1.1886e-01  9.8183e-02
  1.6806e-01 -1.1440e-01 -9.0875e-02  ...  -7.8800e-02  7.5088e-02 -5.4279e-02

( 1 ,.,.) = 
  7.0676e-02  1.7670e-01 -1.3909e-01  ...  -1.4632e-02  7.4488e-02 -2.1645e-02
 -3.5472e-03  4.9908e-03 -9.8495e-02  ...  -7.3550e-02  1.0628e-01 -2.2763e-03
 -4.6616e-02 -1.4708e-01  1.4968e-02  ...  -3.0381e-02  2.1769e-01 -1.1509e-02
                 ...                   ⋱                   ...                
 -2.6585e-02  1.0834e-02  7.4253e-02  ...  -1.2562e-01  1.5431e-02  9.3649e-02
 -1.7755e-01 -8.2570e-04

KeyboardInterrupt: 