In [74]:
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

# --- Manual filtering ---
mask = [
    env == 'City' 
    and obj == 'Car_Red' 
    and ori == 'Orbit'
    for env, obj, ori in zip(result['environment'], result['object_class'], result['orientation'])
]

features = torch.stack([f for f, m in zip(result['features'], mask) if m]).to(torch.float32)[:, 0, :]
angles = torch.stack([a for a, m in zip(result['number'], mask) if m]).to(torch.float32) / 360.0

# --- Split ---
train_mask = angles < 0.7
val_mask = ~train_mask
train_x, train_y = features[train_mask], angles[train_mask]
val_x, val_y = features[val_mask], angles[val_mask]

# --- DataLoader ---
train_loader = DataLoader(TensorDataset(train_x, train_y), batch_size=32, shuffle=True)

# --- Model ---
model = nn.Linear(train_x.shape[1], 1)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

# --- Training ---
for epoch in range(100):
    pbar = tqdm(train_loader)
    for x, y in pbar:
        pred = model(x).squeeze(1).sigmoid()
        loss = loss_fn(pred, y)
        opt.zero_grad()
        loss.backward()
        opt.step()
        pbar.set_description(f"Epoch {epoch+1} | Loss: {loss.item():.4f}")

# --- Validation ---
model.eval()
with torch.no_grad():
    preds = model(val_x).squeeze(1).sigmoid()
    val_loss = loss_fn(preds, val_y).item()
print(f"\nValidation Loss (second half): {val_loss:.4f}")

Epoch 1 | Loss: 0.0320: 100%|██████████| 5/5 [00:00<00:00, 505.12it/s]
Epoch 2 | Loss: 0.0365: 100%|██████████| 5/5 [00:00<00:00, 705.80it/s]
Epoch 3 | Loss: 0.0245: 100%|██████████| 5/5 [00:00<00:00, 716.09it/s]
Epoch 4 | Loss: 0.0245: 100%|██████████| 5/5 [00:00<00:00, 724.91it/s]
Epoch 5 | Loss: 0.0527: 100%|██████████| 5/5 [00:00<00:00, 721.17it/s]
Epoch 6 | Loss: 0.0144: 100%|██████████| 5/5 [00:00<00:00, 719.36it/s]
Epoch 7 | Loss: 0.0452: 100%|██████████| 5/5 [00:00<00:00, 742.72it/s]
Epoch 8 | Loss: 0.0194: 100%|██████████| 5/5 [00:00<00:00, 726.46it/s]
Epoch 9 | Loss: 0.0163: 100%|██████████| 5/5 [00:00<00:00, 726.56it/s]
Epoch 10 | Loss: 0.0271: 100%|██████████| 5/5 [00:00<00:00, 718.08it/s]
Epoch 11 | Loss: 0.0237: 100%|██████████| 5/5 [00:00<00:00, 683.94it/s]
Epoch 12 | Loss: 0.0189: 100%|██████████| 5/5 [00:00<00:00, 580.61it/s]
Epoch 13 | Loss: 0.0260: 100%|██████████| 5/5 [00:00<00:00, 693.43it/s]
Epoch 14 | Loss: 0.0133: 100%|██████████| 5/5 [00:00<00:00, 624.39it/s]
E


Validation Loss (second half): 0.2523



