In [None]:
!git clone https://github.com/takakib123/fed-conv-social-pooling.git

Cloning into 'fed-conv-social-pooling'...
remote: Enumerating objects: 52, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 52 (delta 0), reused 0 (delta 0), pack-reused 49 (from 1)[K
Receiving objects: 100% (52/52), 26.95 KiB | 452.00 KiB/s, done.
Resolving deltas: 100% (21/21), done.


In [None]:
data_dir = '/content/drive/MyDrive/Datasets/ngsim dataset'

In [None]:
%cd /content/fed-conv-social-pooling

/content/fed-conv-social-pooling


In [None]:
!wandb login

[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Create a new API key at: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Store your API key securely and do not share it.
[34m[1mwandb[0m: Paste your API key and hit enter: 
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33makibc123[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy
import logging
import time
import math
import os
import sys
import wandb
import gc
# Import necessary modules from the provided file structure
from model import highwayNet
from utils import ngsimDataset, maskedNLL, maskedMSE, maskedNLLTest

# --- Configuration & Logging Setup ---
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - [FL-System] - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("fl_experiment.log", mode='w'),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

ARGS = {
    'use_cuda': True and torch.cuda.is_available(),
    'encoder_size': 64,
    'decoder_size': 128,
    'in_length': 16,
    'out_length': 25,
    'grid_size': (13, 3),
    'soc_conv_depth': 64,
    'conv_3x1_depth': 16,
    'dyn_embedding_size': 32,
    'input_embedding_size': 32,
    'num_lat_classes': 3,
    'num_lon_classes': 2,
    'use_maneuvers': True,
    'train_flag': True,
}

# FL Hyperparameters
NUM_CLIENTS = 10
GLOBAL_ROUNDS = 8
LOCAL_EPOCHS = 2
PRETRAIN_ROUNDS = 3
BATCH_SIZE = 8192
DEVICE = torch.device("cuda" if ARGS['use_cuda'] else "cpu")
LOG_INTERVAL = 10
REUSE_WEIGHTS = True

torch.manual_seed(42)
np.random.seed(42)


class FLClient:
    def __init__(self, client_id, dataset, device, args):
        self.client_id = client_id
        self.dataset = dataset
        self.device = device
        self.args = args
        self.net = highwayNet(args).to(device)
        self.optimizer = torch.optim.Adam(self.net.parameters())
        self.crossEnt = torch.nn.BCELoss()
        self.dataloader = DataLoader(
            self.dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=0,
            collate_fn=self.dataset.dataset.collate_fn
        )

    def train(self, global_weights, round_num, global_step_offset):
        """
        Returns: (state_dict, avg_loss, steps_taken)
        """
        self.net.load_state_dict(global_weights)
        self.net.train()
        self.net.train_flag = True

        epoch_loss = 0
        batch_count = 0
        local_step = 0

        use_mse_loss = round_num < PRETRAIN_ROUNDS
        loss_mode = "MSE" if use_mse_loss else "NLL"

        start_time = time.time()

        for epoch in range(LOCAL_EPOCHS):
            for i, data in enumerate(self.dataloader):
                hist, nbrs, mask, lat_enc, lon_enc, fut, op_mask = data

                if self.args['use_cuda']:
                    hist = hist.to(self.device)
                    nbrs = nbrs.to(self.device)
                    mask = mask.to(self.device)
                    lat_enc = lat_enc.to(self.device)
                    lon_enc = lon_enc.to(self.device)
                    fut = fut.to(self.device)
                    op_mask = op_mask.to(self.device)

                if self.args['use_maneuvers']:
                    fut_pred, lat_pred, lon_pred = self.net(hist, nbrs, mask, lat_enc, lon_enc)
                    if use_mse_loss:
                        l = maskedMSE(fut_pred, fut, op_mask)
                    else:
                        l = maskedNLL(fut_pred, fut, op_mask) + \
                            self.crossEnt(lat_pred, lat_enc) + \
                            self.crossEnt(lon_pred, lon_enc)
                else:
                    fut_pred = self.net(hist, nbrs, mask, lat_enc, lon_enc)
                    if use_mse_loss:
                        l = maskedMSE(fut_pred, fut, op_mask)
                    else:
                        l = maskedNLL(fut_pred, fut, op_mask)

                self.optimizer.zero_grad()
                l.backward()
                torch.nn.utils.clip_grad_norm_(self.net.parameters(), 10)
                self.optimizer.step()

                epoch_loss += l.item()
                batch_count += 1
                local_step += 1


                current_global_step = global_step_offset + local_step

                if (i + 1) % LOG_INTERVAL == 0:
                    log_msg = (f"Client {self.client_id} | Round {round_num+1} | "
                               f"Batch {i+1} | Loss ({loss_mode}): {l.item():.4f}")
                    logger.info(log_msg)

                    # WandB: Log specific client metric, commit=True to push update immediately
                    wandb.log({
                        f"client_{self.client_id}/train_loss": l.item(),
                        "round": round_num + 1,
                        "client_id": self.client_id,
                        "global_step": current_global_step,
                        "local_step": round_num*batch_count + i,

                    })


        avg_loss = epoch_loss / batch_count if batch_count > 0 else 0

        gc.collect()
        torch.cuda.empty_cache()
        return self.net.state_dict(), avg_loss, local_step

def fed_avg(weights_list):
    w_avg = copy.deepcopy(weights_list[0])
    for k in w_avg.keys():
        for i in range(1, len(weights_list)):
            w_avg[k] += weights_list[i][k]
        w_avg[k] = torch.div(w_avg[k], len(weights_list))
    return w_avg

def validate_global_model(model, val_loader, round_num):
    model.eval()
    model.train_flag = False
    avg_val_loss = 0
    val_batch_count = 0
    use_mse_loss = round_num < PRETRAIN_ROUNDS

    with torch.no_grad():
        for i, data in enumerate(val_loader):
            hist, nbrs, mask, lat_enc, lon_enc, fut, op_mask = data
            if ARGS['use_cuda']:
                hist = hist.to(DEVICE)
                nbrs = nbrs.to(DEVICE)
                mask = mask.to(DEVICE)
                lat_enc = lat_enc.to(DEVICE)
                lon_enc = lon_enc.to(DEVICE)
                fut = fut.to(DEVICE)
                op_mask = op_mask.to(DEVICE)

            if ARGS['use_maneuvers']:
                if use_mse_loss:
                    model.train_flag = True
                    fut_pred, _, _ = model(hist, nbrs, mask, lat_enc, lon_enc)
                    l = maskedMSE(fut_pred, fut, op_mask)
                    model.train_flag = False
                else:
                    fut_pred, lat_pred, lon_pred = model(hist, nbrs, mask, lat_enc, lon_enc)
                    l = maskedNLLTest(fut_pred, lat_pred, lon_pred, fut, op_mask, avg_along_time=True)
            else:
                fut_pred = model(hist, nbrs, mask, lat_enc, lon_enc)
                if use_mse_loss:
                    l = maskedMSE(fut_pred, fut, op_mask)
                else:
                    l = maskedNLL(fut_pred, fut, op_mask)

            avg_val_loss += l.item()
            val_batch_count += 1

    return avg_val_loss / val_batch_count if val_batch_count > 0 else 0

def main():
    # Force reinit to prevent Zombie processes from locking the run
    wandb.init(
        project="conv-social-pooling-fl",
        reinit=True,
        config={
            "num_clients": NUM_CLIENTS,
            "global_rounds": GLOBAL_ROUNDS,
            "batch_size": BATCH_SIZE
        }
    )

    logger.info("Initializing Federated Learning Pipeline...")

    train_dataset = ngsimDataset('data/TrainSet.mat')
    val_dataset_full = ngsimDataset('data/ValSet.mat')

    val_len = len(val_dataset_full)
    VAL_SUBSET_RATIO = 1
    short_val_len = int(val_len * VAL_SUBSET_RATIO)
    logger.info(f"Shortening validation set: {short_val_len} samples (Original: {val_len})")

    val_subset = Subset(val_dataset_full, list(range(short_val_len)))


    val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False,
                           num_workers=0, collate_fn=val_dataset_full.collate_fn)

    # Partition Data
    total_samples = int(len(train_dataset)/1)
    indices = list(range(total_samples))
    split_size = total_samples // NUM_CLIENTS

    clients = []
    for i in range(NUM_CLIENTS):
        idx_start = i * split_size
        idx_end = (i + 1) * split_size
        client_subset = Subset(train_dataset, indices[idx_start:idx_end])
        clients.append(FLClient(i, client_subset, DEVICE, ARGS))

    global_model = highwayNet(ARGS).to(DEVICE)


    if REUSE_WEIGHTS:
        checkpoint_path = '/content/drive/MyDrive/Fed_traf/fed/fl_global_round_2.tar' # Or your best model
        # Handle cases where checkpoint saves 'state_dict' key or just weights
        state = torch.load(checkpoint_path, map_location=DEVICE)
        if 'state_dict' in state:
            global_model = net.load_state_dict(state['state_dict'])
        else:
            global_model.load_state_dict(state)
        print("Model loaded successfully.")

    global_weights = global_model.state_dict()

    if not os.path.exists('trained_models'):
        os.makedirs('trained_models')

    # Variable to track X-axis alignment across rounds
    global_step_tracker = 0

    for round_num in range(GLOBAL_ROUNDS):
        logger.info(f"--- Global Round {round_num + 1}/{GLOBAL_ROUNDS} ---")

        local_weights_list = []
        local_losses = []

        # Track max steps in this round to sync global_step_tracker
        max_steps_in_round = 0

        for client in clients:
            # Pass the current global step offset to the client
            w_local, loss, steps_taken = client.train(global_weights, round_num, global_step_tracker)
            local_weights_list.append(w_local)
            local_losses.append(loss)

            # We assume clients run in parallel (conceptually), so they share the same X-axis range.
            # We just need to know how much to advance the tracker for the NEXT round.
            if steps_taken > max_steps_in_round:
                max_steps_in_round = steps_taken

        # Advance the global step tracker by the length of one client's epoch
        global_step_tracker += max_steps_in_round

        # Aggregation
        global_weights = fed_avg(local_weights_list)
        global_model.load_state_dict(global_weights)

        # Validation
        val_loss = validate_global_model(global_model, val_loader, round_num)

        avg_train_loss = sum(local_losses) / len(local_losses)
        loss_mode = "MSE" if round_num < PRETRAIN_ROUNDS else "NLL"

        logger.info(f"Round {round_num + 1} | Val Loss: {val_loss:.4f} [{loss_mode}]")

        # Log Global Metrics
        wandb.log({
            "global/val_loss": val_loss,
            "global/avg_train_loss": avg_train_loss,
            "round": round_num + 1,
            "global_step": global_step_tracker
        })


        torch.save(global_weights, f'trained_models/fl_global_round_{round_num+1}.tar')

    logger.info("Federated Learning Complete.")
    wandb.finish()

if __name__ == '__main__':
    main()

0,1
client_0/train_loss,█▄▃▁▃
client_id,▁▁▁▁▁
global_step,▁▃▅▆█
local_step,▁▃▅▆█
round,▁▁▁▁▁

0,1
client_0/train_loss,359.15649
client_id,0.0
global_step,50.0
local_step,49.0
round,1.0


Model loaded successfully.


0,1
client_0/train_loss,█▄▃▃▃▂▂▂▂▂▂▁▁▁
client_1/train_loss,█▃▂▂▂▂▄▂▄▂▁▂▁▁
client_2/train_loss,▇▅▅▂▂█▂▂▁▁▃▄▂▁
client_3/train_loss,▅▂▃█▅▃▁▁▂▂▁▁▂▁
client_4/train_loss,█▄▃▅▆▃▃▂▂▂▃▄▁▂
client_5/train_loss,█▅▄▃▂▂▃▂▂▃▁▁▂▁
client_6/train_loss,█▃▃▂▅▂▂▂▂▂▁▂▂▂
client_7/train_loss,█▃▂▁▁▂▁▁▂▂▂▁▂▂
client_8/train_loss,█▂▂▁▂▂▂▁▂▁▁▂▁▁
client_9/train_loss,█▄▃▄▆▄▃▂▂▂▃▁▂▄

0,1
client_0/train_loss,271.42709
client_1/train_loss,86.61971
client_2/train_loss,62.09343
client_3/train_loss,66.58579
client_4/train_loss,79.26852
client_5/train_loss,133.60329
client_6/train_loss,88.17913
client_7/train_loss,61.65257
client_8/train_loss,56.85877
client_9/train_loss,42.66939


In [None]:
import torch
from torch.utils.data import DataLoader, Subset
import numpy as np
import math
from model import highwayNet
from utils import ngsimDataset, maskedMSETest

def main():
    # --- 1. Configuration ---
    args = {}
    args['use_cuda'] = torch.cuda.is_available()
    args['encoder_size'] = 64
    args['decoder_size'] = 128
    args['in_length'] = 16

    # Check if you want to evaluate 2.5s (25) or 5s (50)
    args['out_length'] = 25

    args['grid_size'] = (13,3)
    args['soc_conv_depth'] = 64
    args['conv_3x1_depth'] = 16
    args['dyn_embedding_size'] = 32
    args['input_embedding_size'] = 32
    args['num_lat_classes'] = 3
    args['num_lon_classes'] = 2
    args['use_maneuvers'] = True
    args['train_flag'] = False

    device = torch.device("cuda" if args['use_cuda'] else "cpu")

    # --- 2. Load Model ---
    print(f"Loading model (expecting out_length={args['out_length']})...")
    net = highwayNet(args).to(device)

    try:
        checkpoint_path = '/content/drive/MyDrive/Fed_traf/central/trained_models/cslstm_central_final.tar'
        state = torch.load(checkpoint_path, map_location=device)
        if 'state_dict' in state:
            net.load_state_dict(state['state_dict'])
        else:
            net.load_state_dict(state)
        print("Model loaded successfully.")
    except Exception as e:
        print(f"Error loading model: {e}")
        return

    net.eval()

    # --- 3. Load Test Data ---
    dataset_path = '/content/fed-conv-social-pooling/data/TestSet_Keep.mat'
    print(f"Loading test data from {dataset_path}...")

    tsSet_full = ngsimDataset(dataset_path)
    len_tes = int(len(tsSet_full)/1)
    tsSet = Subset(tsSet_full, list(range(len_tes)))
    tsDataloader = DataLoader(tsSet, batch_size=128, shuffle=False, num_workers=8, collate_fn=tsSet_full.collate_fn)

    # --- 4. Evaluation Loop ---

    # Trackers for RMSE (Squared Error)
    lossVals = torch.zeros(args['out_length']).to(device)
    counts = torch.zeros(args['out_length']).to(device)

    # Trackers for ADE/FDE (L2 Distance at every time step)
    fdeVals = torch.zeros(args['out_length']).to(device)
    fdeCounts = torch.zeros(args['out_length']).to(device)

    print("Running evaluation...")
    with torch.no_grad():
        for i, data in enumerate(tsDataloader):
            hist, nbrs, mask, lat_enc, lon_enc, fut, op_mask = data

            if args['use_cuda']:
                hist = hist.to(device)
                nbrs = nbrs.to(device)
                mask = mask.to(device)
                lat_enc = lat_enc.to(device)
                lon_enc = lon_enc.to(device)
                fut = fut.to(device)
                op_mask = op_mask.to(device)

            # --- Forward Pass ---
            if args['use_maneuvers']:
                fut_pred, lat_pred, lon_pred = net(hist, nbrs, mask, lat_enc, lon_enc)

                fut_pred_max = torch.zeros_like(fut_pred[0])
                for k in range(lat_pred.shape[0]):
                    lat_man = torch.argmax(lat_pred[k, :]).detach()
                    lon_man = torch.argmax(lon_pred[k, :]).detach()
                    indx = lon_man*3 + lat_man
                    fut_pred_max[:,k,:] = fut_pred[indx][:,k,:]

                pred_to_eval = fut_pred_max
            else:
                fut_pred = net(hist, nbrs, mask, lat_enc, lon_enc)
                pred_to_eval = fut_pred

            # --- 1. RMSE Accumulation ---
            l, c = maskedMSETest(pred_to_eval, fut, op_mask)
            time_steps = l.shape[0]
            lossVals[:time_steps] += l.detach()
            counts[:time_steps] += c.detach()

            # --- 2. ADE/FDE Accumulation (Per Time-Step) ---
            # Extract only muX, muY
            pred_pos = pred_to_eval[:, :, 0:2]

            # L2 norm (Euclidean distance) in feet
            diff = pred_pos - fut
            dist_l2 = torch.norm(diff, dim=2) # [Seq_Len, Batch]

            valid_mask = op_mask[:, :, 0] # [Seq_Len, Batch]
            masked_dist = dist_l2 * valid_mask

            # Sum the L2 distances across the batch FOR EACH time step
            fdeVals[:time_steps] += torch.sum(masked_dist, dim=1).detach()
            fdeCounts[:time_steps] += torch.sum(valid_mask, dim=1).detach()


    # --- 5. Formatting Results ---

    # 1. RMSE at each time step
    rmse_meters = torch.pow(lossVals / counts, 0.5) * 0.3048
    rmse_meters = rmse_meters.cpu().numpy()

    # 2. FDE exactly at each time step
    fde_meters = (fdeVals / fdeCounts) * 0.3048
    fde_meters = fde_meters.cpu().numpy()

    # 3. ADE up to each time step (Cumulative Average)
    cum_fdeVals = torch.cumsum(fdeVals, dim=0)
    cum_fdeCounts = torch.cumsum(fdeCounts, dim=0)
    ade_meters = (cum_fdeVals / cum_fdeCounts) * 0.3048
    ade_meters = ade_meters.cpu().numpy()

    # --- Print Unified Table ---
    print("\n" + "="*65)
    print("EVALUATION RESULTS (All Metrics in Meters)")
    print("="*65)

    horizons = [5, 10, 15, 20, 25]

    headers = []
    rmse_out = []
    ade_out = []
    fde_out = []

    for h in horizons:
        idx = h - 1 # 0-indexed tensor
        if idx < len(rmse_meters):
            time_label = f"{h/5:.1f}s"
            headers.append(time_label)

            rmse_out.append(f"{rmse_meters[idx]:.3f}")
            ade_out.append(f"{ade_meters[idx]:.3f}")
            fde_out.append(f"{fde_meters[idx]:.3f}")

    # Build and align the table
    print(f"{'Metric':<12} | " + " | ".join([f"{h:<8}" for h in headers]))
    print("-" * (12 + 11 * len(headers)))
    print(f"{'RMSE':<12} | " + " | ".join([f"{v:<8}" for v in rmse_out]))
    print(f"{'ADE':<12} | " + " | ".join([f"{v:<8}" for v in ade_out]))
    print(f"{'FDE':<12} | " + " | ".join([f"{v:<8}" for v in fde_out]))
    print("="*65)

if __name__ == '__main__':
    main()

Loading model (expecting out_length=25)...
Model loaded successfully.
Loading test data from /content/fed-conv-social-pooling/data/TestSet_Keep.mat...
Running evaluation...

EVALUATION RESULTS (All Metrics in Meters)
Metric       | 1.0s     | 2.0s     | 3.0s     | 4.0s     | 5.0s    
-------------------------------------------------------------------
RMSE         | 0.590    | 1.269    | 2.100    | 3.152    | 4.460   
ADE          | 0.228    | 0.464    | 0.729    | 1.035    | 1.392   
FDE          | 0.404    | 0.912    | 1.528    | 2.305    | 3.273   
