In [13]:
import torch
from torch.utils.data import Dataset

class SquatRepDataset(Dataset):
    def __init__(self, data_list, label_list):
        self.data_list = data_list  # list of dicts like the one you provided
        self.label_list = label_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        data = self.data_list[idx]
        sequence = torch.tensor([
            data["knee_angle"],
            data["torso_angle"],
            data["hip_angle"],
            data["symmetry_score"],
            data["alignment_score"],
            data["head_angle"],
            data["toe_distance"],
            data["heel_angle"],
            data["back_angle"]
        ], dtype=torch.float).T  # shape: [seq_len, num_features]

        label = torch.tensor(self.label_list[idx], dtype=torch.long)
        return sequence, label


In [14]:
import torch.nn as nn

class SquatClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(SquatClassifier, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x_packed):
        packed_output, (h_n, c_n) = self.lstm(x_packed)
        out = self.fc(h_n[-1])  # Final hidden state from last LSTM layer
        return out



In [None]:
# Example data prep
import json
from glob import glob
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence

def squat_collate_fn(batch):
    sequences, labels = zip(*batch)
    lengths = torch.tensor([seq.shape[0] for seq in sequences])  # original lengths
    padded_sequences = pad_sequence(sequences, batch_first=True)
    labels = torch.stack(labels)
    return padded_sequences, lengths, labels

# print(data_list[0])
    
# data_list = [your_json_dict]  # can add more samples here
# label_list = [1]*len(data_list_new)  # class index for this sample

bad_data_list=glob('temp_data/*')
print(bad_data_list)
bad_list_new_list=[]
labels_list=[]
for i in range(len(bad_data_list)):
    bad_list_new_list.extend(glob(bad_data_list[i]+'/*.json'))
    one_hot=[0]*6
    one_hot[i]=1
    # print(one_hot)
    for i in range(len(glob(bad_data_list[i]+'/*.json'))):
        labels_list.append(one_hot)
    # break
# print(labels_list[:2])

bad_data_list_new=[json.load(open(file)) for file in bad_list_new_list]
# print(bad_list_new_list[0])
with open(bad_list_new_list[0]) as f:
    data = json.load(f)
# print(data[1])
bad_data_list_new_actual=[]
label_list_actual=[]
for i in bad_data_list_new:
    for k in range(len(i)):
        label_list_actual.append(labels_list[bad_data_list_new.index(i)])
    for j in i:
        bad_data_list_new_actual.append(j)
bad_data_list=bad_data_list_new_actual
bad_label_list=label_list_actual
# print(data_list[0])
# print(bad_data_list[0])
# print(bad_label_list[0])


    
# data_list = [your_json_dict]  # can add more samples here
# bad_label_list = [0]*len(bad_data_list_new)  # class index for this sample

dataset = SquatRepDataset(bad_data_list, bad_label_list)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=squat_collate_fn)
train_loader,val_loader=torch.utils.data.random_split(dataloader, [int(len(dataloader)*0.8), int(len(dataloader)*0.2)])
# print(iter(dataloader).next())  # print one batch of data
# print(dataset[1][0].size())
# # Model
model = SquatClassifier(input_size=9, hidden_size=64, num_classes=6)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Training loop (for demonstration, not epoched)
for epoch in range(100):
    avg_loss=0
    for X_batch, lengths, y_batch in tqdm(dataloader, desc="Training"):
        optimizer.zero_grad()

        # Pack sequences
        packed_input = nn.utils.rnn.pack_padded_sequence(X_batch, lengths, batch_first=True, enforce_sorted=False)

        # Pass to model (LSTM expects PackedSequence)
        packed_output, (h_n, c_n) = model.lstm(packed_input)

        # Use the final hidden state from LSTM (last layer)
        logits = model.fc(h_n[-1])  # shape: (batch_size, num_classes)
        y_batch = torch.argmax(y_batch, dim=1)  # from one-hot to class index
        loss = criterion(logits, y_batch)

        loss.backward()
        optimizer.step()
        # print(f"Loss: {loss.item()}")
        avg_loss+=loss.item()
    avg_loss/=len(dataloader)
    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")


['temp_data\\bad_back_warp', 'temp_data\\bad_head', 'temp_data\\bad_inner_thigh', 'temp_data\\bad_shallow', 'temp_data\\bad_toe', 'temp_data\\good']
{'rep_number': 1, 'start_frame': 2, 'end_frame': 82, 'knee_angle': [102.8005691707072, 113.91209659355584, 122.4856580896193, 124.23367348207394, 125.13090771159824, 125.46758144236492, 131.4560942316897, 137.0548553818299, 143.43744864880958, 147.04516211824085, 149.57033533940643, 149.6766748773775, 148.8539307388724, 152.41428322566145, 154.56602008748916, 155.9728300735592, 157.4139006257187, 159.01114993160354, 161.03929817062624, 162.07943527335277, 160.59734225880146, 163.8540089276289, 163.46088574128026, 165.15614859681506, 164.13001383150026, 165.24503357434867, 161.72190632739733, 159.70694843314033, 159.24182051193833, 158.96956331180675, 155.23073948564564, 156.12767506704955, 156.71975271596943, 158.137711947676, 161.0935851340399, 161.1630859972194, 161.05230076327302, 160.0993778656175, 160.05478846976797, 160.2069028518410

Training: 100%|██████████| 95/95 [00:02<00:00, 38.10it/s]


Epoch 1, Loss: 1.7438


Training: 100%|██████████| 95/95 [00:02<00:00, 39.58it/s]


Epoch 2, Loss: 1.6713


Training: 100%|██████████| 95/95 [00:02<00:00, 38.69it/s]


Epoch 3, Loss: 1.6338


Training: 100%|██████████| 95/95 [00:02<00:00, 39.90it/s]


Epoch 4, Loss: 1.6053


Training: 100%|██████████| 95/95 [00:02<00:00, 40.12it/s]


Epoch 5, Loss: 1.5823


Training: 100%|██████████| 95/95 [00:02<00:00, 38.12it/s]


Epoch 6, Loss: 1.5462


Training: 100%|██████████| 95/95 [00:02<00:00, 37.82it/s]


Epoch 7, Loss: 1.5139


Training: 100%|██████████| 95/95 [00:02<00:00, 38.77it/s]


Epoch 8, Loss: 1.4831


Training: 100%|██████████| 95/95 [00:02<00:00, 39.81it/s]


Epoch 9, Loss: 1.4634


Training: 100%|██████████| 95/95 [00:02<00:00, 40.78it/s]


Epoch 10, Loss: 1.4704


Training: 100%|██████████| 95/95 [00:02<00:00, 38.88it/s]


Epoch 11, Loss: 1.4765


Training: 100%|██████████| 95/95 [00:02<00:00, 38.03it/s]


Epoch 12, Loss: 1.4277


Training: 100%|██████████| 95/95 [00:02<00:00, 38.33it/s]


Epoch 13, Loss: 1.4206


Training: 100%|██████████| 95/95 [00:02<00:00, 38.46it/s]


Epoch 14, Loss: 1.3923


Training: 100%|██████████| 95/95 [00:02<00:00, 40.68it/s]


Epoch 15, Loss: 1.3694


Training: 100%|██████████| 95/95 [00:02<00:00, 39.49it/s]


Epoch 16, Loss: 1.3729


Training: 100%|██████████| 95/95 [00:02<00:00, 39.41it/s]


Epoch 17, Loss: 1.3658


Training: 100%|██████████| 95/95 [00:02<00:00, 38.75it/s]


Epoch 18, Loss: 1.3453


Training: 100%|██████████| 95/95 [00:02<00:00, 39.23it/s]


Epoch 19, Loss: 1.3510


Training: 100%|██████████| 95/95 [00:02<00:00, 40.76it/s]


Epoch 20, Loss: 1.3129


Training: 100%|██████████| 95/95 [00:02<00:00, 38.34it/s]


Epoch 21, Loss: 1.3110


Training: 100%|██████████| 95/95 [00:02<00:00, 39.93it/s]


Epoch 22, Loss: 1.3076


Training: 100%|██████████| 95/95 [00:02<00:00, 38.69it/s]


Epoch 23, Loss: 1.3045


Training: 100%|██████████| 95/95 [00:02<00:00, 39.69it/s]


Epoch 24, Loss: 1.2678


Training: 100%|██████████| 95/95 [00:02<00:00, 38.68it/s]


Epoch 25, Loss: 1.2448


Training: 100%|██████████| 95/95 [00:02<00:00, 36.11it/s]


Epoch 26, Loss: 1.2489


Training: 100%|██████████| 95/95 [00:02<00:00, 37.12it/s]


Epoch 27, Loss: 1.2363


Training: 100%|██████████| 95/95 [00:02<00:00, 37.59it/s]


Epoch 28, Loss: 1.2454


Training: 100%|██████████| 95/95 [00:02<00:00, 36.95it/s]


Epoch 29, Loss: 1.1910


Training: 100%|██████████| 95/95 [00:02<00:00, 37.17it/s]


Epoch 30, Loss: 1.2179


Training: 100%|██████████| 95/95 [00:02<00:00, 37.97it/s]


Epoch 31, Loss: 1.2005


Training: 100%|██████████| 95/95 [00:02<00:00, 36.92it/s]


Epoch 32, Loss: 1.1700


Training: 100%|██████████| 95/95 [00:02<00:00, 34.76it/s]


Epoch 33, Loss: 1.1671


Training: 100%|██████████| 95/95 [00:02<00:00, 34.42it/s]


Epoch 34, Loss: 1.1645


Training: 100%|██████████| 95/95 [00:02<00:00, 34.45it/s]


Epoch 35, Loss: 1.1658


Training: 100%|██████████| 95/95 [00:02<00:00, 36.68it/s]


Epoch 36, Loss: 1.1672


Training: 100%|██████████| 95/95 [00:02<00:00, 36.80it/s]


Epoch 37, Loss: 1.1445


Training: 100%|██████████| 95/95 [00:02<00:00, 39.39it/s]


Epoch 38, Loss: 1.1470


Training: 100%|██████████| 95/95 [00:02<00:00, 38.84it/s]


Epoch 39, Loss: 1.1503


Training: 100%|██████████| 95/95 [00:02<00:00, 37.57it/s]


Epoch 40, Loss: 1.1204


Training: 100%|██████████| 95/95 [00:02<00:00, 39.36it/s]


Epoch 41, Loss: 1.1283


Training: 100%|██████████| 95/95 [00:02<00:00, 39.19it/s]


Epoch 42, Loss: 1.1504


Training: 100%|██████████| 95/95 [00:02<00:00, 35.71it/s]


Epoch 43, Loss: 1.1245


Training: 100%|██████████| 95/95 [00:02<00:00, 36.36it/s]


Epoch 44, Loss: 1.1220


Training: 100%|██████████| 95/95 [00:02<00:00, 36.77it/s]


Epoch 45, Loss: 1.1440


Training: 100%|██████████| 95/95 [00:02<00:00, 35.13it/s]


Epoch 46, Loss: 1.1136


Training: 100%|██████████| 95/95 [00:02<00:00, 37.35it/s]


Epoch 47, Loss: 1.1077


Training: 100%|██████████| 95/95 [00:02<00:00, 33.87it/s]


Epoch 48, Loss: 1.0810


Training: 100%|██████████| 95/95 [00:02<00:00, 36.78it/s]


Epoch 49, Loss: 1.0947


Training: 100%|██████████| 95/95 [00:02<00:00, 37.54it/s]


Epoch 50, Loss: 1.1031


Training: 100%|██████████| 95/95 [00:02<00:00, 38.53it/s]


Epoch 51, Loss: 1.0789


Training: 100%|██████████| 95/95 [00:02<00:00, 37.92it/s]


Epoch 52, Loss: 1.1046


Training: 100%|██████████| 95/95 [00:02<00:00, 38.29it/s]


Epoch 53, Loss: 1.0473


Training: 100%|██████████| 95/95 [00:02<00:00, 39.28it/s]


Epoch 54, Loss: 1.0325


Training: 100%|██████████| 95/95 [00:02<00:00, 38.20it/s]


Epoch 55, Loss: 1.0396


Training: 100%|██████████| 95/95 [00:02<00:00, 37.62it/s]


Epoch 56, Loss: 1.0228


Training: 100%|██████████| 95/95 [00:02<00:00, 38.10it/s]


Epoch 57, Loss: 1.0305


Training: 100%|██████████| 95/95 [00:02<00:00, 38.10it/s]


Epoch 58, Loss: 1.0268


Training: 100%|██████████| 95/95 [00:02<00:00, 38.08it/s]


Epoch 59, Loss: 1.0010


Training: 100%|██████████| 95/95 [00:02<00:00, 36.65it/s]


Epoch 60, Loss: 1.0260


Training: 100%|██████████| 95/95 [00:02<00:00, 37.99it/s]


Epoch 61, Loss: 1.0149


Training: 100%|██████████| 95/95 [00:02<00:00, 37.93it/s]


Epoch 62, Loss: 1.0154


Training: 100%|██████████| 95/95 [00:02<00:00, 36.74it/s]


Epoch 63, Loss: 1.0056


Training: 100%|██████████| 95/95 [00:02<00:00, 37.31it/s]


Epoch 64, Loss: 1.0087


Training: 100%|██████████| 95/95 [00:02<00:00, 38.32it/s]


Epoch 65, Loss: 0.9937


Training: 100%|██████████| 95/95 [00:02<00:00, 32.78it/s]


Epoch 66, Loss: 1.0089


Training: 100%|██████████| 95/95 [00:02<00:00, 36.60it/s]


Epoch 67, Loss: 0.9890


Training: 100%|██████████| 95/95 [00:02<00:00, 35.30it/s]


Epoch 68, Loss: 0.9764


Training: 100%|██████████| 95/95 [00:02<00:00, 36.53it/s]


Epoch 69, Loss: 0.9779


Training: 100%|██████████| 95/95 [00:02<00:00, 33.08it/s]


Epoch 70, Loss: 0.9751


Training: 100%|██████████| 95/95 [00:02<00:00, 35.91it/s]


Epoch 71, Loss: 0.9712


Training: 100%|██████████| 95/95 [00:02<00:00, 34.81it/s]


Epoch 72, Loss: 0.9793


Training: 100%|██████████| 95/95 [00:02<00:00, 36.75it/s]


Epoch 73, Loss: 1.0248


Training: 100%|██████████| 95/95 [00:02<00:00, 34.62it/s]


Epoch 74, Loss: 0.9926


Training: 100%|██████████| 95/95 [00:02<00:00, 36.49it/s]


Epoch 75, Loss: 0.9814


Training: 100%|██████████| 95/95 [00:02<00:00, 33.87it/s]


Epoch 76, Loss: 0.9446


Training: 100%|██████████| 95/95 [00:02<00:00, 34.91it/s]


Epoch 77, Loss: 0.9573


Training: 100%|██████████| 95/95 [00:02<00:00, 36.03it/s]


Epoch 78, Loss: 0.9852


Training: 100%|██████████| 95/95 [00:02<00:00, 35.98it/s]


Epoch 79, Loss: 0.9197


Training: 100%|██████████| 95/95 [00:02<00:00, 32.60it/s]


Epoch 80, Loss: 0.9250


Training: 100%|██████████| 95/95 [00:02<00:00, 37.24it/s]


Epoch 81, Loss: 0.9335


Training: 100%|██████████| 95/95 [00:02<00:00, 37.37it/s]


Epoch 82, Loss: 0.9434


Training: 100%|██████████| 95/95 [00:02<00:00, 34.40it/s]


Epoch 83, Loss: 0.9451


Training: 100%|██████████| 95/95 [00:02<00:00, 34.14it/s]


Epoch 84, Loss: 0.8869


Training: 100%|██████████| 95/95 [00:03<00:00, 28.12it/s]


Epoch 85, Loss: 0.9500


Training: 100%|██████████| 95/95 [00:02<00:00, 32.35it/s]


Epoch 86, Loss: 0.9274


Training: 100%|██████████| 95/95 [00:03<00:00, 29.49it/s]


Epoch 87, Loss: 0.9555


Training: 100%|██████████| 95/95 [00:02<00:00, 35.76it/s]


Epoch 88, Loss: 0.9341


Training: 100%|██████████| 95/95 [00:03<00:00, 29.24it/s]


Epoch 89, Loss: 0.9373


Training: 100%|██████████| 95/95 [00:03<00:00, 26.55it/s]


Epoch 90, Loss: 0.8872


Training: 100%|██████████| 95/95 [00:02<00:00, 32.60it/s]


Epoch 91, Loss: 0.8948


Training: 100%|██████████| 95/95 [00:02<00:00, 32.63it/s]


Epoch 92, Loss: 0.9059


Training: 100%|██████████| 95/95 [00:03<00:00, 30.80it/s]


Epoch 93, Loss: 0.8919


Training: 100%|██████████| 95/95 [00:02<00:00, 32.12it/s]


Epoch 94, Loss: 0.8657


Training: 100%|██████████| 95/95 [00:03<00:00, 29.53it/s]


Epoch 95, Loss: 0.8598


Training: 100%|██████████| 95/95 [00:03<00:00, 27.10it/s]


Epoch 96, Loss: 0.8903


Training: 100%|██████████| 95/95 [00:03<00:00, 29.15it/s]


Epoch 97, Loss: 0.9001


Training: 100%|██████████| 95/95 [00:02<00:00, 32.83it/s]


Epoch 98, Loss: 0.8666


Training: 100%|██████████| 95/95 [00:02<00:00, 33.79it/s]


Epoch 99, Loss: 0.9084


Training: 100%|██████████| 95/95 [00:02<00:00, 33.08it/s]

Epoch 100, Loss: 0.8835





In [62]:
torch.save(model.state_dict(), 'squat_classifier.pth')

In [50]:
with open('temp_data/bad_back_warp/rep_metrics_0918_squat_000010.json') as f:
    data = json.load(f)
print(data[1])

{'rep_number': 2, 'start_frame': 82, 'end_frame': 156, 'knee_angle': [124.16441978570647, 130.96553520402105, 131.64582712399448, 137.27906987436342, 135.26355534266918, 138.60424044112585, 143.2184553278396, 144.82204794138158, 144.8207910610513, 145.6226769910278, 147.97534945997094, 152.34888490014322, 147.90362323433322, 160.3455422840239, 159.58675237523295, 162.4428614434911, 158.03828325717126, 161.29662452973952, 163.3872190537278, 163.654650930814, 164.23001837848642, 162.60574696257925, 162.04662921417548, 162.76508158065556, 160.5059524914733, 160.6326679858921, 161.86499512922785, 161.21635111925153, 161.55576077098942, 162.12622109059532, 161.9819949855299, 161.6932530774037, 161.66029037515216, 163.17646968024513, 162.22538795387683, 161.5817015656864, 160.51330863125028, 160.99033184512223, 161.51282808179263, 161.61229042262164, 157.42449251825343, 159.44232030841755, 154.0305009588413, 147.81131522097223, 138.87210625829073, 136.59304050252717, 132.38326836297796, 129.