<a href="https://colab.research.google.com/github/pascalghanimi/Ski-Classification-AI/blob/main/Predictions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler, LabelEncoder
import pickle

# ─── Attention Klasse definieren ───────────────────────────────────────────────
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attn = nn.Linear(hidden_size * 4, hidden_size)  # Korrigierte Dimension
        self.v = nn.Linear(hidden_size, 1, bias=False)

    def forward(self, hidden, outputs):
        hidden = hidden.unsqueeze(1).repeat(1, outputs.size(1), 1)
        combined = torch.cat((hidden, outputs), dim=2)
        energy = torch.tanh(self.attn(combined))
        attention = torch.softmax(self.v(energy).squeeze(2), dim=1)
        return torch.sum(attention.unsqueeze(2) * outputs, dim=1)

# ─── Modellklasse definieren ───────────────────────────────────────────────────
class SkiSwingLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size,
            hidden_size,
            num_layers=2,
            bidirectional=True,
            batch_first=True,
            dropout=0.4
        )
        self.attention = Attention(hidden_size)
        self.dropout = nn.Dropout(0.6)
        self.fc = nn.Linear(hidden_size * 2, num_classes)

    def forward(self, x):
        outputs, (hidden, _) = self.lstm(x)
        hidden_combined = torch.cat((hidden[-2], hidden[-1]), dim=1)
        context = self.attention(hidden_combined, outputs)
        return self.fc(self.dropout(context))

# ─── Konfiguration ────────────────────────────────────────────────────────────
MAX_LENGTH = 400
FEATURE_FILE = "PHALP_AMichi_2_features.pkl"
SCALER_FILE  = "scaler_schwung.pkl"
ENCODER_FILE = "encoder_schwung.pkl"
INPUT_VIDEO = "PHALP_AMichi_2.mp4"
OUTPUT_VIDEO = "annotated_output.mp4"
MODEL = "ski_schwung_classifier.pt"

# ─── Scaler & Encoder laden ───────────────────────────────────────────────────
with open(SCALER_FILE, "rb") as f:
    scaler = pickle.load(f)

with open(ENCODER_FILE, "rb") as f:
    le = pickle.load(f)

# ─── Features laden ───────────────────────────────────────────────────────────
with open(FEATURE_FILE, "rb") as f:
    features = pickle.load(f)

# ─── Daten vorbereiten ────────────────────────────────────────────────────────
data = []
frames = sorted(f for f in features["COM_to_ground"] if isinstance(f, int))

for frame in frames:
    frame_features = [
        features["COM_to_ground"][frame],
        features["knee_angles_right"][frame],
        features["knee_angles_left"][frame]
    ]

    # Füge joint_angles hinzu
    for axis in features["joint_angles"][frame]:
        frame_features.extend(
            features["joint_angles"][frame][axis].values()
        )

    # Füge axis_angles hinzu
    for axis in features["axis_angles"][frame]:
        frame_features.extend(
            features["axis_angles"][frame][axis].values()
        )

    # Füge COM_angles hinzu
    frame_features.extend(features["COM_angles"][frame].values())

    data.append(frame_features)

# Konvertiere zu numpy array
data = np.array(data, dtype=np.float32)

# Padden oder trimmen auf MAX_LENGTH
if data.shape[0] < MAX_LENGTH:
    pad = np.zeros((MAX_LENGTH - data.shape[0], data.shape[1]), dtype=np.float32)
    data = np.vstack([data, pad])
else:
    data = data[:MAX_LENGTH]

# Skalieren
data_scaled = scaler.transform(data)

# Zu Tensor konvertieren
x = torch.tensor(data_scaled, dtype=torch.float32).unsqueeze(0)  # [1, seq_len, feat_dim]

# ─── Modell laden ─────────────────────────────────────────────────────────────
# Füge sichere Klassen hinzu für das Laden
torch.serialization.add_safe_globals([
    SkiSwingLSTM,
    Attention,
    nn.LSTM,
    nn.Linear,
    nn.Dropout
])

# Modell laden
model = torch.load(MODEL, map_location="cpu")
model.eval()

# ─── Vorhersage machen ────────────────────────────────────────────────────────
with torch.no_grad():
    logits = model(x)
    style_idx = logits.argmax(dim=1).item()

# Klassennamen dekodieren
style_str = le.inverse_transform([style_idx])[0]
print("Vorhergesagte Schwungart:", style_str)

Vorhergesagte Schwungart: PDK


In [None]:
import csv
import cv2

# ─── 2. Turn-Klassifikationsmodell ────────────────────────────────────────────
class FrameWiseLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size,
                            num_layers=2, bidirectional=True,
                            batch_first=True, dropout=0.3)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_size * 2, num_classes)

    def forward(self, x):
        out, _ = self.lstm(x)
        return self.fc(self.dropout(out))

# Füge sichere Klassen für Deserialisierung hinzu
torch.serialization.add_safe_globals([FrameWiseLSTM, nn.LSTM])

# ─── 3. Features laden und verarbeiten ────────────────────────────────────────
with open(FEATURE_FILE, "rb") as f:
    features = pickle.load(f)

# Extrahiere Frame-Indizes
frame_indices = sorted([f for f in features["COM_to_ground"].keys() if isinstance(f, int)])

# Baue die Sequenz für die Turn-Klassifikation auf
seq_turn = []
for i in frame_indices:
    try:
        v = [
            features["COM_to_ground"][i],
            features["knee_angles_right"][i],
            features["knee_angles_left"][i]
        ]

        # Füge joint_angles hinzu
        for axis in features["joint_angles"][i]:
            v.extend(features["joint_angles"][i][axis].values())

        # Füge axis_angles hinzu
        for axis in features["axis_angles"][i]:
            v.extend(features["axis_angles"][i][axis].values())

        # Füge COM_angles hinzu
        v.extend(features["COM_angles"][i].values())

        # Füge style_idx als zusätzliches Feature hinzu
        v.append(style_idx)

        seq_turn.append(v)
    except KeyError as e:
        print(f"Fehlendes Feature in Frame {i}: {e}")
        continue

if not seq_turn:
    raise RuntimeError("Turn-Sequence ist leer – bitte Features prüfen!")

# Konvertiere zu Tensor
arr_turn = np.array(seq_turn, dtype=np.float32)
x_turn = torch.tensor(arr_turn).unsqueeze(0).float()  # [1, n_frames, feat_dim+1]

# ─── 4. Modell laden und Vorhersage treffen ───────────────────────────────────
# Modell laden
model_turn = torch.load("left_right_classifier.pt", map_location="cpu")
model_turn.eval()

# Vorhersage treffen
with torch.no_grad():
    logits = model_turn(x_turn)
    preds = logits.argmax(dim=-1).squeeze().numpy()

# Labels erstellen
labels_turn = ["Linksschwung" if p == 1 else "Rechtsschwung" for p in preds]
print("Schwung pro Frame:", labels_turn)

# ─── 5. Video annotieren ──────────────────────────────────────────────────────
# Video öffnen
cap = cv2.VideoCapture(INPUT_VIDEO)
if not cap.isOpened():
    raise RuntimeError(f"Konnte Video {INPUT_VIDEO} nicht öffnen")

# Video-Writer einrichten
fps = cap.get(cv2.CAP_PROP_FPS)
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (w, h))

# Frame-Zähler
frame_count = 0

while True:
    ret, frame = cap.read()
    if not ret:
        break

    # Nur annotieren, wenn wir ein Label für diesen Frame haben
    if frame_count < len(labels_turn):
        swing_txt = labels_turn[frame_count]

        # Annotationen hinzufügen
        cv2.putText(frame, swing_txt, (50, 80),
                    cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 3)
        cv2.putText(frame, f"Fahrstil: {style_str}", (50, h-30),
                    cv2.FONT_HERSHEY_SIMPLEX, 1.2, (255, 255, 255), 2)

    # Frame schreiben
    out.write(frame)
    frame_count += 1

# Ressourcen freigeben
cap.release()
out.release()
print(f"Annotiertes Video wurde gespeichert als '{OUTPUT_VIDEO}'")


Schwung pro Frame: ['Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Rechtsschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Linksschwung', 'Link