In [112]:
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import matplotlib.pyplot as plt

In [113]:
repr_dic = "/shared/results/common/kargin/unreal_engine/features/initial_test"
MODEL = "spa"
result = torch.load(f"{repr_dic}/repr_{MODEL}.pt")

In [114]:
# --- Manual filtering ---
mask = [
    env == 'City' 
    and obj == 'SM_vehCar_vehicle05_Red' 
    and ori == 'Orbit'
    for env, obj, ori in zip(result['environment'], result['object_class'], result['orientation'])
]

In [115]:
FEATURES = "CLS"

features = torch.stack([f for f, m in zip(result['features'], mask) if m]).to(torch.float32)
angles = torch.stack([a for a, m in zip(result['number'], mask) if m]).to(torch.float32) / 360.0

if FEATURES == "CLS":
    features = features[:, 0, :]
elif FEATURES == "MEAN":
    features = features[:, -196:, :].mean(1)
elif FEATURES == "CENTER":
    features = features[:, 105, :]
else:
    raise Exception("bruh")


In [None]:
# sort the features by the angles
features = features[torch.argsort(angles)]
angles = angles[torch.argsort(angles)]

theta = torch.linspace(0, 2 * torch.pi, len(angles))
radius = 1
a = radius * torch.cos(theta)
b = radius * torch.sin(theta)
circles = torch.stack([a, b])


# --- Split ---
train_mask = angles < 0.90 # Time-based split
# train_mask = torch.rand(len(angles)) < 0.7 # Random split
val_mask = ~train_mask
train_x, train_y = features[train_mask], circles[:,train_mask]
val_x, val_y = features[val_mask], circles[:,val_mask]

train_y = train_y.T
val_y = val_y.T

# --- DataLoader ---
train_loader = DataLoader(TensorDataset(train_x, train_y), batch_size=64, shuffle=True)

# --- Model ---
model = nn.Sequential(
    nn.Linear(train_x.shape[1], 256),
    nn.ReLU(),
    nn.Dropout(p=0.4),
    nn.Linear(256, 2)
)
opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)
loss_fn = nn.MSELoss()
train_loss_list = []
val_loss_list = []
# --- Training ---
for epoch in range(300):
    pbar = tqdm(train_loader)
    for x, y in pbar:
        pred = model(x)#.squeeze(1).sigmoid()
        loss = loss_fn(pred, y)
        opt.zero_grad()
        loss.backward()
        opt.step()
        pbar.set_description(f"Epoch {epoch+1} | Loss: {loss.item():.4f}")
    if epoch % 10 == 0:
        model.eval()
        with torch.no_grad():
            preds = model(val_x)#.squeeze(1).sigmoid()
            val_loss = loss_fn(preds, val_y).item()
            val_loss_list.append(val_loss)
            train_loss_list.append(loss.item())
        model.train()


# --- Validation ---
model.eval()
with torch.no_grad():
    preds_val = model(val_x)#.squeeze(1).sigmoid()
    val_loss = loss_fn(preds_val, val_y).item()
    preds_train = model(train_x)#.squeeze(1).sigmoid()
    train_loss = loss_fn(preds_train, train_y).item()
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

In [None]:
# plot loss
plt.plot(train_loss_list, label='Training Loss')
plt.plot(val_loss_list, label='Validation Loss')
plt.legend()
plt.show()

In [None]:
# Detach tensors and convert to numpy for plotting
true_xy = val_y.numpy()
pred_xy = preds_val.numpy()

# Plot ground truth
plt.figure(figsize=(6, 6))
plt.scatter(true_xy[:, 0], true_xy[:, 1], c='blue', label='Ground Truth', alpha=0.6)

# Plot predictions
plt.scatter(pred_xy[:, 0], pred_xy[:, 1], c='red', label='Predictions', alpha=0.75)

# Draw connection lines
for gt, pred in zip(true_xy, pred_xy):
    plt.plot([gt[0], pred[0]], [gt[1], pred[1]], c='gray', alpha=0.5, linewidth=0.8)

# Draw unit circle for reference
circle = plt.Circle((0, 0), 1, color='gray', fill=False, linestyle='--')
plt.gca().add_artist(circle)

plt.axis('equal')
plt.xlim(0, 1.1)
plt.ylim(-1, 0.1)
plt.title("Predicted vs. Ground Truth Positions on Unit Circle")
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Detach tensors and convert to numpy for plotting
true_xy = train_y.numpy()
pred_xy = preds_train.numpy()

# Plot ground truth
plt.figure(figsize=(6, 6))
plt.scatter(true_xy[:, 0], true_xy[:, 1], c='blue', label='Ground Truth', alpha=0.6)

# Plot predictions
plt.scatter(pred_xy[:, 0], pred_xy[:, 1], c='red', label='Predictions', alpha=0.6)

# Draw unit circle for reference
circle = plt.Circle((0, 0), 1, color='gray', fill=False, linestyle='--')
plt.gca().add_artist(circle)

plt.axis('equal')
plt.title("Predicted vs. Ground Truth Positions on Unit Circle")
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()
plt.grid(True)
plt.show()