In [109]:
import pandas as pd
import numpy as np
import torch
import math
import torch.nn as nn
from tqdm import tqdm
from torch.nn import LayerNorm
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
from utils import *
import torch.nn.functional as F

In [3]:
node_coords = pd.read_csv("data/sioux/SiouxFalls_node.tntp", sep='\t')
node_coords_arr = np.array(node_coords[['X', 'Y']])

sioux = create_network_df(network_name="SiouxFalls")
T_0, C = prepare_network_data(sioux)

directory = "/home/podozerovapo/traffic_assignment/data/sioux/uncongested"

inputs = []
outputs = []
metadata = []

for filename in sorted(os.listdir(directory)):
    if filename.endswith(".pkl"):
        filepath = os.path.join(directory, filename)
        
        with open(filepath, 'rb') as f:
            data_pair = pickle.load(f)
            
            inputs.append(data_pair['input'])
            outputs.append(data_pair['output'])
            metadata.append(data_pair.get('metadata', None))

input_matrices = np.array(inputs)  # [num_samples, num_nodes, num_nodes]
output_matrices = np.array(outputs)  # [num_samples, num_nodes, num_nodes]


In [None]:
class FeatureEmbedding(nn.Module):
    def __init__(self, input_size, embedding_size=32):
        super(FeatureEmbedding, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, embedding_size),
            nn.ReLU()
        )
        
    def forward(self, x):
        return self.network(x)

class SharedEncoderComponents(nn.Module):
    def __init__(self, node_feat_size, n_heads, dropout=0.1):
        super().__init__()
        self.node_feat_size = node_feat_size
        self.n_heads = n_heads
        self.head_dim = node_feat_size // n_heads
        
        self.q_linear = nn.Linear(node_feat_size, node_feat_size)
        self.k_linear = nn.Linear(node_feat_size, node_feat_size)
        self.v_linear = nn.Linear(node_feat_size, node_feat_size)
        
        self.out_ffn = nn.Sequential(
            nn.Linear(node_feat_size, node_feat_size),
            nn.ReLU(),
            nn.Linear(node_feat_size, node_feat_size)
        )
        
        self.layer_norm = nn.LayerNorm(node_feat_size)
        self.dropout = nn.Dropout(dropout)

class VEncoderLayer(nn.Module):
    def __init__(self, node_feat_size, edge_feat_size, n_heads, dropout=0.1):
        super().__init__()
        self.shared = SharedEncoderComponents(node_feat_size, n_heads, dropout)
        self.n_heads = n_heads
        self.head_dim = node_feat_size // n_heads
        self.edge_weight_ffn = nn.Sequential(
            nn.Linear(2 * node_feat_size, edge_feat_size),
            nn.ReLU(),
            nn.Linear(edge_feat_size, n_heads),
            nn.Sigmoid()
        )
        
        self.ffn = nn.Sequential(
            nn.Linear(node_feat_size, node_feat_size * 2),
            nn.ReLU(),
            nn.Linear(node_feat_size * 2, node_feat_size),
            nn.Dropout(dropout)
        )
        
    def forward(self, x, adj_mask):
        batch_size, num_nodes, _ = x.size()

        q = self.shared.q_linear(x).view(batch_size, num_nodes, self.n_heads, self.head_dim)
        k = self.shared.k_linear(x).view(batch_size, num_nodes, self.n_heads, self.head_dim)
        v = self.shared.v_linear(x).view(batch_size, num_nodes, self.n_heads, self.head_dim)
        
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        src_nodes = x.unsqueeze(2).expand(-1, -1, num_nodes, -1)
        dst_nodes = x.unsqueeze(1).expand(-1, num_nodes, -1, -1)
        node_pairs = torch.cat([src_nodes, dst_nodes], dim=-1)
        beta = self.edge_weight_ffn(node_pairs).permute(0, 3, 1, 2)
        
        adj_mask = adj_mask.unsqueeze(1)
        scores = scores * adj_mask * beta
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.shared.dropout(attn_weights)
        
        output = torch.matmul(attn_weights, v)
        output = output.transpose(1, 2).contiguous()
        output = output.view(batch_size, num_nodes, self.shared.node_feat_size)
        
        output = self.shared.out_ffn(output)
        output = self.shared.dropout(output)
        output = self.shared.layer_norm(x + output)
        
        ffn_output = self.ffn(output)
        output = self.shared.layer_norm(output + ffn_output)
        
        return output

class REncoderLayer(nn.Module):
    def __init__(self, node_feat_size, edge_feat_size, n_heads, dropout=0.1):
        super().__init__()
        self.shared = SharedEncoderComponents(node_feat_size, n_heads, dropout)
        self.n_heads = n_heads
        self.head_dim = node_feat_size // n_heads
        self.edge_feat_transform = nn.Linear(edge_feat_size, n_heads)
        
        self.ffn = nn.Sequential(
            nn.Linear(node_feat_size, node_feat_size * 2),
            nn.ReLU(),
            nn.Linear(node_feat_size * 2, node_feat_size),
            nn.Dropout(dropout)
        )
        
    def forward(self, x, adj_mask, edge_features):
        batch_size, num_nodes, _ = x.size()
        
        q = self.shared.q_linear(x).view(batch_size, num_nodes, self.n_heads, self.head_dim)
        k = self.shared.k_linear(x).view(batch_size, num_nodes, self.n_heads, self.head_dim)
        v = self.shared.v_linear(x).view(batch_size, num_nodes, self.n_heads, self.head_dim)
        
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        edge_weights = self.edge_feat_transform(edge_features).permute(0, 3, 1, 2)
    
        adj_mask = adj_mask.unsqueeze(1)
        scores = scores * adj_mask * torch.sigmoid(edge_weights)
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.shared.dropout(attn_weights)
        
        output = torch.matmul(attn_weights, v)
        output = output.transpose(1, 2).contiguous()
        output = output.view(batch_size, num_nodes, self.shared.node_feat_size)
        
        output = self.shared.out_ffn(output)
        output = self.shared.dropout(output)
        output = self.shared.layer_norm(x + output)
        
        ffn_output = self.ffn(output)
        output = self.shared.layer_norm(output + ffn_output)
        
        return output

class VEncoder(nn.Module):
    def __init__(self, node_feat_size, edge_feat_size, n_layers=3, n_heads=4, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            VEncoderLayer(node_feat_size, edge_feat_size, n_heads, dropout)
            for _ in range(n_layers)
        ])
        
    def forward(self, x, v_adj_mask):
        for layer in self.layers:
            x = layer(x, v_adj_mask)
        return x

class REncoder(nn.Module):
    def __init__(self, node_feat_size, edge_feat_size, n_layers=3, n_heads=4, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            REncoderLayer(node_feat_size, edge_feat_size, n_heads, dropout)
            for _ in range(n_layers)
        ])
        
    def forward(self, x, adj_mask, edge_features):
        for layer in self.layers:
            x = layer(x, adj_mask, edge_features)
        return x

class DualGraphEncoder(nn.Module):
    def __init__(self, node_feat_size, edge_feat_size,
                 v_layers=3, r_layers=3, n_heads=4, dropout=0.1):
        super().__init__()
        self.vencoder = VEncoder(node_feat_size, edge_feat_size, v_layers, n_heads, dropout)
        self.rencoder = REncoder(node_feat_size, edge_feat_size, r_layers, n_heads, dropout)

    def forward(self, node_features, real_adj_mask, virtual_adj_mask, edge_features):
        v_encoded = self.vencoder(node_features, virtual_adj_mask)
        r_encoded = self.rencoder(node_features, real_adj_mask, edge_features)
        return v_encoded, r_encoded

class TrafficAssignmentModel(nn.Module):
    def __init__(self, num_nodes, node_feat_size=32, edge_feat_size=2,
                 v_layers=2, r_layers=2, n_heads=4, dropout=0.1):
        super().__init__()
        
        self.feature_preprocessor = FeatureEmbedding(input_size=num_nodes + 2, embedding_size=32)
        
        self.dual_encoder = DualGraphEncoder(node_feat_size, edge_feat_size, 
                                           v_layers, r_layers, n_heads, dropout)
        
        self.flow_predictor = nn.Sequential(
            nn.Linear(4 * node_feat_size, node_feat_size),
            nn.ReLU(),
            nn.Linear(node_feat_size, 1),
            nn.ReLU()
        )

    def forward(self, node_features, real_adj_mask, virt_adj_mask, edge_features):
        
        node_features = self.feature_preprocessor(node_features)
        v_encoded, r_encoded = self.dual_encoder(node_features, real_adj_mask, virt_adj_mask, edge_features)
        combined = torch.cat([v_encoded, r_encoded], dim=-1)
        src_nodes = combined.unsqueeze(2).expand(-1, -1, combined.size(1), -1)
        dst_nodes = combined.unsqueeze(1).expand(-1, combined.size(1), -1, -1)
        pair_features = torch.cat([src_nodes, dst_nodes], dim=-1)

        flows = self.flow_predictor(pair_features).squeeze(-1)
        flows = flows * real_adj_mask
        
        return flows

In [5]:
coords_for_concat = np.repeat(np.expand_dims(node_coords_arr, 0), repeats=input_matrices.shape[0],axis=0)

In [6]:
X = np.concatenate([input_matrices, coords_for_concat], axis=2)
X.shape

(4096, 24, 26)

In [7]:
capacities = np.repeat(np.expand_dims(C, 0), axis=0, repeats=input_matrices.shape[0])
free_flow_times = np.repeat(np.expand_dims(T_0, 0), axis=0, repeats=input_matrices.shape[0])

In [None]:
class TrafficDataset(Dataset):
    def __init__(self, X_data, capacity_data, free_flow_data, flows_data):
        self.data = []
        for x, cap, fft, fl in zip(X_data, capacity_data, free_flow_data, flows_data):
            od_matrix = x[:, :-2]
            coordinates = x[:, -2:]

            od_mean = od_matrix.mean()
            od_std = od_matrix.std() + 1e-8
            od_matrix = (od_matrix - od_mean) / od_std

            coord_mean = coordinates.mean(axis=0)
            coord_std = coordinates.std(axis=0) + 1e-8
            coordinates = (coordinates - coord_mean) / coord_std

            x_normalized = np.concatenate([od_matrix, coordinates], axis=1)

            cap = (cap - cap.mean()) / (cap.std() + 1e-8)
            fft = (fft - fft.mean()) / (fft.std() + 1e-8)
            fl = (fl - fl.mean()) / (fl.std() + 1e-8)

            self.data.append({
                'od_matrix': torch.FloatTensor(od_matrix),
                'coordinates': torch.FloatTensor(coordinates),
                'capacity': torch.FloatTensor(cap),
                'free_flow': torch.FloatTensor(fft),
                'flows': torch.FloatTensor(fl),
                'x_normalized' : torch.FloatTensor(x_normalized)
            })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        real_adj_mask = (item['capacity'] > 0).float()
        virtual_adj = item['od_matrix'].abs()

        edge_features = torch.stack([
            item['capacity'],
            item['free_flow']
        ], dim=-1)

        agg_demand = item['od_matrix'].sum(dim=1, keepdim=True)
        node_features = torch.cat([item['coordinates'], agg_demand], dim=1)

        return {
            'node_features': node_features,
            'real_adj_mask': real_adj_mask,
            'edge_features': edge_features,
            'virtual_adj': virtual_adj,
            'target_flows': item['flows'],
            'x_normalized' : item['x_normalized']
        }

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=100, lr=1e-3, patience=5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3)
    criterion = nn.MSELoss()
    
    best_loss = float('inf')
    epochs_no_improve = 0
    train_history = []
    val_history = []
    
    for epoch in tqdm(range(num_epochs)):
        model.train()
        train_loss = 0.0
        for batch in train_loader:
            node_feat = batch['x_normalized'].to(device)
            real_mask = batch['real_adj_mask'].to(device)
            edge_feat = batch['edge_features'].to(device)
            virt_adj = batch['virtual_adj'].to(device)
            targets = batch['target_flows'].to(device)
            
            optimizer.zero_grad()
            outputs = model(node_feat, real_mask, virt_adj, edge_feat)
            loss = criterion(outputs[real_mask.bool()], targets[real_mask.bool()])
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            train_loss += loss.item() * node_feat.size(0)
        
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                node_feat = batch['x_normalized'].to(device)
                real_mask = batch['real_adj_mask'].to(device)
                edge_feat = batch['edge_features'].to(device)
                virt_adj = batch['virtual_adj'].to(device)
                targets = batch['target_flows'].to(device)
                
                outputs = model(node_feat, real_mask, virt_adj, edge_feat)
                
                loss = criterion(outputs[real_mask.bool()], targets[real_mask.bool()])
                val_loss += loss.item() * node_feat.size(0)
        
        train_loss = train_loss / len(train_loader.dataset)
        val_loss = val_loss / len(val_loader.dataset)
        train_history.append(train_loss)
        val_history.append(val_loss)
        
        scheduler.step(val_loss)
        
        print(f'Epoch {epoch + 1}/{num_epochs}')
        print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}')
        print(f'LR: {optimizer.param_groups[0]["lr"]:.2e}')
        
        if val_loss < best_loss:
            best_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                print(f'Early stopping after {epoch+1} epochs')
                break
    
    plt.figure(figsize=(10, 6))
    plt.plot(train_history, label='Train Loss')
    plt.plot(val_history, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training History')
    plt.legend()
    plt.savefig('training_history.png')
    plt.close()
    
    model.load_state_dict(torch.load('best_model.pth'))
    return model

In [None]:
def model_predict(model, data_loader, device=None):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = model.to(device)
    model.eval()

    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for batch in tqdm(data_loader, desc='Making predictions'):
            node_feat = batch['x_normalized'].to(device)
            real_mask = batch['real_adj_mask'].to(device)
            edge_feat = batch['edge_features'].to(device)
            virt_adj = batch['virtual_adj'].to(device)
            targets = batch['target_flows'].to(device)

            outputs = model(node_feat, real_mask, virt_adj, edge_feat)

            mask = real_mask.bool()
            all_predictions.append(outputs[mask].cpu().numpy())
            all_targets.append(targets[mask].cpu().numpy())
            break

    predictions = np.concatenate(all_predictions)
    targets = np.concatenate(all_targets)

    return predictions, targets

In [107]:
ans = model_predict(model, train_loader)

Making predictions:   0%|          | 0/52 [00:00<?, ?it/s]


In [108]:
print(ans[0], ans[1], sep='\n')

[0.48642147 0.9556702  0.5360018  ... 2.403835   2.7833183  1.0279554 ]
[0.4724989  0.96577114 0.63058925 ... 2.6740074  2.8549805  0.9699092 ]


In [106]:
n_nodes = 24
X_data = X
capacity_data = capacities
free_flow_data = free_flow_times
flows_data = output_matrices

full_dataset = TrafficDataset(X_data, capacity_data, free_flow_data, flows_data)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

model = TrafficAssignmentModel(24)

# Запуск обучения
trained_model = train_model(
    model,
    train_loader,
    val_loader,
    num_epochs=100,
    lr=1e-3,
    patience=10
)

# Сохранение финальной модели
torch.save(trained_model.state_dict(), 'final_model.pth')

  1%|          | 1/100 [00:03<05:46,  3.50s/it]

Epoch 1/100
Train Loss: 1.7905 | Val Loss: 0.8613
LR: 1.00e-03


  2%|▏         | 2/100 [00:06<05:42,  3.49s/it]

Epoch 2/100
Train Loss: 0.6359 | Val Loss: 0.3121
LR: 1.00e-03


  3%|▎         | 3/100 [00:10<05:47,  3.59s/it]

Epoch 3/100
Train Loss: 0.1756 | Val Loss: 0.0997
LR: 1.00e-03


  4%|▍         | 4/100 [00:14<05:37,  3.51s/it]

Epoch 4/100
Train Loss: 0.0990 | Val Loss: 0.0806
LR: 1.00e-03


  5%|▌         | 5/100 [00:17<05:43,  3.61s/it]

Epoch 5/100
Train Loss: 0.0864 | Val Loss: 0.0763
LR: 1.00e-03


  6%|▌         | 6/100 [00:21<05:30,  3.52s/it]

Epoch 6/100
Train Loss: 0.0797 | Val Loss: 0.0718
LR: 1.00e-03


  7%|▋         | 7/100 [00:24<05:26,  3.51s/it]

Epoch 7/100
Train Loss: 0.0746 | Val Loss: 0.0678
LR: 1.00e-03


  8%|▊         | 8/100 [00:28<05:17,  3.45s/it]

Epoch 8/100
Train Loss: 0.0713 | Val Loss: 0.0669
LR: 1.00e-03


  9%|▉         | 9/100 [00:31<05:20,  3.52s/it]

Epoch 9/100
Train Loss: 0.0688 | Val Loss: 0.0634
LR: 1.00e-03


 10%|█         | 10/100 [00:35<05:19,  3.55s/it]

Epoch 10/100
Train Loss: 0.0664 | Val Loss: 0.0617
LR: 1.00e-03


 11%|█         | 11/100 [00:39<05:19,  3.59s/it]

Epoch 11/100
Train Loss: 0.0646 | Val Loss: 0.0603
LR: 1.00e-03


 12%|█▏        | 12/100 [00:42<05:16,  3.60s/it]

Epoch 12/100
Train Loss: 0.0628 | Val Loss: 0.0582
LR: 1.00e-03


 13%|█▎        | 13/100 [00:45<05:06,  3.52s/it]

Epoch 13/100
Train Loss: 0.0608 | Val Loss: 0.0587
LR: 1.00e-03


 14%|█▍        | 14/100 [00:49<05:02,  3.52s/it]

Epoch 14/100
Train Loss: 0.0590 | Val Loss: 0.0551
LR: 1.00e-03


 15%|█▌        | 15/100 [00:53<05:00,  3.54s/it]

Epoch 15/100
Train Loss: 0.0580 | Val Loss: 0.0545
LR: 1.00e-03


 16%|█▌        | 16/100 [00:56<04:47,  3.42s/it]

Epoch 16/100
Train Loss: 0.0561 | Val Loss: 0.0521
LR: 1.00e-03


 17%|█▋        | 17/100 [00:59<04:36,  3.34s/it]

Epoch 17/100
Train Loss: 0.0546 | Val Loss: 0.0526
LR: 1.00e-03


 18%|█▊        | 18/100 [01:02<04:29,  3.29s/it]

Epoch 18/100
Train Loss: 0.0537 | Val Loss: 0.0502
LR: 1.00e-03


 19%|█▉        | 19/100 [01:04<04:03,  3.00s/it]

Epoch 19/100
Train Loss: 0.0519 | Val Loss: 0.0496
LR: 1.00e-03


 20%|██        | 20/100 [01:07<03:50,  2.88s/it]

Epoch 20/100
Train Loss: 0.0510 | Val Loss: 0.0504
LR: 1.00e-03


 21%|██        | 21/100 [01:10<04:01,  3.05s/it]

Epoch 21/100
Train Loss: 0.0507 | Val Loss: 0.0478
LR: 1.00e-03


 22%|██▏       | 22/100 [01:14<04:07,  3.18s/it]

Epoch 22/100
Train Loss: 0.0500 | Val Loss: 0.0481
LR: 1.00e-03


 23%|██▎       | 23/100 [01:17<04:05,  3.18s/it]

Epoch 23/100
Train Loss: 0.0482 | Val Loss: 0.0468
LR: 1.00e-03


 24%|██▍       | 24/100 [01:21<04:09,  3.28s/it]

Epoch 24/100
Train Loss: 0.0475 | Val Loss: 0.0462
LR: 1.00e-03


 25%|██▌       | 25/100 [01:24<04:09,  3.33s/it]

Epoch 25/100
Train Loss: 0.0468 | Val Loss: 0.0456
LR: 1.00e-03


 26%|██▌       | 26/100 [01:28<04:11,  3.40s/it]

Epoch 26/100
Train Loss: 0.0468 | Val Loss: 0.0438
LR: 1.00e-03


 27%|██▋       | 27/100 [01:31<04:09,  3.42s/it]

Epoch 27/100
Train Loss: 0.0460 | Val Loss: 0.0451
LR: 1.00e-03


 28%|██▊       | 28/100 [01:35<04:10,  3.47s/it]

Epoch 28/100
Train Loss: 0.0453 | Val Loss: 0.0444
LR: 1.00e-03


 29%|██▉       | 29/100 [01:38<04:08,  3.51s/it]

Epoch 29/100
Train Loss: 0.0449 | Val Loss: 0.0425
LR: 1.00e-03


 30%|███       | 30/100 [01:42<04:07,  3.53s/it]

Epoch 30/100
Train Loss: 0.0443 | Val Loss: 0.0430
LR: 1.00e-03


 31%|███       | 31/100 [01:45<04:03,  3.53s/it]

Epoch 31/100
Train Loss: 0.0440 | Val Loss: 0.0420
LR: 1.00e-03


 32%|███▏      | 32/100 [01:49<04:04,  3.60s/it]

Epoch 32/100
Train Loss: 0.0437 | Val Loss: 0.0427
LR: 1.00e-03


 33%|███▎      | 33/100 [01:53<04:01,  3.61s/it]

Epoch 33/100
Train Loss: 0.0434 | Val Loss: 0.0413
LR: 1.00e-03


 34%|███▍      | 34/100 [01:56<03:52,  3.52s/it]

Epoch 34/100
Train Loss: 0.0422 | Val Loss: 0.0426
LR: 1.00e-03


 35%|███▌      | 35/100 [02:00<03:47,  3.50s/it]

Epoch 35/100
Train Loss: 0.0419 | Val Loss: 0.0397
LR: 1.00e-03


 36%|███▌      | 36/100 [02:03<03:43,  3.49s/it]

Epoch 36/100
Train Loss: 0.0415 | Val Loss: 0.0405
LR: 1.00e-03


 37%|███▋      | 37/100 [02:06<03:39,  3.49s/it]

Epoch 37/100
Train Loss: 0.0409 | Val Loss: 0.0390
LR: 1.00e-03


 38%|███▊      | 38/100 [02:10<03:37,  3.51s/it]

Epoch 38/100
Train Loss: 0.0404 | Val Loss: 0.0405
LR: 1.00e-03


 39%|███▉      | 39/100 [02:14<03:33,  3.50s/it]

Epoch 39/100
Train Loss: 0.0397 | Val Loss: 0.0395
LR: 1.00e-03


 40%|████      | 40/100 [02:17<03:31,  3.52s/it]

Epoch 40/100
Train Loss: 0.0396 | Val Loss: 0.0379
LR: 1.00e-03


 41%|████      | 41/100 [02:20<03:24,  3.47s/it]

Epoch 41/100
Train Loss: 0.0391 | Val Loss: 0.0402
LR: 1.00e-03


 42%|████▏     | 42/100 [02:24<03:19,  3.43s/it]

Epoch 42/100
Train Loss: 0.0389 | Val Loss: 0.0383
LR: 1.00e-03


 43%|████▎     | 43/100 [02:27<03:15,  3.42s/it]

Epoch 43/100
Train Loss: 0.0382 | Val Loss: 0.0362
LR: 1.00e-03


 44%|████▍     | 44/100 [02:30<03:08,  3.37s/it]

Epoch 44/100
Train Loss: 0.0376 | Val Loss: 0.0365
LR: 1.00e-03


 45%|████▌     | 45/100 [02:34<03:07,  3.40s/it]

Epoch 45/100
Train Loss: 0.0376 | Val Loss: 0.0359
LR: 1.00e-03


 46%|████▌     | 46/100 [02:37<03:06,  3.46s/it]

Epoch 46/100
Train Loss: 0.0368 | Val Loss: 0.0377
LR: 1.00e-03


 47%|████▋     | 47/100 [02:41<03:05,  3.49s/it]

Epoch 47/100
Train Loss: 0.0367 | Val Loss: 0.0353
LR: 1.00e-03


 48%|████▊     | 48/100 [02:45<03:01,  3.48s/it]

Epoch 48/100
Train Loss: 0.0364 | Val Loss: 0.0361
LR: 1.00e-03


 49%|████▉     | 49/100 [02:48<02:59,  3.51s/it]

Epoch 49/100
Train Loss: 0.0362 | Val Loss: 0.0353
LR: 1.00e-03


 50%|█████     | 50/100 [02:52<02:55,  3.51s/it]

Epoch 50/100
Train Loss: 0.0362 | Val Loss: 0.0345
LR: 1.00e-03


 51%|█████     | 51/100 [02:55<02:51,  3.51s/it]

Epoch 51/100
Train Loss: 0.0360 | Val Loss: 0.0367
LR: 1.00e-03


 52%|█████▏    | 52/100 [02:59<02:48,  3.52s/it]

Epoch 52/100
Train Loss: 0.0360 | Val Loss: 0.0340
LR: 1.00e-03


 53%|█████▎    | 53/100 [03:02<02:43,  3.49s/it]

Epoch 53/100
Train Loss: 0.0350 | Val Loss: 0.0337
LR: 1.00e-03


 54%|█████▍    | 54/100 [03:06<02:41,  3.51s/it]

Epoch 54/100
Train Loss: 0.0350 | Val Loss: 0.0337
LR: 1.00e-03


 55%|█████▌    | 55/100 [03:08<02:27,  3.28s/it]

Epoch 55/100
Train Loss: 0.0347 | Val Loss: 0.0348
LR: 1.00e-03


 56%|█████▌    | 56/100 [03:12<02:24,  3.29s/it]

Epoch 56/100
Train Loss: 0.0343 | Val Loss: 0.0330
LR: 1.00e-03


 57%|█████▋    | 57/100 [03:15<02:24,  3.35s/it]

Epoch 57/100
Train Loss: 0.0343 | Val Loss: 0.0331
LR: 1.00e-03


 58%|█████▊    | 58/100 [03:19<02:22,  3.39s/it]

Epoch 58/100
Train Loss: 0.0339 | Val Loss: 0.0327
LR: 1.00e-03


 59%|█████▉    | 59/100 [03:22<02:19,  3.39s/it]

Epoch 59/100
Train Loss: 0.0335 | Val Loss: 0.0326
LR: 1.00e-03


 60%|██████    | 60/100 [03:26<02:16,  3.41s/it]

Epoch 60/100
Train Loss: 0.0335 | Val Loss: 0.0326
LR: 1.00e-03


 61%|██████    | 61/100 [03:29<02:14,  3.45s/it]

Epoch 61/100
Train Loss: 0.0332 | Val Loss: 0.0322
LR: 1.00e-03


 62%|██████▏   | 62/100 [03:33<02:12,  3.49s/it]

Epoch 62/100
Train Loss: 0.0333 | Val Loss: 0.0319
LR: 1.00e-03


 63%|██████▎   | 63/100 [03:36<02:06,  3.43s/it]

Epoch 63/100
Train Loss: 0.0328 | Val Loss: 0.0339
LR: 1.00e-03


 64%|██████▍   | 64/100 [03:39<01:56,  3.24s/it]

Epoch 64/100
Train Loss: 0.0332 | Val Loss: 0.0316
LR: 1.00e-03


 65%|██████▌   | 65/100 [03:42<01:56,  3.34s/it]

Epoch 65/100
Train Loss: 0.0325 | Val Loss: 0.0317
LR: 1.00e-03


 66%|██████▌   | 66/100 [03:45<01:42,  3.01s/it]

Epoch 66/100
Train Loss: 0.0321 | Val Loss: 0.0307
LR: 1.00e-03


 67%|██████▋   | 67/100 [03:48<01:43,  3.14s/it]

Epoch 67/100
Train Loss: 0.0321 | Val Loss: 0.0323
LR: 1.00e-03


 68%|██████▊   | 68/100 [03:51<01:43,  3.22s/it]

Epoch 68/100
Train Loss: 0.0321 | Val Loss: 0.0316
LR: 1.00e-03


 69%|██████▉   | 69/100 [03:55<01:43,  3.33s/it]

Epoch 69/100
Train Loss: 0.0317 | Val Loss: 0.0309
LR: 1.00e-03


 70%|███████   | 70/100 [03:59<01:41,  3.39s/it]

Epoch 70/100
Train Loss: 0.0317 | Val Loss: 0.0323
LR: 5.00e-04


 71%|███████   | 71/100 [04:02<01:38,  3.41s/it]

Epoch 71/100
Train Loss: 0.0309 | Val Loss: 0.0298
LR: 5.00e-04


 72%|███████▏  | 72/100 [04:06<01:36,  3.45s/it]

Epoch 72/100
Train Loss: 0.0305 | Val Loss: 0.0306
LR: 5.00e-04


 73%|███████▎  | 73/100 [04:09<01:33,  3.46s/it]

Epoch 73/100
Train Loss: 0.0305 | Val Loss: 0.0291
LR: 5.00e-04


 74%|███████▍  | 74/100 [04:13<01:30,  3.48s/it]

Epoch 74/100
Train Loss: 0.0302 | Val Loss: 0.0291
LR: 5.00e-04


 75%|███████▌  | 75/100 [04:16<01:27,  3.49s/it]

Epoch 75/100
Train Loss: 0.0301 | Val Loss: 0.0295
LR: 5.00e-04


 76%|███████▌  | 76/100 [04:20<01:23,  3.50s/it]

Epoch 76/100
Train Loss: 0.0302 | Val Loss: 0.0291
LR: 5.00e-04


 77%|███████▋  | 77/100 [04:23<01:20,  3.50s/it]

Epoch 77/100
Train Loss: 0.0299 | Val Loss: 0.0295
LR: 2.50e-04


 78%|███████▊  | 78/100 [04:27<01:17,  3.54s/it]

Epoch 78/100
Train Loss: 0.0296 | Val Loss: 0.0288
LR: 2.50e-04


 79%|███████▉  | 79/100 [04:30<01:13,  3.51s/it]

Epoch 79/100
Train Loss: 0.0295 | Val Loss: 0.0284
LR: 2.50e-04


 80%|████████  | 80/100 [04:33<01:09,  3.47s/it]

Epoch 80/100
Train Loss: 0.0294 | Val Loss: 0.0284
LR: 2.50e-04


 81%|████████  | 81/100 [04:37<01:05,  3.45s/it]

Epoch 81/100
Train Loss: 0.0294 | Val Loss: 0.0285
LR: 2.50e-04


 82%|████████▏ | 82/100 [04:40<01:02,  3.45s/it]

Epoch 82/100
Train Loss: 0.0293 | Val Loss: 0.0284
LR: 2.50e-04


 83%|████████▎ | 83/100 [04:44<00:58,  3.47s/it]

Epoch 83/100
Train Loss: 0.0293 | Val Loss: 0.0285
LR: 2.50e-04


 84%|████████▍ | 84/100 [04:47<00:55,  3.50s/it]

Epoch 84/100
Train Loss: 0.0292 | Val Loss: 0.0282
LR: 2.50e-04


 85%|████████▌ | 85/100 [04:51<00:52,  3.50s/it]

Epoch 85/100
Train Loss: 0.0292 | Val Loss: 0.0282
LR: 2.50e-04


 86%|████████▌ | 86/100 [04:54<00:48,  3.50s/it]

Epoch 86/100
Train Loss: 0.0292 | Val Loss: 0.0281
LR: 2.50e-04


 87%|████████▋ | 87/100 [04:58<00:45,  3.54s/it]

Epoch 87/100
Train Loss: 0.0291 | Val Loss: 0.0283
LR: 2.50e-04


 88%|████████▊ | 88/100 [05:01<00:41,  3.49s/it]

Epoch 88/100
Train Loss: 0.0291 | Val Loss: 0.0279
LR: 2.50e-04


 89%|████████▉ | 89/100 [05:05<00:37,  3.45s/it]

Epoch 89/100
Train Loss: 0.0290 | Val Loss: 0.0282
LR: 2.50e-04


 90%|█████████ | 90/100 [05:08<00:34,  3.49s/it]

Epoch 90/100
Train Loss: 0.0290 | Val Loss: 0.0281
LR: 2.50e-04


 91%|█████████ | 91/100 [05:12<00:30,  3.40s/it]

Epoch 91/100
Train Loss: 0.0289 | Val Loss: 0.0280
LR: 2.50e-04


 92%|█████████▏| 92/100 [05:15<00:27,  3.45s/it]

Epoch 92/100
Train Loss: 0.0289 | Val Loss: 0.0281
LR: 1.25e-04


 93%|█████████▎| 93/100 [05:18<00:23,  3.34s/it]

Epoch 93/100
Train Loss: 0.0287 | Val Loss: 0.0277
LR: 1.25e-04


 94%|█████████▍| 94/100 [05:22<00:20,  3.36s/it]

Epoch 94/100
Train Loss: 0.0286 | Val Loss: 0.0276
LR: 1.25e-04


 95%|█████████▌| 95/100 [05:25<00:16,  3.35s/it]

Epoch 95/100
Train Loss: 0.0286 | Val Loss: 0.0275
LR: 1.25e-04


 96%|█████████▌| 96/100 [05:28<00:13,  3.39s/it]

Epoch 96/100
Train Loss: 0.0286 | Val Loss: 0.0275
LR: 1.25e-04


 97%|█████████▋| 97/100 [05:32<00:10,  3.44s/it]

Epoch 97/100
Train Loss: 0.0285 | Val Loss: 0.0276
LR: 1.25e-04


 98%|█████████▊| 98/100 [05:35<00:06,  3.44s/it]

Epoch 98/100
Train Loss: 0.0286 | Val Loss: 0.0275
LR: 1.25e-04


 99%|█████████▉| 99/100 [05:39<00:03,  3.46s/it]

Epoch 99/100
Train Loss: 0.0285 | Val Loss: 0.0275
LR: 6.25e-05


100%|██████████| 100/100 [05:43<00:00,  3.43s/it]

Epoch 100/100
Train Loss: 0.0284 | Val Loss: 0.0273
LR: 6.25e-05





In [7]:
from tqdm import tqdm

In [103]:
model = FlowDistributionModel(num_nodes, hidden_dim=32).cuda()

In [104]:
def train(model, dataloader, epochs=5):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    for epoch in tqdm(range(epochs)):
        total_loss = 0
        for batch in dataloader:
            OD, C, T0, targets = batch.values()
            # Move data to GPU
            OD, C, T0, targets = OD.cuda(), C.cuda(), T0.cuda(), targets.cuda()
            
            optimizer.zero_grad()
            preds = model(OD, C, T0)
            loss = F.mse_loss(preds, targets)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(dataloader):.4f}")

train(model, dataloader)

  0%|          | 0/5 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (552x3 and 16x32)

In [80]:
model = FlowDistributionModel(num_nodes=num_nodes)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in tqdm(range(100)):
    for batch in dataloader:
        OD_batch, C_batch, T0_batch, target_flows_batch = batch.values()

        optimizer.zero_grad()
        
        pred_flows = model(OD_batch, C_batch, T0_batch)
        loss = loss_fn(pred_flows, target_flows_batch, OD_batch, C_batch)
        
        loss.backward()
        optimizer.step()

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!