In [1]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn

from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay, f1_score
from sklearn.model_selection import StratifiedKFold, TimeSeriesSplit
from tqdm import tqdm, trange

seed=42
FTWs = [720, 540, 360, 180, 60, 30, 15, 5, 3, 2, 1, 0, 0]
ftw_window = 10

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
features = np.load('./ftw_data/features.npy')
activities = np.load('./ftw_data/activities.npy')

In [4]:
tsv = TimeSeriesSplit(n_splits=3)
kfold = StratifiedKFold(n_splits=3, shuffle=True, random_state=seed)
k = 0

# Problem about the time series split since it would have overlap with the previous
for train, test in list(tsv.split(features, activities))[-1:]:
    print('X_train shape:', features[train].shape)
    print('y_train shape:', activities[train].shape)

X_train shape: (3347, 10, 56)
y_train shape: (3347, 17)


In [5]:
class LSTM(nn.Module):
    
    def __init__(self,input_dim,hidden_dim,output_dim,layer_num):
        super(LSTM,self).__init__()
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.lstm = nn.LSTM(input_dim,hidden_dim,layer_num,batch_first=False)
        self.fc = nn.Linear(hidden_dim,output_dim)
        self.sigmoid = nn.Sigmoid()
        # self.bn = nn.BatchNorm1d(32)
        
    def forward(self,inputs):
        # x = self.bn(inputs)
        lstm_out,(hn,cn) = self.lstm(inputs)
        out = self.fc(lstm_out)
        # print(out.shape)
        return self.sigmoid(out)

In [6]:
ensemble_activities = activities.T

### Mean

In [7]:
def preprocess_features(features, method):
    if method not in ['mean', 'mean_std', 'mean_wth_weight', 'mean_std_max_min']:
        raise ValueError('Please double check the method parameter')

    preprocessed_features = np.array([])
    if method == 'mean':
        return features.mean(axis=1)
    elif method == 'mean_std':
        preprocess_features = np.array(features.mean(axis=1))
        print(preprocess_features.shape)
        preprocess_features = np.concatenate([preprocess_features, features.std(axis=1)], axis=1)
        print(preprocess_features.shape)
        return preprocess_features
    elif method == 'mean_wth_weight':
        out = []
        for instance in features:
            out.append(np.sum([vector*((i+1)/10) for i, vector in enumerate(instance)], axis=0))
        return np.array(out)      

In [8]:
FTWs
sum = 0
for i in range(ftw_window):
    l4, l3, l2, l1 = FTWs[i:i+4]
    sum += (l3 - l2)
sum

540

In [9]:
mean_features = preprocess_features(features, 'mean_wth_weight')

input_size = mean_features.shape[1]

print(input_size)

n_hidden = 512
n_categories = 1
n_layer = 3
rnn = LSTM(input_size,n_hidden,n_categories,n_layer)

rnn.to(device)

56


LSTM(
  (lstm): LSTM(56, 512, num_layers=3)
  (fc): Linear(in_features=512, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [10]:
print(input.shape)
ensemble_activities.shape

AttributeError: 'function' object has no attribute 'shape'

In [12]:
import torch.optim as optim
import time
import math

criterion = nn.BCELoss()
learning_rate = 0.0005
optimizer = optim.Adam(rnn.parameters(),lr=learning_rate)
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.1)

n_iters = 1000
print_every = 10
plot_every = 10
batch_size = 256

# Keep track of losses for plotting
current_loss = 0
all_losses = []
accur = []

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

sleep_activities = ensemble_activities[0]
train, test = list(tsv.split(mean_features, sleep_activities))[-1]

print('X_train shape:', mean_features[train].shape)
print('y_train shape:', sleep_activities[train].shape)
train_features, train_activities = mean_features[train], sleep_activities[train]
dev_features, dev_activities = mean_features[test], sleep_activities[test]

start = time.time()

for iter in trange(1, n_iters + 1):
    
    train_labels, train_tensor = torch.tensor(train_activities, dtype=torch.float32), torch.tensor(train_features, dtype=torch.float32)
    dev_labels, dev_tensor = torch.tensor(dev_activities, dtype=torch.float32), torch.tensor(dev_features, dtype=torch.float32)

    train_tensor = train_tensor.to(device)
    train_labels = train_labels.to(device).unsqueeze(1)

    dev_tensor = dev_tensor.to(device)
    dev_labels = dev_labels.to(device).unsqueeze(1)
    
    optimizer.zero_grad()
    
    output = rnn(train_tensor)
    loss = criterion(output, train_labels)
    loss.backward()
    optimizer.step()
    #scheduler.step()
    
    prediction = rnn(dev_tensor)
    pred_cpu =  prediction.reshape(-1).cpu().detach().numpy().round()
    acc = f1_score(dev_activities, pred_cpu)
    #backprop

    if iter%print_every == 0:
        accur.append(acc)
        print("epoch {}\tloss : {}\t accuracy : {}".format(iter,loss,acc))
        print(classification_report(dev_activities, pred_cpu))
    
    current_loss += loss.item()
    
    # category = LABELS[int(category_tensor[0])]

    # # Print iter number, loss, name and guess
    # if iter % print_every == 0:
    #     guess, guess_i = categoryFromOutput(output)
    #     correct = '✓' if guess == category else '✗ (%s)' % category
    #     print('%d %d%% (%s) %.4f  / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss, guess, correct))
        
    # Add current loss avg to list of losses
    if iter % plot_every == 0:
        all_losses.append(current_loss / plot_every)
        current_loss = 0

X_train shape: (3347, 56)
y_train shape: (3347,)


  1%|          | 10/1000 [00:02<04:34,  3.60it/s]

epoch 10	loss : 0.32651543617248535	 accuracy : 0.6504065040650405
              precision    recall  f1-score   support

         0.0       0.94      0.92      0.93       937
         1.0       0.63      0.67      0.65       178

    accuracy                           0.88      1115
   macro avg       0.78      0.80      0.79      1115
weighted avg       0.89      0.88      0.89      1115



  2%|▏         | 20/1000 [00:05<04:32,  3.60it/s]

epoch 20	loss : 0.28317540884017944	 accuracy : 0.694300518134715
              precision    recall  f1-score   support

         0.0       0.95      0.92      0.94       937
         1.0       0.64      0.75      0.69       178

    accuracy                           0.89      1115
   macro avg       0.80      0.84      0.82      1115
weighted avg       0.90      0.89      0.90      1115



  3%|▎         | 30/1000 [00:08<04:29,  3.60it/s]

epoch 30	loss : 0.2537921667098999	 accuracy : 0.7296587926509187
              precision    recall  f1-score   support

         0.0       0.96      0.93      0.94       937
         1.0       0.68      0.78      0.73       178

    accuracy                           0.91      1115
   macro avg       0.82      0.86      0.84      1115
weighted avg       0.91      0.91      0.91      1115



  4%|▍         | 40/1000 [00:11<04:26,  3.60it/s]

epoch 40	loss : 0.23639126121997833	 accuracy : 0.7263157894736842
              precision    recall  f1-score   support

         0.0       0.96      0.93      0.94       937
         1.0       0.68      0.78      0.73       178

    accuracy                           0.91      1115
   macro avg       0.82      0.85      0.84      1115
weighted avg       0.91      0.91      0.91      1115



  5%|▌         | 50/1000 [00:13<04:23,  3.60it/s]

epoch 50	loss : 0.21512241661548615	 accuracy : 0.7584415584415585
              precision    recall  f1-score   support

         0.0       0.96      0.93      0.95       937
         1.0       0.71      0.82      0.76       178

    accuracy                           0.92      1115
   macro avg       0.84      0.88      0.85      1115
weighted avg       0.92      0.92      0.92      1115



  6%|▌         | 60/1000 [00:16<04:21,  3.60it/s]

epoch 60	loss : 0.20531873404979706	 accuracy : 0.730077120822622
              precision    recall  f1-score   support

         0.0       0.96      0.93      0.94       937
         1.0       0.67      0.80      0.73       178

    accuracy                           0.91      1115
   macro avg       0.82      0.86      0.84      1115
weighted avg       0.91      0.91      0.91      1115



  7%|▋         | 70/1000 [00:19<04:18,  3.60it/s]

epoch 70	loss : 0.19565366208553314	 accuracy : 0.6873508353221958
              precision    recall  f1-score   support

         0.0       0.96      0.90      0.93       937
         1.0       0.60      0.81      0.69       178

    accuracy                           0.88      1115
   macro avg       0.78      0.85      0.81      1115
weighted avg       0.90      0.88      0.89      1115



  8%|▊         | 80/1000 [00:22<04:15,  3.60it/s]

epoch 80	loss : 0.19016246497631073	 accuracy : 0.7117794486215541
              precision    recall  f1-score   support

         0.0       0.96      0.92      0.94       937
         1.0       0.64      0.80      0.71       178

    accuracy                           0.90      1115
   macro avg       0.80      0.86      0.82      1115
weighted avg       0.91      0.90      0.90      1115



  9%|▉         | 90/1000 [00:24<04:13,  3.59it/s]

epoch 90	loss : 0.1681666076183319	 accuracy : 0.6696230598669622
              precision    recall  f1-score   support

         0.0       0.97      0.87      0.92       937
         1.0       0.55      0.85      0.67       178

    accuracy                           0.87      1115
   macro avg       0.76      0.86      0.79      1115
weighted avg       0.90      0.87      0.88      1115



 10%|█         | 100/1000 [00:27<04:10,  3.60it/s]

epoch 100	loss : 0.17986582219600677	 accuracy : 0.6650943396226415
              precision    recall  f1-score   support

         0.0       0.96      0.89      0.92       937
         1.0       0.57      0.79      0.67       178

    accuracy                           0.87      1115
   macro avg       0.77      0.84      0.79      1115
weighted avg       0.90      0.87      0.88      1115



 11%|█         | 110/1000 [00:30<04:07,  3.59it/s]

epoch 110	loss : 0.16780371963977814	 accuracy : 0.7762803234501348
              precision    recall  f1-score   support

         0.0       0.96      0.95      0.96       937
         1.0       0.75      0.81      0.78       178

    accuracy                           0.93      1115
   macro avg       0.85      0.88      0.87      1115
weighted avg       0.93      0.93      0.93      1115



 12%|█▏        | 120/1000 [00:33<04:05,  3.59it/s]

epoch 120	loss : 0.15048253536224365	 accuracy : 0.7422680412371134
              precision    recall  f1-score   support

         0.0       0.96      0.93      0.95       937
         1.0       0.69      0.81      0.74       178

    accuracy                           0.91      1115
   macro avg       0.82      0.87      0.84      1115
weighted avg       0.92      0.91      0.91      1115



 13%|█▎        | 130/1000 [00:36<04:02,  3.58it/s]

epoch 130	loss : 0.23275244235992432	 accuracy : 0.6156862745098038
              precision    recall  f1-score   support

         0.0       0.97      0.81      0.89       937
         1.0       0.47      0.88      0.62       178

    accuracy                           0.82      1115
   macro avg       0.72      0.85      0.75      1115
weighted avg       0.89      0.82      0.84      1115



 14%|█▍        | 140/1000 [00:38<04:00,  3.58it/s]

epoch 140	loss : 0.19344428181648254	 accuracy : 0.6788511749347259
              precision    recall  f1-score   support

         0.0       0.95      0.92      0.93       937
         1.0       0.63      0.73      0.68       178

    accuracy                           0.89      1115
   macro avg       0.79      0.83      0.81      1115
weighted avg       0.90      0.89      0.89      1115



 15%|█▌        | 150/1000 [00:41<03:57,  3.58it/s]

epoch 150	loss : 0.17428216338157654	 accuracy : 0.7611111111111112
              precision    recall  f1-score   support

         0.0       0.96      0.95      0.95       937
         1.0       0.75      0.77      0.76       178

    accuracy                           0.92      1115
   macro avg       0.85      0.86      0.86      1115
weighted avg       0.92      0.92      0.92      1115



 16%|█▌        | 160/1000 [00:44<03:54,  3.58it/s]

epoch 160	loss : 0.15807783603668213	 accuracy : 0.680952380952381
              precision    recall  f1-score   support

         0.0       0.96      0.89      0.93       937
         1.0       0.59      0.80      0.68       178

    accuracy                           0.88      1115
   macro avg       0.78      0.85      0.80      1115
weighted avg       0.90      0.88      0.89      1115



 17%|█▋        | 170/1000 [00:47<03:51,  3.58it/s]

epoch 170	loss : 0.14421698451042175	 accuracy : 0.6561797752808989
              precision    recall  f1-score   support

         0.0       0.96      0.87      0.91       937
         1.0       0.55      0.82      0.66       178

    accuracy                           0.86      1115
   macro avg       0.75      0.85      0.79      1115
weighted avg       0.90      0.86      0.87      1115



 18%|█▊        | 180/1000 [00:50<03:48,  3.58it/s]

epoch 180	loss : 0.13152320683002472	 accuracy : 0.6341463414634146
              precision    recall  f1-score   support

         0.0       0.96      0.86      0.91       937
         1.0       0.52      0.80      0.63       178

    accuracy                           0.85      1115
   macro avg       0.74      0.83      0.77      1115
weighted avg       0.89      0.85      0.86      1115



 19%|█▉        | 190/1000 [00:52<03:46,  3.58it/s]

epoch 190	loss : 0.1478080153465271	 accuracy : 0.6181015452538632
              precision    recall  f1-score   support

         0.0       0.95      0.86      0.90       937
         1.0       0.51      0.79      0.62       178

    accuracy                           0.84      1115
   macro avg       0.73      0.82      0.76      1115
weighted avg       0.88      0.84      0.86      1115



 20%|██        | 200/1000 [00:55<03:43,  3.59it/s]

epoch 200	loss : 0.13288845121860504	 accuracy : 0.6252676659528908
              precision    recall  f1-score   support

         0.0       0.96      0.85      0.90       937
         1.0       0.51      0.82      0.63       178

    accuracy                           0.84      1115
   macro avg       0.73      0.83      0.76      1115
weighted avg       0.89      0.84      0.86      1115



 21%|██        | 210/1000 [00:58<03:40,  3.59it/s]

epoch 210	loss : 0.11720936745405197	 accuracy : 0.6285714285714287
              precision    recall  f1-score   support

         0.0       0.96      0.86      0.90       937
         1.0       0.52      0.80      0.63       178

    accuracy                           0.85      1115
   macro avg       0.74      0.83      0.77      1115
weighted avg       0.89      0.85      0.86      1115



 22%|██▏       | 220/1000 [01:01<03:38,  3.58it/s]

epoch 220	loss : 0.104091115295887	 accuracy : 0.6410835214446953
              precision    recall  f1-score   support

         0.0       0.96      0.87      0.91       937
         1.0       0.54      0.80      0.64       178

    accuracy                           0.86      1115
   macro avg       0.75      0.83      0.78      1115
weighted avg       0.89      0.86      0.87      1115



 23%|██▎       | 230/1000 [01:03<03:34,  3.59it/s]

epoch 230	loss : 0.09804610908031464	 accuracy : 0.6236559139784945
              precision    recall  f1-score   support

         0.0       0.96      0.85      0.90       937
         1.0       0.51      0.81      0.62       178

    accuracy                           0.84      1115
   macro avg       0.73      0.83      0.76      1115
weighted avg       0.89      0.84      0.86      1115



 24%|██▍       | 240/1000 [01:06<03:32,  3.58it/s]

epoch 240	loss : 0.11393619328737259	 accuracy : 0.623608017817372
              precision    recall  f1-score   support

         0.0       0.95      0.86      0.91       937
         1.0       0.52      0.79      0.62       178

    accuracy                           0.85      1115
   macro avg       0.74      0.82      0.76      1115
weighted avg       0.88      0.85      0.86      1115



 25%|██▌       | 250/1000 [01:09<03:29,  3.58it/s]

epoch 250	loss : 0.09632671624422073	 accuracy : 0.6322869955156951
              precision    recall  f1-score   support

         0.0       0.96      0.86      0.91       937
         1.0       0.53      0.79      0.63       178

    accuracy                           0.85      1115
   macro avg       0.74      0.83      0.77      1115
weighted avg       0.89      0.85      0.86      1115



 26%|██▌       | 260/1000 [01:12<03:26,  3.58it/s]

epoch 260	loss : 0.08026690781116486	 accuracy : 0.6299559471365638
              precision    recall  f1-score   support

         0.0       0.96      0.86      0.91       937
         1.0       0.52      0.80      0.63       178

    accuracy                           0.85      1115
   macro avg       0.74      0.83      0.77      1115
weighted avg       0.89      0.85      0.86      1115



 27%|██▋       | 270/1000 [01:15<03:23,  3.58it/s]

epoch 270	loss : 0.06650646775960922	 accuracy : 0.6181015452538632
              precision    recall  f1-score   support

         0.0       0.95      0.86      0.90       937
         1.0       0.51      0.79      0.62       178

    accuracy                           0.84      1115
   macro avg       0.73      0.82      0.76      1115
weighted avg       0.88      0.84      0.86      1115



 28%|██▊       | 280/1000 [01:17<03:20,  3.59it/s]

epoch 280	loss : 0.22025522589683533	 accuracy : 0.6814814814814815
              precision    recall  f1-score   support

         0.0       0.95      0.91      0.93       937
         1.0       0.61      0.78      0.68       178

    accuracy                           0.88      1115
   macro avg       0.78      0.84      0.81      1115
weighted avg       0.90      0.88      0.89      1115



 29%|██▉       | 290/1000 [01:20<03:18,  3.58it/s]

epoch 290	loss : 0.16341768205165863	 accuracy : 0.773841961852861
              precision    recall  f1-score   support

         0.0       0.96      0.95      0.96       937
         1.0       0.75      0.80      0.77       178

    accuracy                           0.93      1115
   macro avg       0.86      0.87      0.86      1115
weighted avg       0.93      0.93      0.93      1115



 30%|███       | 300/1000 [01:23<03:15,  3.58it/s]

epoch 300	loss : 0.14217552542686462	 accuracy : 0.7651715039577837
              precision    recall  f1-score   support

         0.0       0.96      0.94      0.95       937
         1.0       0.72      0.81      0.77       178

    accuracy                           0.92      1115
   macro avg       0.84      0.88      0.86      1115
weighted avg       0.93      0.92      0.92      1115



 31%|███       | 310/1000 [01:26<03:12,  3.59it/s]

epoch 310	loss : 0.12300512194633484	 accuracy : 0.776595744680851
              precision    recall  f1-score   support

         0.0       0.97      0.94      0.95       937
         1.0       0.74      0.82      0.78       178

    accuracy                           0.92      1115
   macro avg       0.85      0.88      0.87      1115
weighted avg       0.93      0.92      0.93      1115



 32%|███▏      | 320/1000 [01:29<03:09,  3.58it/s]

epoch 320	loss : 0.1034957617521286	 accuracy : 0.768831168831169
              precision    recall  f1-score   support

         0.0       0.97      0.94      0.95       937
         1.0       0.71      0.83      0.77       178

    accuracy                           0.92      1115
   macro avg       0.84      0.88      0.86      1115
weighted avg       0.93      0.92      0.92      1115



 33%|███▎      | 330/1000 [01:31<03:07,  3.58it/s]

epoch 330	loss : 0.0915493369102478	 accuracy : 0.776595744680851
              precision    recall  f1-score   support

         0.0       0.97      0.94      0.95       937
         1.0       0.74      0.82      0.78       178

    accuracy                           0.92      1115
   macro avg       0.85      0.88      0.87      1115
weighted avg       0.93      0.92      0.93      1115



 34%|███▍      | 340/1000 [01:34<03:04,  3.59it/s]

epoch 340	loss : 0.08780454099178314	 accuracy : 0.7447916666666667
              precision    recall  f1-score   support

         0.0       0.96      0.93      0.95       937
         1.0       0.69      0.80      0.74       178

    accuracy                           0.91      1115
   macro avg       0.83      0.87      0.85      1115
weighted avg       0.92      0.91      0.91      1115



 35%|███▌      | 350/1000 [01:37<03:01,  3.58it/s]

epoch 350	loss : 0.07232514023780823	 accuracy : 0.7525773195876287
              precision    recall  f1-score   support

         0.0       0.96      0.93      0.95       937
         1.0       0.70      0.82      0.75       178

    accuracy                           0.91      1115
   macro avg       0.83      0.88      0.85      1115
weighted avg       0.92      0.91      0.92      1115



 36%|███▌      | 360/1000 [01:40<02:58,  3.58it/s]

epoch 360	loss : 0.05581659823656082	 accuracy : 0.7435897435897436
              precision    recall  f1-score   support

         0.0       0.96      0.93      0.95       937
         1.0       0.68      0.81      0.74       178

    accuracy                           0.91      1115
   macro avg       0.82      0.87      0.84      1115
weighted avg       0.92      0.91      0.91      1115



 37%|███▋      | 370/1000 [01:43<02:55,  3.59it/s]

epoch 370	loss : 0.045885924249887466	 accuracy : 0.7218045112781954
              precision    recall  f1-score   support

         0.0       0.96      0.92      0.94       937
         1.0       0.65      0.81      0.72       178

    accuracy                           0.90      1115
   macro avg       0.81      0.86      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 38%|███▊      | 380/1000 [01:45<02:53,  3.58it/s]

epoch 380	loss : 0.0880722776055336	 accuracy : 0.7052896725440806
              precision    recall  f1-score   support

         0.0       0.96      0.92      0.94       937
         1.0       0.64      0.79      0.71       178

    accuracy                           0.90      1115
   macro avg       0.80      0.85      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 39%|███▉      | 390/1000 [01:48<02:50,  3.58it/s]

epoch 390	loss : 0.04845674708485603	 accuracy : 0.7254408060453399
              precision    recall  f1-score   support

         0.0       0.96      0.92      0.94       937
         1.0       0.66      0.81      0.73       178

    accuracy                           0.90      1115
   macro avg       0.81      0.86      0.83      1115
weighted avg       0.91      0.90      0.91      1115



 40%|████      | 400/1000 [01:51<02:47,  3.58it/s]

epoch 400	loss : 0.03651002049446106	 accuracy : 0.7229551451187336
              precision    recall  f1-score   support

         0.0       0.96      0.93      0.94       937
         1.0       0.68      0.77      0.72       178

    accuracy                           0.91      1115
   macro avg       0.82      0.85      0.83      1115
weighted avg       0.91      0.91      0.91      1115



 41%|████      | 410/1000 [01:54<02:44,  3.58it/s]

epoch 410	loss : 0.028506150469183922	 accuracy : 0.7202072538860103
              precision    recall  f1-score   support

         0.0       0.96      0.93      0.94       937
         1.0       0.67      0.78      0.72       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.91      1115



 42%|████▏     | 420/1000 [01:56<02:41,  3.59it/s]

epoch 420	loss : 0.022163378074765205	 accuracy : 0.7139107611548555
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 43%|████▎     | 430/1000 [01:59<02:39,  3.57it/s]

epoch 430	loss : 0.017104530707001686	 accuracy : 0.7086614173228346
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 44%|████▍     | 440/1000 [02:02<02:37,  3.57it/s]

epoch 440	loss : 0.013060769066214561	 accuracy : 0.7191601049868767
              precision    recall  f1-score   support

         0.0       0.96      0.93      0.94       937
         1.0       0.67      0.77      0.72       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.91      1115



 45%|████▌     | 450/1000 [02:05<02:34,  3.57it/s]

epoch 450	loss : 0.009796222671866417	 accuracy : 0.7206266318537858
              precision    recall  f1-score   support

         0.0       0.96      0.93      0.94       937
         1.0       0.67      0.78      0.72       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.91      1115



 46%|████▌     | 460/1000 [02:08<02:31,  3.57it/s]

epoch 460	loss : 0.007312969770282507	 accuracy : 0.7098445595854922
              precision    recall  f1-score   support

         0.0       0.95      0.92      0.94       937
         1.0       0.66      0.77      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 47%|████▋     | 470/1000 [02:10<02:28,  3.56it/s]

epoch 470	loss : 0.005554351024329662	 accuracy : 0.7131782945736435
              precision    recall  f1-score   support

         0.0       0.96      0.92      0.94       937
         1.0       0.66      0.78      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 48%|████▊     | 480/1000 [02:13<02:25,  3.57it/s]

epoch 480	loss : 0.00435628229752183	 accuracy : 0.7146529562982005
              precision    recall  f1-score   support

         0.0       0.96      0.92      0.94       937
         1.0       0.66      0.78      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 49%|████▉     | 490/1000 [02:16<02:22,  3.57it/s]

epoch 490	loss : 0.0035260559525340796	 accuracy : 0.7150259067357513
              precision    recall  f1-score   support

         0.0       0.96      0.93      0.94       937
         1.0       0.66      0.78      0.72       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 50%|█████     | 500/1000 [02:19<02:20,  3.57it/s]

epoch 500	loss : 0.002933961572125554	 accuracy : 0.7135416666666665
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.77      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 51%|█████     | 510/1000 [02:22<02:17,  3.56it/s]

epoch 510	loss : 0.002500047441571951	 accuracy : 0.7135416666666665
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.77      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 52%|█████▏    | 520/1000 [02:24<02:14,  3.57it/s]

epoch 520	loss : 0.002174267778173089	 accuracy : 0.7086614173228346
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 53%|█████▎    | 530/1000 [02:27<02:11,  3.57it/s]

epoch 530	loss : 0.0019234016072005033	 accuracy : 0.710182767624021
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 54%|█████▍    | 540/1000 [02:30<02:09,  3.56it/s]

epoch 540	loss : 0.0017266771756112576	 accuracy : 0.710182767624021
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 55%|█████▌    | 550/1000 [02:33<02:06,  3.56it/s]

epoch 550	loss : 0.001568502513691783	 accuracy : 0.7064935064935065
              precision    recall  f1-score   support

         0.0       0.95      0.92      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 56%|█████▌    | 560/1000 [02:36<02:03,  3.56it/s]

epoch 560	loss : 0.001439801068045199	 accuracy : 0.7064935064935065
              precision    recall  f1-score   support

         0.0       0.95      0.92      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 57%|█████▋    | 570/1000 [02:38<02:00,  3.56it/s]

epoch 570	loss : 0.0013331834925338626	 accuracy : 0.7046632124352331
              precision    recall  f1-score   support

         0.0       0.95      0.92      0.94       937
         1.0       0.65      0.76      0.70       178

    accuracy                           0.90      1115
   macro avg       0.80      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 58%|█████▊    | 580/1000 [02:41<01:57,  3.57it/s]

epoch 580	loss : 0.0012434987584128976	 accuracy : 0.7046632124352331
              precision    recall  f1-score   support

         0.0       0.95      0.92      0.94       937
         1.0       0.65      0.76      0.70       178

    accuracy                           0.90      1115
   macro avg       0.80      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 59%|█████▉    | 590/1000 [02:44<01:55,  3.56it/s]

epoch 590	loss : 0.0011674652341753244	 accuracy : 0.7083333333333334
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 60%|██████    | 600/1000 [02:47<01:52,  3.57it/s]

epoch 600	loss : 0.0011021228274330497	 accuracy : 0.710182767624021
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 61%|██████    | 610/1000 [02:50<01:49,  3.56it/s]

epoch 610	loss : 0.001045017852447927	 accuracy : 0.710182767624021
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 62%|██████▏   | 620/1000 [02:52<01:46,  3.56it/s]

epoch 620	loss : 0.000995131442323327	 accuracy : 0.710182767624021
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 63%|██████▎   | 630/1000 [02:55<01:44,  3.54it/s]

epoch 630	loss : 0.000951204274315387	 accuracy : 0.710182767624021
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 64%|██████▍   | 640/1000 [02:58<01:40,  3.57it/s]

epoch 640	loss : 0.0009118199814110994	 accuracy : 0.710182767624021
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 65%|██████▌   | 650/1000 [03:01<01:38,  3.57it/s]

epoch 650	loss : 0.0008767804247327149	 accuracy : 0.7083333333333334
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 66%|██████▌   | 660/1000 [03:04<01:35,  3.57it/s]

epoch 660	loss : 0.0008452608017250896	 accuracy : 0.7083333333333334
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 67%|██████▋   | 670/1000 [03:07<01:32,  3.56it/s]

epoch 670	loss : 0.0008167146006599069	 accuracy : 0.7064935064935065
              precision    recall  f1-score   support

         0.0       0.95      0.92      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 68%|██████▊   | 680/1000 [03:09<01:29,  3.56it/s]

epoch 680	loss : 0.0007907137041911483	 accuracy : 0.7064935064935065
              precision    recall  f1-score   support

         0.0       0.95      0.92      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 69%|██████▉   | 690/1000 [03:12<01:26,  3.56it/s]

epoch 690	loss : 0.0007671211496926844	 accuracy : 0.7064935064935065
              precision    recall  f1-score   support

         0.0       0.95      0.92      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 70%|███████   | 700/1000 [03:15<01:24,  3.57it/s]

epoch 700	loss : 0.0007453823927789927	 accuracy : 0.7083333333333334
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 71%|███████   | 710/1000 [03:18<01:21,  3.56it/s]

epoch 710	loss : 0.0007254990050569177	 accuracy : 0.710182767624021
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 72%|███████▏  | 720/1000 [03:21<01:18,  3.57it/s]

epoch 720	loss : 0.0007071365835145116	 accuracy : 0.710182767624021
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 73%|███████▎  | 730/1000 [03:23<01:15,  3.56it/s]

epoch 730	loss : 0.0006901726592332125	 accuracy : 0.710182767624021
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 74%|███████▍  | 740/1000 [03:26<01:12,  3.57it/s]

epoch 740	loss : 0.0006743734120391309	 accuracy : 0.710182767624021
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 75%|███████▌  | 750/1000 [03:29<01:10,  3.57it/s]

epoch 750	loss : 0.0006597572937607765	 accuracy : 0.710182767624021
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 76%|███████▌  | 760/1000 [03:32<01:07,  3.57it/s]

epoch 760	loss : 0.0006461243028752506	 accuracy : 0.7083333333333334
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 77%|███████▋  | 770/1000 [03:35<01:04,  3.57it/s]

epoch 770	loss : 0.0006333880010060966	 accuracy : 0.7049608355091384
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.70       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 78%|███████▊  | 780/1000 [03:37<01:01,  3.57it/s]

epoch 780	loss : 0.0006214801105670631	 accuracy : 0.7068062827225131
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 79%|███████▉  | 790/1000 [03:40<00:58,  3.57it/s]

epoch 790	loss : 0.0006102576735429466	 accuracy : 0.7068062827225131
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 80%|████████  | 800/1000 [03:43<00:56,  3.57it/s]

epoch 800	loss : 0.0005997555563226342	 accuracy : 0.7068062827225131
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 81%|████████  | 810/1000 [03:46<00:53,  3.57it/s]

epoch 810	loss : 0.0005898637464269996	 accuracy : 0.7068062827225131
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 82%|████████▏ | 820/1000 [03:49<00:50,  3.57it/s]

epoch 820	loss : 0.0005805211258120835	 accuracy : 0.7068062827225131
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 83%|████████▎ | 830/1000 [03:51<00:47,  3.57it/s]

epoch 830	loss : 0.0005716988816857338	 accuracy : 0.7068062827225131
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.84      0.82      1115
weighted avg       0.91      0.90      0.90      1115



 84%|████████▍ | 840/1000 [03:54<00:44,  3.57it/s]

epoch 840	loss : 0.000563319306820631	 accuracy : 0.710182767624021
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 85%|████████▌ | 850/1000 [03:57<00:42,  3.56it/s]

epoch 850	loss : 0.0005553694209083915	 accuracy : 0.710182767624021
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.66      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 86%|████████▌ | 860/1000 [04:00<00:39,  3.57it/s]

epoch 860	loss : 0.0005478623206727207	 accuracy : 0.7120418848167538
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 87%|████████▋ | 870/1000 [04:03<00:36,  3.56it/s]

epoch 870	loss : 0.000540715700481087	 accuracy : 0.7120418848167538
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 88%|████████▊ | 880/1000 [04:05<00:33,  3.56it/s]

epoch 880	loss : 0.0005339164636097848	 accuracy : 0.7120418848167538
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 89%|████████▉ | 890/1000 [04:08<00:30,  3.57it/s]

epoch 890	loss : 0.0005274167633615434	 accuracy : 0.7120418848167538
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 90%|█████████ | 900/1000 [04:11<00:28,  3.57it/s]

epoch 900	loss : 0.00052122981287539	 accuracy : 0.7139107611548555
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 91%|█████████ | 910/1000 [04:14<00:25,  3.57it/s]

epoch 910	loss : 0.0005153032834641635	 accuracy : 0.7139107611548555
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 92%|█████████▏| 920/1000 [04:17<00:22,  3.56it/s]

epoch 920	loss : 0.0005096285021863878	 accuracy : 0.7139107611548555
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.76      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 93%|█████████▎| 930/1000 [04:19<00:19,  3.57it/s]

epoch 930	loss : 0.0005041666445322335	 accuracy : 0.7172774869109947
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.77      0.72       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.91      1115



 94%|█████████▍| 940/1000 [04:22<00:16,  3.56it/s]

epoch 940	loss : 0.000498988782055676	 accuracy : 0.7135416666666665
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.77      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 95%|█████████▌| 950/1000 [04:25<00:14,  3.57it/s]

epoch 950	loss : 0.0004939599893987179	 accuracy : 0.7135416666666665
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.77      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 96%|█████████▌| 960/1000 [04:28<00:11,  3.57it/s]

epoch 960	loss : 0.0004891288699582219	 accuracy : 0.7135416666666665
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.77      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 97%|█████████▋| 970/1000 [04:31<00:08,  3.57it/s]

epoch 970	loss : 0.0004844973445869982	 accuracy : 0.7135416666666665
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.77      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 98%|█████████▊| 980/1000 [04:33<00:05,  3.57it/s]

epoch 980	loss : 0.0004800430906470865	 accuracy : 0.7135416666666665
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.77      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



 99%|█████████▉| 990/1000 [04:36<00:02,  3.53it/s]

epoch 990	loss : 0.00047575627104379237	 accuracy : 0.7135416666666665
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.77      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115



100%|██████████| 1000/1000 [04:39<00:00,  3.58it/s]

epoch 1000	loss : 0.00047154241474345326	 accuracy : 0.7135416666666665
              precision    recall  f1-score   support

         0.0       0.95      0.93      0.94       937
         1.0       0.67      0.77      0.71       178

    accuracy                           0.90      1115
   macro avg       0.81      0.85      0.83      1115
weighted avg       0.91      0.90      0.90      1115






In [None]:
all_losses

In [None]:
# Score accuracy if at least one is true

category_tensor[0].dtype

### TODO
- Apply the model to other activities and plot heatmap
- Check any paper talked about approach for weight for window size
- Multi-label, consider the correlation between labels
- Consider more data from hh1xx dataset

### Mean and std

### Mean, std, max, min and crossing rate

### Mean with weight for each time window