<a href="https://colab.research.google.com/github/pascalghanimi/Injury-Prediction-in-Runners/blob/main/Predictions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [20]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler, LabelEncoder
import pickle

# ─── Attention Klasse definieren ───────────────────────────────────────────────
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attn = nn.Linear(hidden_size * 4, hidden_size)  # Korrigierte Dimension
        self.v = nn.Linear(hidden_size, 1, bias=False)

    def forward(self, hidden, outputs):
        hidden = hidden.unsqueeze(1).repeat(1, outputs.size(1), 1)
        combined = torch.cat((hidden, outputs), dim=2)
        energy = torch.tanh(self.attn(combined))
        attention = torch.softmax(self.v(energy).squeeze(2), dim=1)
        return torch.sum(attention.unsqueeze(2) * outputs, dim=1)

# ─── Modellklasse definieren ───────────────────────────────────────────────────
class SkiSwingLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size,
            hidden_size,
            num_layers=2,
            bidirectional=True,
            batch_first=True,
            dropout=0.4
        )
        self.attention = Attention(hidden_size)
        self.dropout = nn.Dropout(0.6)
        self.fc = nn.Linear(hidden_size * 2, num_classes)

    def forward(self, x):
        outputs, (hidden, _) = self.lstm(x)
        hidden_combined = torch.cat((hidden[-2], hidden[-1]), dim=1)
        context = self.attention(hidden_combined, outputs)
        return self.fc(self.dropout(context))

# ─── Konfiguration ────────────────────────────────────────────────────────────
MAX_LENGTH = 400
FEATURE_FILE = "PHALP_AJohannes_3_features.pkl"
SCALER_FILE  = "scaler_schwung.pkl"
ENCODER_FILE = "encoder_schwung.pkl"
MODEL = "ski_schwung_classifier.pt"

# ─── Scaler & Encoder laden ───────────────────────────────────────────────────
with open(SCALER_FILE, "rb") as f:
    scaler = pickle.load(f)

with open(ENCODER_FILE, "rb") as f:
    le = pickle.load(f)

# ─── Features laden ───────────────────────────────────────────────────────────
with open(FEATURE_FILE, "rb") as f:
    features = pickle.load(f)

# ─── Daten vorbereiten ────────────────────────────────────────────────────────
data = []
frames = sorted(f for f in features["COM_to_ground"] if isinstance(f, int))

for frame in frames:
    frame_features = [
        features["COM_to_ground"][frame],
        features["knee_angles_right"][frame],
        features["knee_angles_left"][frame]
    ]

    # Füge joint_angles hinzu
    for axis in features["joint_angles"][frame]:
        frame_features.extend(
            features["joint_angles"][frame][axis].values()
        )

    # Füge axis_angles hinzu
    for axis in features["axis_angles"][frame]:
        frame_features.extend(
            features["axis_angles"][frame][axis].values()
        )

    # Füge COM_angles hinzu
    frame_features.extend(features["COM_angles"][frame].values())

    data.append(frame_features)

# Konvertiere zu numpy array
data = np.array(data, dtype=np.float32)

# Padden oder trimmen auf MAX_LENGTH
if data.shape[0] < MAX_LENGTH:
    pad = np.zeros((MAX_LENGTH - data.shape[0], data.shape[1]), dtype=np.float32)
    data = np.vstack([data, pad])
else:
    data = data[:MAX_LENGTH]

# Skalieren
data_scaled = scaler.transform(data)

# Zu Tensor konvertieren
x = torch.tensor(data_scaled, dtype=torch.float32).unsqueeze(0)  # [1, seq_len, feat_dim]

# ─── Modell laden ─────────────────────────────────────────────────────────────
# Füge sichere Klassen hinzu für das Laden
torch.serialization.add_safe_globals([
    SkiSwingLSTM,
    Attention,
    nn.LSTM,
    nn.Linear,
    nn.Dropout
])

# Modell laden
model = torch.load(MODEL, map_location="cpu")
model.eval()

# ─── Vorhersage machen ────────────────────────────────────────────────────────
with torch.no_grad():
    logits = model(x)
    style_idx = logits.argmax(dim=1).item()

# Klassennamen dekodieren
style_str = le.inverse_transform([style_idx])[0]
print("Vorhergesagte Schwungart:", style_str)

Vorhergesagte Schwungart: EKK


In [6]:
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.serialization

# ─── Einstellungen ────────────────────────────────────────────────────────────
# hier style_idx und style_str aus dem vorherigen Cell verfügbar

# ─── 1) Daten & Modelle laden ─────────────────────────────────────────────────
with open(FEATURE_FILE, "rb") as f:
    features = pickle.load(f)

class FrameWiseLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size,
                            num_layers=2, bidirectional=True,
                            batch_first=True, dropout=0.3)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_size * 2, num_classes)

    def forward(self, x):
        out, _ = self.lstm(x)
        return self.fc(self.dropout(out))

torch.serialization.add_safe_globals([FrameWiseLSTM, nn.LSTM])
model_turn = torch.load("left_right_classifier.pt", map_location="cpu", weights_only=False)
model_turn.eval()

# ─── 2) Turn-Sequenz aufbauen (Style als letztes Feature) ───────────────────
frame_indices = sorted(features["COM_to_ground"].keys())
seq_turn = []
for i in frame_indices:
    try:
        v = [
            features["COM_to_ground"][i],
            features["knee_angles_right"][i],
            features["knee_angles_left"][i]
        ]
        for axis in features["joint_angles"][i]:
            v.extend(features["joint_angles"][i][axis].values())
        for axis in features["axis_angles"][i]:
            v.extend(features["axis_angles"][i][axis].values())
        v.extend(features["COM_angles"][i].values())
        v.append(style_idx)
        seq_turn.append(v)
    except KeyError:
        continue

if not seq_turn:
    raise RuntimeError("Turn-Sequence ist leer – bitte Features prüfen!")

arr_turn = np.array(seq_turn, dtype=np.float32)      # [n_frames, feat_dim+1]
x_turn   = torch.tensor(arr_turn).unsqueeze(0)       # [1, n_frames, feat_dim+1]

# ─── 3) Links/Rechts pro Frame vorhersagen ───────────────────────────────────
with torch.no_grad():
    preds = model_turn(x_turn).argmax(dim=-1).squeeze().numpy()

labels_turn = ["Linksschwung" if p == 1 else "Rechtsschwung" for p in preds]
print("Schwung pro Frame:", labels_turn)

# ─── 4) Video annotieren & speichern ─────────────────────────────────────────
import cv2
cap    = cv2.VideoCapture(INPUT_VIDEO)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fps    = cap.get(cv2.CAP_PROP_FPS)
w      = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h      = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
out    = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (w, h))

for idx, swing_txt in enumerate(labels_turn):
    ret, frame = cap.read()
    if not ret:
        break
    cv2.putText(frame, swing_txt,        (50, 80),
                cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 3)
    cv2.putText(frame, f"Fahrstil: {style_str}", (50, h-30),
                cv2.FONT_HERSHEY_SIMPLEX, 1.2, (255,255,255), 2)
    out.write(frame)

cap.release()
out.release()
print(f"Annotated video saved as '{OUTPUT_VIDEO}'")


Schwung pro Frame: ['Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Lin