In [None]:
!git clone https://github.com/takakib123/fed-conv-social-pooling.git

Cloning into 'fed-conv-social-pooling'...
remote: Enumerating objects: 52, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 52 (delta 0), reused 0 (delta 0), pack-reused 49 (from 1)[K
Receiving objects: 100% (52/52), 26.95 KiB | 26.95 MiB/s, done.
Resolving deltas: 100% (21/21), done.


In [None]:
data_dir = '/content/drive/MyDrive/Datasets/ngsim dataset'

In [None]:
%cd /content/fed-conv-social-pooling

/content/fed-conv-social-pooling


In [None]:
!wandb login

[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Create a new API key at: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Store your API key securely and do not share it.
[34m[1mwandb[0m: Paste your API key and hit enter: 
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33makibc123[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import numpy as np
import logging
import time
import math
import os
import sys
import wandb

# Import necessary modules from the provided file structure
from model import highwayNet
from utils import ngsimDataset, maskedNLL, maskedMSE, maskedNLLTest

# --- Configuration & Logging Setup ---
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - [Central-System] - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("central_experiment.log", mode='w'),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

ARGS = {
    'use_cuda': True and torch.cuda.is_available(),
    'encoder_size': 64,
    'decoder_size': 128,
    'in_length': 16,
    'out_length': 25,
    'grid_size': (13, 3),
    'soc_conv_depth': 64,
    'conv_3x1_depth': 16,
    'dyn_embedding_size': 32,
    'input_embedding_size': 32,
    'num_lat_classes': 3,
    'num_lon_classes': 2,
    'use_maneuvers': True,
    'train_flag': True
}


TOTAL_EPOCHS = 10
PRETRAIN_EPOCHS = 5         # Rounds to use MSE
BATCH_SIZE = 2048
DEVICE = torch.device("cuda" if ARGS['use_cuda'] else "cpu")
LOG_INTERVAL = 10           # Log every 10 minibatches
VAL_SUBSET_RATIO = 1

def train_epoch(model, dataloader, optimizer, epoch_num, crossEnt):
    """
    Executes one training epoch.
    """
    model.train()
    model.train_flag = True

    epoch_loss = 0
    batch_count = 0


    use_mse_loss = epoch_num < PRETRAIN_EPOCHS
    loss_mode = "MSE" if use_mse_loss else "NLL"

    start_time = time.time()

    for i, data in enumerate(dataloader):
        hist, nbrs, mask, lat_enc, lon_enc, fut, op_mask = data

        if ARGS['use_cuda']:
            hist = hist.to(DEVICE)
            nbrs = nbrs.to(DEVICE)
            mask = mask.to(DEVICE)
            lat_enc = lat_enc.to(DEVICE)
            lon_enc = lon_enc.to(DEVICE)
            fut = fut.to(DEVICE)
            op_mask = op_mask.to(DEVICE)

        # Forward pass logic
        if ARGS['use_maneuvers']:
            fut_pred, lat_pred, lon_pred = model(hist, nbrs, mask, lat_enc, lon_enc)

            if use_mse_loss:
                l = maskedMSE(fut_pred, fut, op_mask)
            else:
                l = maskedNLL(fut_pred, fut, op_mask) + \
                    crossEnt(lat_pred, lat_enc) + \
                    crossEnt(lon_pred, lon_enc)
        else:
            fut_pred = model(hist, nbrs, mask, lat_enc, lon_enc)
            if use_mse_loss:
                l = maskedMSE(fut_pred, fut, op_mask)
            else:
                l = maskedNLL(fut_pred, fut, op_mask)

        # Backprop
        optimizer.zero_grad()
        l.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()

        epoch_loss += l.item()
        batch_count += 1

        # --- Granular Logging ---
        if (i + 1) % LOG_INTERVAL == 0:
            log_msg = (f"Epoch {epoch_num+1} | Batch {i+1} | Loss ({loss_mode}): {l.item():.4f}")
            logger.info(log_msg)

            # WandB logging
            wandb.log({
                "train_batch_loss": l.item(),
                "epoch": epoch_num + 1,
                "batch": i + 1,
                "training_phase": loss_mode
            })

    avg_loss = epoch_loss / batch_count if batch_count > 0 else 0
    duration = time.time() - start_time
    logger.info(f"Epoch {epoch_num+1} Complete | Avg Loss: {avg_loss:.4f} | Time: {duration:.2f}s")

    return avg_loss

def validate(model, val_loader, epoch_num):
    """
    [cite_start]Validation logic mirroring train.py [cite: 85-90].
    """
    model.eval()
    model.train_flag = False

    avg_val_loss = 0
    val_batch_count = 0

    use_mse_loss = epoch_num < PRETRAIN_EPOCHS
    loss_mode = "MSE" if use_mse_loss else "NLL"

    with torch.no_grad():
        for i, data in enumerate(val_loader):
            hist, nbrs, mask, lat_enc, lon_enc, fut, op_mask = data

            if ARGS['use_cuda']:
                hist = hist.to(DEVICE)
                nbrs = nbrs.to(DEVICE)
                mask = mask.to(DEVICE)
                lat_enc = lat_enc.to(DEVICE)
                lon_enc = lon_enc.to(DEVICE)
                fut = fut.to(DEVICE)
                op_mask = op_mask.to(DEVICE)

            if ARGS['use_maneuvers']:
                if use_mse_loss:
                    # Pre-training validation: temporary train_flag=True for MSE output
                    model.train_flag = True
                    fut_pred, _, _ = model(hist, nbrs, mask, lat_enc, lon_enc)
                    l = maskedMSE(fut_pred, fut, op_mask)
                    model.train_flag = False
                else:
                    # NLL Validation
                    fut_pred, lat_pred, lon_pred = model(hist, nbrs, mask, lat_enc, lon_enc)
                    l = maskedNLLTest(fut_pred, lat_pred, lon_pred, fut, op_mask, avg_along_time=True)
            else:
                fut_pred = model(hist, nbrs, mask, lat_enc, lon_enc)
                if use_mse_loss:
                    l = maskedMSE(fut_pred, fut, op_mask)
                else:
                    l = maskedNLL(fut_pred, fut, op_mask)

            avg_val_loss += l.item()
            val_batch_count += 1

    final_loss = avg_val_loss / val_batch_count if val_batch_count > 0 else 0
    return final_loss, loss_mode

def main():
    # Initialize WandB
    wandb.init(
        project="conv-social-pooling-central",
        reinit=True,
        config={
            "total_epochs": TOTAL_EPOCHS,
            "batch_size": BATCH_SIZE,
            "pretrain_epochs": PRETRAIN_EPOCHS,
            "val_subset_ratio": VAL_SUBSET_RATIO,
            **ARGS
        }
    )

    logger.info("Initializing Centralized Training Pipeline...")

    # 1. Load Datasets
    logger.info("Loading Data...")
    try:
        train_dataset_ful = ngsimDataset('./data/TrainSet.mat')
        val_dataset_full = ngsimDataset('./data/ValSet.mat')
    except FileNotFoundError:
        logger.error("Data files not found. Check 'data/' directory.")
        return
    num_train_samples = int(len(train_dataset_ful)/20)
    train_dataset = Subset(train_dataset_ful, list(range(num_train_samples)))

    # Train Loader
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=2,
        collate_fn=train_dataset_ful.collate_fn
    )

    # Shorten Validation Dataset (10%)
    val_len = len(val_dataset_full)
    short_val_len = int(val_len * VAL_SUBSET_RATIO)
    logger.info(f"Shortening validation set: {short_val_len} samples (Original: {val_len})")
    val_subset = Subset(val_dataset_full, list(range(short_val_len)))

    val_loader = DataLoader(
        val_subset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=2,
        collate_fn=val_dataset_full.collate_fn
    )

    # 2. Initialize Model & Optimizer
    net = highwayNet(ARGS).to(DEVICE)
    optimizer = torch.optim.Adam(net.parameters())
    crossEnt = torch.nn.BCELoss()

    if not os.path.exists('trained_models'):
        os.makedirs('trained_models')

    best_val_loss = math.inf

    # 3. Training Loop
    for epoch in range(TOTAL_EPOCHS):
        logger.info(f"--- Epoch {epoch + 1}/{TOTAL_EPOCHS} ---")

        # Train
        train_loss = train_epoch(net, train_loader, optimizer, epoch, crossEnt)

        # Validate
        val_loss, loss_mode = validate(net, val_loader, epoch)

        # Log Summary
        logger.info(f"Epoch {epoch + 1} Summary | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} [{loss_mode}]")

        wandb.log({
            "val_loss": val_loss,
            "avg_train_loss": train_loss,
            "epoch": epoch + 1,
            "training_phase": loss_mode
        })

        # Save Best Model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(net.state_dict(), 'trained_models/cslstm_central_best.tar')
            logger.info(f"New best model saved (Loss: {best_val_loss:.4f})")
            wandb.run.summary["best_val_loss"] = best_val_loss

        # Periodic Checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save(net.state_dict(), f'trained_models/central_epoch_{epoch+1}.tar')

    logger.info("Training Complete.")
    torch.save(net.state_dict(), 'trained_models/cslstm_central_final.tar')
    wandb.finish()

if __name__ == '__main__':
    main()

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: [wandb.login()] Loaded credentials for https://api.wandb.ai from /root/.netrc.
[34m[1mwandb[0m: Currently logged in as: [33makibc123[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


0,1
avg_train_loss,█▆▅▄▃▁▁▁▁▁
batch,▃▄▆█▃▇▁▃▃▄▇▃▆▆▃▅▆▇█▂▄▆▇▁▃▅▆▇█▂▆▇█▃▄▇▇▂▆█
epoch,▁▁▁▁▁▂▂▃▃▃▃▃▃▃▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇██
train_batch_loss,██▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▃▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▆▅▃▃▁▁▁▁▁

0,1
avg_train_loss,5.69558
batch,140
best_val_loss,5.23793
epoch,10
train_batch_loss,5.6797
training_phase,NLL
val_loss,5.23793
