In [11]:
import os, pickle, math, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
from sklearn.model_selection import train_test_split

In [12]:
speed_path = '..\data\processed\speeds.pkl'

with open(speed_path, "rb") as f:
    speed_tensor = pickle.load(f)

print(speed_tensor.shape)

torch.Size([11978, 33])


In [13]:
class SpeedNetwork(nn.Module):
    def __init__(self):
        super(SpeedNetwork, self).__init__()
        self.fc1 = nn.Linear(32, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 32)
        self.fc5 = nn.Linear(32, 2)  # mu, sigma

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)

        mu, sigma_logit = x[:, :1], x[:, 1:]
        sigma = F.softplus(sigma_logit) + 1e-6
        return mu, sigma

In [14]:
class LOOFramesDataset(Dataset):
    """
    Für jeden Frame mit 33 Fischen erzeugt dies 33 Samples:
    Input: 32 Geschwindigkeiten (alle außer fokal)
    Target: 1 Geschwindigkeit (fokaler Fisch)
    """
    def __init__(self, speeds: torch.Tensor, frame_indices: np.ndarray):
        self.speeds = speeds
        self.frames = np.asarray(frame_indices, dtype=np.int64)
        self.M = speeds.shape[1]  # 33

    def __len__(self):
        return len(self.frames) * self.M

    def __getitem__(self, idx):
        frame_idx = self.frames[idx // self.M]
        focal = idx % self.M
        row = self.speeds[frame_idx]                      # (33,)
        target = row[focal]                               # ()
        # Leave-one-out Features (32,)
        x = torch.cat([row[:focal], row[focal+1:]])       # (32,)
        return x, target.unsqueeze(0)                     # (32,), (1,)

In [15]:
N_frames = speed_tensor.shape[0]
all_frames = np.arange(N_frames)
train_frames, val_frames = train_test_split(all_frames, test_size=0.2, shuffle=True, random_state=42)

train_ds = LOOFramesDataset(speed_tensor, train_frames)
val_ds   = LOOFramesDataset(speed_tensor, val_frames)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=32, shuffle=False, num_workers=0, pin_memory=False)

# ---- 5) Training mit RMSE-Loss ---------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SpeedNetwork().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0)

def rmse_loss(pred, target, eps=1e-12):
    # pred, target: (B, 1)
    mse = F.mse_loss(pred, target, reduction='mean')
    return torch.sqrt(mse + eps)

def epoch_pass(loader, train=True):
    if train:
        model.train()
    else:
        model.eval()
    total_rmse = 0.0
    total_count = 0
    with torch.set_grad_enabled(train):
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)          # (B,1)
            mu, _ = model(x)          # (B,1)
            loss = rmse_loss(mu, y)
            if train:
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
                optimizer.step()
            # RMSE mitteln über Batches gewichtet nach Batchgröße
            bs = x.size(0)
            total_rmse += loss.item() * bs
            total_count += bs
    return total_rmse / max(1, total_count)

best_val = float('inf')
patience, patience_left = 7, 7

for epoch in range(1, 51):  # 50 Epochen als Startpunkt
    train_rmse = epoch_pass(train_loader, train=True)
    val_rmse   = epoch_pass(val_loader,   train=False)
    print(f"Epoch {epoch:03d} | Train RMSE: {train_rmse:.6f} | Val RMSE: {val_rmse:.6f}")

    # Simple Early Stopping
    if val_rmse + 1e-6 < best_val:
        best_val = val_rmse
        patience_left = patience
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
    else:
        patience_left -= 1
        if patience_left == 0:
            print("Early stopping.")
            break


Epoch 001 | Train RMSE: 0.039492 | Val RMSE: 0.038319
Epoch 002 | Train RMSE: 0.039186 | Val RMSE: 0.038726
Epoch 003 | Train RMSE: 0.039158 | Val RMSE: 0.038210
Epoch 004 | Train RMSE: 0.039094 | Val RMSE: 0.038236
Epoch 005 | Train RMSE: 0.039085 | Val RMSE: 0.038202
Epoch 006 | Train RMSE: 0.039037 | Val RMSE: 0.039368
Epoch 007 | Train RMSE: 0.039042 | Val RMSE: 0.038455
Epoch 008 | Train RMSE: 0.039016 | Val RMSE: 0.038206
Epoch 009 | Train RMSE: 0.038985 | Val RMSE: 0.038182
Epoch 010 | Train RMSE: 0.038985 | Val RMSE: 0.038374
Epoch 011 | Train RMSE: 0.038965 | Val RMSE: 0.038243
Epoch 012 | Train RMSE: 0.038967 | Val RMSE: 0.038470
Epoch 013 | Train RMSE: 0.038957 | Val RMSE: 0.038693
Epoch 014 | Train RMSE: 0.038929 | Val RMSE: 0.038265
Epoch 015 | Train RMSE: 0.038910 | Val RMSE: 0.038195
Epoch 016 | Train RMSE: 0.038886 | Val RMSE: 0.038637
Early stopping.


In [16]:
if 'best_state' in locals():
    model.load_state_dict(best_state)

# Finaler Val-RMSE
final_val_rmse = epoch_pass(val_loader, train=False)
print(f"Final Validation RMSE: {final_val_rmse:.6f}")

# Beispielvorhersage auf einem Val-Batch
model.eval()
with torch.no_grad():
    for x, y in val_loader:
        x = x.to(device); y = y.to(device)
        pred, _ = model(x)
        # RMSE dieses Batches
        batch_rmse = torch.sqrt(F.mse_loss(pred, y) + 1e-12).item()
        print(f"Sample batch RMSE: {batch_rmse:.6f}")
        # zeige 5 Beispiele
        for i in range(min(5, x.size(0))):
            print(f"y_true={y[i,0].item():.4f} | y_pred={pred[i,0].item():.4f}")
        break

Final Validation RMSE: 0.038182
Sample batch RMSE: 0.032820
y_true=0.0472 | y_pred=0.0660
y_true=0.0415 | y_pred=0.0660
y_true=0.0538 | y_pred=0.0657
y_true=0.0574 | y_pred=0.0656
y_true=0.0343 | y_pred=0.0663
