<a href="https://colab.research.google.com/github/shitkov/movements/blob/main/neck_static.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install mediapipe

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms
from torchvision import models

import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

import cv2
from google.colab.patches import cv2_imshow
import mediapipe as mp

from os import listdir
from os.path import isfile, join

In [4]:
path_0 = [
    '/content/drive/MyDrive/ml/neck/static_correct_1.MOV'
]

In [5]:
path_0_90 = ['/content/drive/MyDrive/ml/neck/static_me.mp4']

In [6]:
path_1 = [
    '/content/drive/MyDrive/ml/neck/static_correct_2.mp4',
    '/content/drive/MyDrive/ml/neck/static_correct_3.MOV',
    '/content/drive/MyDrive/ml/neck/static_head_goes_forward_1.mp4',
    '/content/drive/MyDrive/ml/neck/static_head_goes_forward_2.mp4',
    '/content/drive/MyDrive/ml/neck/static_nearest_shoulder_rises_1.mp4',
    '/content/drive/MyDrive/ml/neck/static_nearest_shoulder_rises_2.mp4',
    '/content/drive/MyDrive/ml/neck/static_shoulder_rises_1.mp4',
    '/content/drive/MyDrive/ml/neck/static_shoulder_rises_2.mp4'

]

In [7]:
key_points = [0, 7, 8, 11, 12, 23, 24, 25, 26]

In [8]:
def get_points(path, key_points):
    vidcap = cv2.VideoCapture(path)
    points_list = []
    for _ in tqdm(range(int(cv2.VideoCapture(path).get(cv2.CAP_PROP_FRAME_COUNT)))):
        _, image = vidcap.read()
        try:
            with mp.solutions.pose.Pose(static_image_mode=False, min_detection_confidence=0.3, model_complexity=1) as pose:
                results = pose.process(image)
            points = []
            for i, point in enumerate(results.pose_world_landmarks.landmark):
                if i in key_points:
                    points.append([point.x, point.y, point.z])
            points_list.append(np.array(points).reshape(-1))
        except:
            # points_list.append(np.zeros(len(key_points) * 3))
            pass
    return np.array(points_list)

In [9]:
def get_points_rotate(path, key_points):
    vidcap = cv2.VideoCapture(path)
    points_list = []
    for _ in tqdm(range(int(cv2.VideoCapture(path).get(cv2.CAP_PROP_FRAME_COUNT)))):
        _, image = vidcap.read()
        image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
        try:
            with mp.solutions.pose.Pose(static_image_mode=False, min_detection_confidence=0.3, model_complexity=1) as pose:
                results = pose.process(image)
            points = []
            for i, point in enumerate(results.pose_world_landmarks.landmark):
                if i in key_points:
                    points.append([point.x, point.y, point.z])
            points_list.append(np.array(points).reshape(-1))
        except:
            # points_list.append(np.zeros(len(key_points) * 3))
            pass
    return np.array(points_list)

In [10]:
points_0 = get_points(path_0[0], key_points)

100%|██████████| 505/505 [01:46<00:00,  4.73it/s]


In [11]:
points_0_90 = get_points_rotate(path_0_90[0], key_points)

100%|██████████| 787/787 [02:42<00:00,  4.84it/s]


In [12]:
points_0 = points_0.tolist()

In [13]:
points_0_90 = points_0_90.tolist()

In [14]:
points_0 += points_0_90

In [15]:
points_1_list = []

for path in path_1:
    points_1_list.append(get_points(path, key_points))

100%|██████████| 406/406 [01:18<00:00,  5.20it/s]
100%|██████████| 255/255 [00:49<00:00,  5.12it/s]
100%|██████████| 341/341 [01:06<00:00,  5.12it/s]
100%|██████████| 371/371 [01:12<00:00,  5.11it/s]
100%|██████████| 368/368 [01:11<00:00,  5.12it/s]
100%|██████████| 341/341 [01:06<00:00,  5.12it/s]
100%|██████████| 361/361 [01:12<00:00,  4.98it/s]
100%|██████████| 338/338 [01:06<00:00,  5.10it/s]


In [16]:
points_1 = []
for arr in points_1_list:
    points_1 += arr.tolist()

In [17]:
import pickle

with open('/content/drive/MyDrive/ml/neck/' + 'points_0.p', 'wb') as f:
    pickle.dump(points_0, f)

with open('/content/drive/MyDrive/ml/neck/' + 'points_1.p', 'wb') as f:
    pickle.dump(points_1, f)

In [18]:
data = pd.DataFrame()
data['points'] = points_0 + points_1

data['label'] = [0] * len(points_0) + [1] * len(points_1)

In [19]:
data = data.sample(frac=1)

In [20]:
data = data.reset_index(drop = True)

In [21]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(data, stratify=list(data['label']), test_size = 0.2)

In [22]:
train.to_csv('/content/drive/MyDrive/ml/neck/static_train.csv', index=False)
test.to_csv('/content/drive/MyDrive/ml/neck/static_test.csv', index=False)

In [23]:
class CustomDataset(Dataset):
    
    def __init__(self, points, labels):
        self.points = points
        self.labels = labels
        
    def __getitem__(self, index):
        return self.points[index], self.labels[index]
        
    def __len__ (self):
        return len(self.points)

In [24]:
train_dataset = CustomDataset(torch.FloatTensor(list(train['points'])), torch.FloatTensor(list(train['label'])))

In [87]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

In [88]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [89]:
class SimleNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimleNetwork,self).__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim * 2)
        self.linear2 = nn.Linear(hidden_dim * 2, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, output_dim)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.relu(x)
        x = self.linear3(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.out(x)
        return x

In [90]:
input_dim = len(list(train['points'])[0])
hidden_dim = 128
output_dim = 1

In [91]:
loss_fn = nn.BCEWithLogitsLoss()

In [92]:
model = SimleNetwork(input_dim, hidden_dim, output_dim)

In [93]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

model.apply(init_weights)

SimleNetwork(
  (linear1): Linear(in_features=27, out_features=256, bias=True)
  (linear2): Linear(in_features=256, out_features=128, bias=True)
  (linear3): Linear(in_features=128, out_features=128, bias=True)
  (out): Linear(in_features=128, out_features=1, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.1, inplace=False)
)

In [94]:
model.to(device)

SimleNetwork(
  (linear1): Linear(in_features=27, out_features=256, bias=True)
  (linear2): Linear(in_features=256, out_features=128, bias=True)
  (linear3): Linear(in_features=128, out_features=128, bias=True)
  (out): Linear(in_features=128, out_features=1, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.1, inplace=False)
)

In [95]:
optimizer = optim.AdamW(model.parameters(), lr=0.0001)

In [96]:
def eval(model, x, y):
    model.eval()
    with torch.no_grad():
        X = torch.FloatTensor(x)
        y = torch.FloatTensor(y)
        X = X.to(device)
        y_test_pred = model(X)
        y_test_pred = torch.sigmoid(y_test_pred)
        y_test_pred = torch.round(y_test_pred)
        y_test_pred = y_test_pred.reshape(1, -1).cpu().detach().numpy()[0]

    return y_test_pred

In [97]:
test_points = torch.FloatTensor(np.array(list(test['points'])))
test_labels = torch.FloatTensor(np.array(list(test['label'])))

In [98]:
for epoch in range(100):
    model.train()
    for data in tqdm(train_loader):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        try:
            outputs = torch.sigmoid(model(inputs))
            loss = loss_fn(outputs, labels.reshape(-1, 1))
            loss.backward()
            optimizer.step()
        except:
            pass

    predictions = eval(model, test_points, test_labels)
    print(epoch, f1_score(test_labels.tolist(), predictions.tolist(), average='macro'))

100%|██████████| 102/102 [00:00<00:00, 455.86it/s]


0 0.8533034538958191


100%|██████████| 102/102 [00:00<00:00, 446.21it/s]


1 0.8595754661208319


100%|██████████| 102/102 [00:00<00:00, 472.53it/s]


2 0.8621698171294871


100%|██████████| 102/102 [00:00<00:00, 472.11it/s]


3 0.8666743744640673


100%|██████████| 102/102 [00:00<00:00, 479.48it/s]


4 0.8660774410774411


100%|██████████| 102/102 [00:00<00:00, 462.85it/s]


5 0.870189656596864


100%|██████████| 102/102 [00:00<00:00, 424.24it/s]


6 0.870189656596864


100%|██████████| 102/102 [00:00<00:00, 435.53it/s]


7 0.870189656596864


100%|██████████| 102/102 [00:00<00:00, 407.89it/s]


8 0.8647135349506273


100%|██████████| 102/102 [00:00<00:00, 424.14it/s]


9 0.870189656596864


100%|██████████| 102/102 [00:00<00:00, 436.13it/s]


10 0.8715673295979312


100%|██████████| 102/102 [00:00<00:00, 431.45it/s]


11 0.8715673295979312


100%|██████████| 102/102 [00:00<00:00, 433.76it/s]


12 0.8715673295979312


100%|██████████| 102/102 [00:00<00:00, 414.85it/s]


13 0.8664586083492938


100%|██████████| 102/102 [00:00<00:00, 416.77it/s]


14 0.870189656596864


100%|██████████| 102/102 [00:00<00:00, 425.42it/s]


15 0.8715673295979312


100%|██████████| 102/102 [00:00<00:00, 433.26it/s]


16 0.8715673295979312


100%|██████████| 102/102 [00:00<00:00, 414.21it/s]


17 0.8715673295979312


100%|██████████| 102/102 [00:00<00:00, 422.20it/s]


18 0.874400662913964


100%|██████████| 102/102 [00:00<00:00, 427.70it/s]


19 0.8715673295979312


100%|██████████| 102/102 [00:00<00:00, 421.44it/s]


20 0.8754208754208754


100%|██████████| 102/102 [00:00<00:00, 415.50it/s]


21 0.8781868776129129


100%|██████████| 102/102 [00:00<00:00, 432.02it/s]


22 0.8754208754208754


100%|██████████| 102/102 [00:00<00:00, 427.95it/s]


23 0.8768015847623718


100%|██████████| 102/102 [00:00<00:00, 424.46it/s]


24 0.8788569533953623


100%|██████████| 102/102 [00:00<00:00, 413.07it/s]


25 0.8778462678203387


100%|██████████| 102/102 [00:00<00:00, 420.82it/s]


26 0.8802263723825694


100%|██████████| 102/102 [00:00<00:00, 407.56it/s]


27 0.8806030993132745


100%|██████████| 102/102 [00:00<00:00, 420.20it/s]


28 0.883616278369755


100%|██████████| 102/102 [00:00<00:00, 411.58it/s]


29 0.8866790078863404


100%|██████████| 102/102 [00:00<00:00, 403.90it/s]


30 0.890040450584995


100%|██████████| 102/102 [00:00<00:00, 422.57it/s]


31 0.8880576911121703


100%|██████████| 102/102 [00:00<00:00, 427.28it/s]


32 0.9051729398385711


100%|██████████| 102/102 [00:00<00:00, 423.18it/s]


33 0.9044264412351767


100%|██████████| 102/102 [00:00<00:00, 427.49it/s]


34 0.9311494114280674


100%|██████████| 102/102 [00:00<00:00, 422.51it/s]


35 0.9442678716856573


100%|██████████| 102/102 [00:00<00:00, 409.27it/s]


36 0.9433951162106502


100%|██████████| 102/102 [00:00<00:00, 409.85it/s]


37 0.938807157405222


100%|██████████| 102/102 [00:00<00:00, 396.35it/s]


38 0.9518661518661518


100%|██████████| 102/102 [00:00<00:00, 398.63it/s]


39 0.9554972390793286


100%|██████████| 102/102 [00:00<00:00, 369.16it/s]


40 0.9567640734363086


100%|██████████| 102/102 [00:00<00:00, 376.93it/s]


41 0.955908540012052


100%|██████████| 102/102 [00:00<00:00, 395.71it/s]


42 0.9747727272727272


100%|██████████| 102/102 [00:00<00:00, 393.92it/s]


43 0.984537408798515


100%|██████████| 102/102 [00:00<00:00, 383.82it/s]


44 0.9735513163632632


100%|██████████| 102/102 [00:00<00:00, 396.25it/s]


45 0.9804172251434855


100%|██████████| 102/102 [00:00<00:00, 384.36it/s]


46 0.980378787878788


100%|██████████| 102/102 [00:00<00:00, 377.04it/s]


47 0.9873487890169668


100%|██████████| 102/102 [00:00<00:00, 391.69it/s]


48 0.9735513163632632


100%|██████████| 102/102 [00:00<00:00, 381.46it/s]


49 0.985957091496277


100%|██████████| 102/102 [00:00<00:00, 397.43it/s]


50 0.966748366013072


100%|██████████| 102/102 [00:00<00:00, 376.52it/s]


51 0.9804172251434855


100%|██████████| 102/102 [00:00<00:00, 371.64it/s]


52 0.9873487890169668


100%|██████████| 102/102 [00:00<00:00, 340.30it/s]


53 0.9915231880033324


100%|██████████| 102/102 [00:00<00:00, 345.05it/s]


54 0.9915403961671967


100%|██████████| 102/102 [00:00<00:00, 359.46it/s]


55 0.9929431795917447


100%|██████████| 102/102 [00:00<00:00, 362.62it/s]


56 0.9873487890169668


100%|██████████| 102/102 [00:00<00:00, 368.90it/s]


57 0.9749193763736828


100%|██████████| 102/102 [00:00<00:00, 351.37it/s]


58 0.9929431795917447


100%|██████████| 102/102 [00:00<00:00, 362.72it/s]


59 0.9749193763736828


100%|██████████| 102/102 [00:00<00:00, 374.28it/s]


60 0.9901404185736657


100%|██████████| 102/102 [00:00<00:00, 354.62it/s]


61 0.985957091496277


100%|██████████| 102/102 [00:00<00:00, 361.61it/s]


62 0.9929431795917447


100%|██████████| 102/102 [00:00<00:00, 367.79it/s]


63 0.9957572567068181


100%|██████████| 102/102 [00:00<00:00, 366.84it/s]


64 0.9901404185736657


100%|██████████| 102/102 [00:00<00:00, 352.89it/s]


65 0.985957091496277


100%|██████████| 102/102 [00:00<00:00, 345.09it/s]


66 0.9901404185736657


100%|██████████| 102/102 [00:00<00:00, 362.01it/s]


67 0.9929431795917447


100%|██████████| 102/102 [00:00<00:00, 351.31it/s]


68 0.9971685971685972


100%|██████████| 102/102 [00:00<00:00, 368.48it/s]


69 0.985957091496277


100%|██████████| 102/102 [00:00<00:00, 371.74it/s]


70 0.985957091496277


100%|██████████| 102/102 [00:00<00:00, 359.33it/s]


71 0.9915403961671967


100%|██████████| 102/102 [00:00<00:00, 354.71it/s]


72 0.9943487920022216


100%|██████████| 102/102 [00:00<00:00, 365.88it/s]


73 0.9957572567068181


100%|██████████| 102/102 [00:00<00:00, 363.28it/s]


74 0.9929431795917447


100%|██████████| 102/102 [00:00<00:00, 343.09it/s]


75 0.9971685971685972


100%|██████████| 102/102 [00:00<00:00, 358.26it/s]


76 0.9957572567068181


100%|██████████| 102/102 [00:00<00:00, 371.90it/s]


77 0.9887432238079433


100%|██████████| 102/102 [00:00<00:00, 353.21it/s]


78 0.9873487890169668


100%|██████████| 102/102 [00:00<00:00, 361.29it/s]


79 0.9971627349283364


100%|██████████| 102/102 [00:00<00:00, 370.70it/s]


80 0.9845681086884961


100%|██████████| 102/102 [00:00<00:00, 349.55it/s]


81 0.9957572567068181


100%|██████████| 102/102 [00:00<00:00, 364.42it/s]


82 0.9831818181818182


100%|██████████| 102/102 [00:00<00:00, 367.53it/s]


83 0.9804172251434855


100%|██████████| 102/102 [00:00<00:00, 365.87it/s]


84 0.9971685971685972


100%|██████████| 102/102 [00:00<00:00, 339.92it/s]


85 0.9957572567068181


100%|██████████| 102/102 [00:00<00:00, 362.35it/s]


86 0.985957091496277


100%|██████████| 102/102 [00:00<00:00, 363.36it/s]


87 0.9971685971685972


100%|██████████| 102/102 [00:00<00:00, 348.65it/s]


88 0.9887432238079433


100%|██████████| 102/102 [00:00<00:00, 364.69it/s]


89 0.9957572567068181


100%|██████████| 102/102 [00:00<00:00, 359.78it/s]


90 0.9971685971685972


100%|██████████| 102/102 [00:00<00:00, 355.82it/s]


91 0.9957572567068181


100%|██████████| 102/102 [00:00<00:00, 371.09it/s]


92 0.9929431795917447


100%|██████████| 102/102 [00:00<00:00, 365.40it/s]


93 0.9804172251434855


100%|██████████| 102/102 [00:00<00:00, 364.01it/s]


94 0.985957091496277


100%|██████████| 102/102 [00:00<00:00, 359.43it/s]


95 0.9985828370071049


100%|██████████| 102/102 [00:00<00:00, 355.96it/s]


96 0.9957572567068181


100%|██████████| 102/102 [00:00<00:00, 363.04it/s]


97 0.9971685971685972


100%|██████████| 102/102 [00:00<00:00, 354.86it/s]


98 0.9929431795917447


100%|██████████| 102/102 [00:00<00:00, 345.11it/s]


99 0.9985828370071049


In [100]:
# Save NN
path = '/content/drive/MyDrive/ml/neck/neck_v0.pt'
torch.save(model.state_dict(), path)

In [101]:
model_neck = SimleNetwork(27, 128, 1)

In [102]:
model_neck.load_state_dict(torch.load(path))

<All keys matched successfully>

In [104]:
predicts = eval(model_neck, test_points, test_labels)

In [105]:
f1_score(test_labels.tolist(), predicts.tolist(), average='macro')

0.9985828370071049