# Action DSARF

In [None]:
from dsarf import DSARF, compute_NRMSE, ELBO_Loss
import numpy as np
import torch, torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm.notebook import tqdm
from sklearn.metrics import confusion_matrix
import scipy.signal
torch.manual_seed(10)
np.random.seed(10)

In [None]:
class classifier(nn.Module):
    def __init__(self, d_in, d_h, A):
        super(classifier, self).__init__()
        self.rnn = nn.LSTM(d_in, d_h, 2, batch_first=True, bidirectional=True) 
        self.fc1 = nn.Linear(2*d_h, d_h)
        self.fc2 = nn.Linear(d_h, A)
        self.relu = nn.ReLU()
    def forward(self, x):
        # x: N x T x D
        x, _ = self.rnn(x) # N x T x h
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class action_dsarf(nn.Module):
    def __init__(self, D, factor_dim, L, S, A, rc = True,
                 VI = {'rnn_dim': None, 'combine': False, 'S': False},
                 fc = False, bs=100, lr = 1e-2, S2A=True):
        super().__init__()
        self.dsarf = nn.ModuleList([DSARF(D=d, factor_dim=k, L=l, S=s, VI=VI, factorization=fc, recurrent = rc)
                                    for d, k, l, s in zip(D, factor_dim, L, S)])
        self.cls = classifier([sum(S) if S2A else sum(factor_dim)][0], VI['rnn_dim'], A)
        self.bs, self.lr, self.L, self.VI, self.grad = bs, lr, L, VI, True
        self.S2A, self.A = S2A, A
        
    def fit(self, data, labels, epoch=500):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print('Using device:', device)
        if self.grad:
            for m, d in zip(self.dsarf, data):
                data_cat = np.concatenate(d, axis = 0)
                #m.mean = data_cat[~np.isnan(data_cat)].mean() #comment for unnormalized
                #m.std = data_cat[~np.isnan(data_cat)].std() #comment
        for p in self.parameters(): #turn gradients on/off
            p.requires_grad  = self.grad
        
        data_in = data.copy()
        #data = [[(di - m.mean)/m.std for di in d] for d, m in zip(data, self.dsarf)]
        data = [np.array([(di - m.mean)/m.std for di in d]) for d, m in zip(data, self.dsarf)]
        n_data = [len(d) for d in data]
        lens = [[len(di) for di in d] for d in data]
        
        idxs = [torch.LongTensor([i]) for i in range(len(data[0]))]
        
        model = self.DSARF(self, n_data, lens).to(device)
        optim_dsarf = optim.Adam(model.parameters(), lr = self.lr)
        CELoss = nn.CrossEntropyLoss(reduction = 'sum')
        
        params = {'batch_size': self.bs,
                  'shuffle': True,
                  'num_workers': 0}
        
        train_loader = DataLoader(idxs, **params)
        
        for i in tqdm(range(epoch)):
            loss_value = 0.0
            acc = np.zeros((self.A, self.A))
            for bidxs in train_loader:
                bidxs = bidxs.reshape(-1)
                #mb = [torch.FloatTensor([d[bidx] for bidx in bidxs]).to(device) for d in data]
                mb = [torch.FloatTensor(d[bidxs]).to(device) for d in data]
                outs = model.forward(mb, [bidxs]*len(data))
                mask = [~torch.isnan(d) for d in mb]
                loss_dsarf = [ELBO_Loss(d[m], out[0][m], out[1], out[2],
                                      out[3][:, max(l):], out[4][:, max(l):],
                                      out[5], out[6], out[7], out[8],
                                      out[9][:, max(l):], out[10][:, max(l):],
                                      out[11][:,:, max(l):], out[12][:,:, max(l):],
                                      0.001) for d, out, m, l in zip(mb, outs, mask, self.L)]
                loss_dsarf = sum(loss_dsarf)
                if self.S2A:
                    q_in = torch.cat([out[13] for out in outs], dim=-1) #q_s_t_log
                else:
                    q_in = torch.cat([out[9] for out in outs], dim=-1) #q_z_mus
                target = model.cls.forward(q_in)
                    
                #loss_dsarf = loss_dsarf + CELoss(target.reshape(-1, self.A),
                #             torch.LongTensor([labels[bidx] for bidx in bidxs]).reshape(-1).to(device))
                loss_dsarf = loss_dsarf + CELoss(target.reshape(-1, self.A),
                             torch.LongTensor(labels[bidxs]).reshape(-1).to(device))
                
                optim_dsarf.zero_grad()
                loss_dsarf.backward()
                optim_dsarf.step()
                loss_value += loss_dsarf.item()
                
                target = target.argmax(-1).type(torch.FloatTensor).detach().cpu().numpy()
                target = torch.LongTensor(scipy.signal.medfilt(target,[1,5]))
                #acc += confusion_matrix(np.array([labels[bidx][0] for bidx in bidxs]),
                #                        target.mode(dim=-1)[0].numpy(),
                #                        labels = np.arange(len(acc))) 
                acc += confusion_matrix(labels[bidxs, 0],
                                        target.mode(dim=-1)[0].numpy(),
                                        labels = np.arange(len(acc))) 
                
            if (i % 50 == 0) or (i == epoch - 1):
                NRMSE = [m.report_stats(d) for m, d in zip(model.dsarf, data_in)]
                NRMSE = {'NRMSE_recv': sum([e['NRMSE_recv'] for e in NRMSE])/len(NRMSE),
                         'NRMSE_pred': sum([e['NRMSE_pred'] for e in NRMSE])/len(NRMSE)}
                epch = i + 1
                
            print('ELBO_Loss: %0.4f, Accuracy: %0.2f, Epoch %d: {NRMSE_recv : %0.2f, NRMSE_pred : %0.2f}'
                  % (loss_value / len(train_loader.dataset),
                     acc.trace()/len(train_loader.dataset)*100, 
                     epch, NRMSE['NRMSE_recv'], NRMSE['NRMSE_pred']),
                  end="\r", flush=True)
        return model
        
    def infer(self, data, labels, epoch = 1):
        self.grad = False
        model = self.fit(data, labels, epoch)
        self.grad = True
        return model
        
    class DSARF(nn.Module):
        def __init__(self, dsarf, n_data, lens):
            super().__init__()
            self.dsarf = nn.ModuleList([m.DSARF_(m, n, l) for m,n,l in zip(dsarf.dsarf, n_data, lens)])
            self.cls = dsarf.cls
        def forward(self, mb, mbi):
            return [m.forward(i, j) for i,j, m in zip(mb, mbi, self.dsarf)]
                      

In [None]:
path = '../data/train_data.npy'
data = np.load(path).transpose(4, 0, 2, 3, 1)[0]
T = data.shape[1]
D = data.shape[2]
data = data.reshape(-1, T, D, 3)
data_val = np.load(path[:-14]+'val_data.npy').transpose(4, 0, 2, 3, 1)[0]
data_val = data_val.reshape(-1,T,D,3)

import pickle
labels = pickle.load(open(path[:-14]+'train_label.pkl', 'rb'))[1]
labels = np.tile(np.array(labels).reshape(-1,1), (1, T))
labels_val = pickle.load(open(path[:-14]+'val_label.pkl', 'rb'))[1]
labels_val = np.tile(np.array(labels_val).reshape(-1,1), (1, T))

groups = [np.arange(25)]#[[13,14,15], [17,18,19], [0,12,16,1,4,20,8,2,3],
          #[5,6,7,21,22],[9,10,11,23,24]] #group of joints LF, RF, T, LH, RH
data_train = [[d[:,g].reshape(-1,len(g)*3) for d in data] for g in groups]
data_test = [[d[:,g].reshape(-1,len(g)*3) for d in data_val] for g in groups]

D = [d[0].shape[-1] for d in data_train]
L = [[1,2]]*len(groups)
K = [15]#[3,3,3,3,3] #[5]*len(groups)
S = [20]#[4]*len(groups)
VI = {'rnn_dim': K[0], 'combine': False, 'S': True}
dsarf = action_dsarf(D, K, L, S, A=labels.max()+1, VI=VI, bs=3000, S2A=True)
model_train = dsarf.fit(data_train, labels, 300)
model_test = dsarf.infer(data_test, labels_val, 5)