In [2]:
import pandas as pd
import numpy as np
import torch
import math
import torch.nn as nn
from torch.nn import LayerNorm
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
from utils import *
import torch.nn.functional as F

In [3]:
node_coords = pd.read_csv("data/sioux/SiouxFalls_node.tntp", sep='\t')
node_coords_arr = np.array(node_coords[['X', 'Y']])

sioux = create_network_df(network_name="SiouxFalls")
T_0, C = prepare_network_data(sioux)

directory = "/home/podozerovapo/traffic_assignment/data/sioux/uncongested"

inputs = []
outputs = []
metadata = []

for filename in sorted(os.listdir(directory)):
    if filename.endswith(".pkl"):
        filepath = os.path.join(directory, filename)
        
        with open(filepath, 'rb') as f:
            data_pair = pickle.load(f)
            
            inputs.append(data_pair['input'])
            outputs.append(data_pair['output'])
            metadata.append(data_pair.get('metadata', None))

input_matrices = np.array(inputs)  # [num_samples, num_nodes, num_nodes]
output_matrices = np.array(outputs)  # [num_samples, num_nodes, num_nodes]


In [80]:
class FeatureEmbedding(nn.Module):
    def __init__(self, input_size, embedding_size=32):
        super(FeatureEmbedding, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, embedding_size),
            nn.ReLU()
        )
        
    def forward(self, x):
        return self.network(x)

class SharedEncoderComponents(nn.Module):
    """Shared components between V-Encoder and R-Encoder"""
    def __init__(self, node_feat_size, n_heads, dropout=0.1):
        super().__init__()
        self.node_feat_size = node_feat_size
        self.n_heads = n_heads
        self.head_dim = node_feat_size // n_heads
        
        # Linear transformations for Q, K, V
        self.q_linear = nn.Linear(node_feat_size, node_feat_size)
        self.k_linear = nn.Linear(node_feat_size, node_feat_size)
        self.v_linear = nn.Linear(node_feat_size, node_feat_size)
        
        # Output transformations
        self.out_ffn = nn.Sequential(
            nn.Linear(node_feat_size, node_feat_size),
            nn.ReLU(),
            nn.Linear(node_feat_size, node_feat_size)
        )
        
        self.layer_norm = nn.LayerNorm(node_feat_size)
        self.dropout = nn.Dropout(dropout)

class VEncoderLayer(nn.Module):
    """Virtual Link Encoder Layer"""
    def __init__(self, node_feat_size, edge_feat_size, n_heads, dropout=0.1):
        super().__init__()
        self.shared = SharedEncoderComponents(node_feat_size, n_heads, dropout)
        self.n_heads = n_heads
        self.head_dim = node_feat_size // n_heads
        # FFN for adaptive virtual edge weights
        self.edge_weight_ffn = nn.Sequential(
            nn.Linear(2 * node_feat_size, edge_feat_size),
            nn.ReLU(),
            nn.Linear(edge_feat_size, n_heads),
            nn.Sigmoid()
        )
        
        # Additional FFN for the layer
        self.ffn = nn.Sequential(
            nn.Linear(node_feat_size, node_feat_size * 2),
            nn.ReLU(),
            nn.Linear(node_feat_size * 2, node_feat_size),
            nn.Dropout(dropout)
        )
        
    def forward(self, x, adj_mask):
        batch_size, num_nodes, _ = x.size()
        #print(x.shape) # [64, 24, 3] batch, nodes, 
        
        # Project to Q, K, V
        q = self.shared.q_linear(x).view(batch_size, num_nodes, self.n_heads, self.head_dim)
        k = self.shared.k_linear(x).view(batch_size, num_nodes, self.n_heads, self.head_dim)
        v = self.shared.v_linear(x).view(batch_size, num_nodes, self.n_heads, self.head_dim)
        
        # Prepare for attention
        q = q.transpose(1, 2)  # [batch, heads, nodes, head_dim]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # Compute scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Compute adaptive edge weights for virtual links
        src_nodes = x.unsqueeze(2).expand(-1, -1, num_nodes, -1)
        dst_nodes = x.unsqueeze(1).expand(-1, num_nodes, -1, -1)
        node_pairs = torch.cat([src_nodes, dst_nodes], dim=-1)
        beta = self.edge_weight_ffn(node_pairs).permute(0, 3, 1, 2)  # [batch, heads, nodes, nodes]
        
        # Apply adjacency mask and combine with attention scores
        #print(adj_mask.shape)
        adj_mask = adj_mask.unsqueeze(1)  # Add head dimension
        #print(scores.shape, adj_mask.shape, beta.shape)
        scores = scores * adj_mask * beta  # Apply both mask and adaptive weights
        
        # Apply softmax
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.shared.dropout(attn_weights)
        
        # Apply attention to values
        output = torch.matmul(attn_weights, v)
        output = output.transpose(1, 2).contiguous()
        output = output.view(batch_size, num_nodes, self.shared.node_feat_size)
        
        # Residual connection and layer norm
        output = self.shared.out_ffn(output)
        output = self.shared.dropout(output)
        output = self.shared.layer_norm(x + output)
        
        # FFN block
        ffn_output = self.ffn(output)
        output = self.shared.layer_norm(output + ffn_output)
        
        return output

class REncoderLayer(nn.Module):
    """Real Link Encoder Layer"""
    def __init__(self, node_feat_size, edge_feat_size, n_heads, dropout=0.1):
        super().__init__()
        #print('node feat size', node_feat_size)
        self.shared = SharedEncoderComponents(node_feat_size, n_heads, dropout)
        self.n_heads = n_heads
        self.head_dim = node_feat_size // n_heads
        #print(edge_feat_size)
        # Edge feature transformation
        self.edge_feat_transform = nn.Linear(edge_feat_size, n_heads)
        
        # Additional FFN for the layer
        self.ffn = nn.Sequential(
            nn.Linear(node_feat_size, node_feat_size * 2),
            nn.ReLU(),
            nn.Linear(node_feat_size * 2, node_feat_size),
            nn.Dropout(dropout)
        )
        
    def forward(self, x, adj_mask, edge_features):
        batch_size, num_nodes, _ = x.size()
        
        # Project to Q, K, V
        q = self.shared.q_linear(x).view(batch_size, num_nodes, self.n_heads, self.head_dim)
        k = self.shared.k_linear(x).view(batch_size, num_nodes, self.n_heads, self.head_dim)
        v = self.shared.v_linear(x).view(batch_size, num_nodes, self.n_heads, self.head_dim)
        
        # Prepare for attention
        q = q.transpose(1, 2)  # [batch, heads, nodes, head_dim]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # Compute scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Transform edge features
        edge_weights = self.edge_feat_transform(edge_features).permute(0, 3, 1, 2)  # [batch, heads, nodes, nodes]
        
        # Apply adjacency mask and edge features
        adj_mask = adj_mask.unsqueeze(1)  # Add head dimension
        scores = scores * adj_mask * torch.sigmoid(edge_weights)  # Combine with edge features
        
        # Apply softmax
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.shared.dropout(attn_weights)
        
        # Apply attention to values
        output = torch.matmul(attn_weights, v)
        output = output.transpose(1, 2).contiguous()
        output = output.view(batch_size, num_nodes, self.shared.node_feat_size)
        
        # Residual connection and layer norm
        output = self.shared.out_ffn(output)
        output = self.shared.dropout(output)
        output = self.shared.layer_norm(x + output)
        
        # FFN block
        ffn_output = self.ffn(output)
        output = self.shared.layer_norm(output + ffn_output)
        
        return output

class VEncoder(nn.Module):
    """Stacked Virtual Link Encoder"""
    def __init__(self, node_feat_size, edge_feat_size, n_layers=3, n_heads=4, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            VEncoderLayer(node_feat_size, edge_feat_size, n_heads, dropout)
            for _ in range(n_layers)
        ])
        
    def forward(self, x, v_adj_mask):
        for layer in self.layers:
            x = layer(x, v_adj_mask)
        return x

class REncoder(nn.Module):
    """Stacked Real Link Encoder"""
    def __init__(self, node_feat_size, edge_feat_size, n_layers=3, n_heads=4, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            REncoderLayer(node_feat_size, edge_feat_size, n_heads, dropout)
            for _ in range(n_layers)
        ])
        
    def forward(self, x, adj_mask, edge_features):
        for layer in self.layers:
            x = layer(x, adj_mask, edge_features)
        return x

class DualGraphEncoder(nn.Module):
    """Complete architecture with both V-Encoder and R-Encoder"""
    def __init__(self, node_feat_size, edge_feat_size,
                 v_layers=3, r_layers=3, n_heads=4, dropout=0.1):
        super().__init__()
        self.vencoder = VEncoder(node_feat_size, edge_feat_size, v_layers, n_heads, dropout)
        self.rencoder = REncoder(node_feat_size, edge_feat_size, r_layers, n_heads, dropout)

    def forward(self, node_features, real_adj_mask, virtual_adj_mask, edge_features):
        """
        Args:
            node_features: Input node features [batch, nodes, feat_size]
            real_adj_mask: Adjacency mask for real edges [batch, nodes, nodes]
            edge_features: Edge features for real links [batch, nodes, nodes, edge_feat_size]
            virtual_adj_mask: Mask for virtual edges [batch, nodes, nodes]
        
        Returns:
            Tuple of (v_encoded, r_encoded) features
        """
        batch_size, num_nodes, _ = node_features.size()
        
        # Create virtual adjacency mask if not provided (self-loops only)
        if virtual_adj_mask is None:
            raise ValueError("no virtual adjacency mask provided")
        
        # Combine real and virtual edges for V-Encoder
        # v_adj_mask = real_adj_mask + virtual_adj_mask
        # v_adj_mask = (v_adj_mask > 0).float()
        
        # Process through V-Encoder (captures long-range dependencies)
        v_encoded = self.vencoder(node_features, virtual_adj_mask)
        
        # Process through R-Encoder (captures local topology)
        r_encoded = self.rencoder(node_features, real_adj_mask, edge_features)
        
        return v_encoded, r_encoded

class TrafficAssignmentModel(nn.Module):
    """Complete model with preprocessing, dual encoders, and flow prediction"""
    def __init__(self, num_nodes, node_feat_size=32, edge_feat_size=2,
                 v_layers=2, r_layers=2, n_heads=4, dropout=0.1):
        super().__init__()
        
        self.feature_preprocessor = FeatureEmbedding(input_size=num_nodes + 2, embedding_size=32)
        
        # Feature preprocessing would be done separately
        self.dual_encoder = DualGraphEncoder(node_feat_size, edge_feat_size, 
                                           v_layers, r_layers, n_heads, dropout)
        
        # Flow prediction head
        self.flow_predictor = nn.Sequential(
            nn.Linear(4 * node_feat_size, node_feat_size),
            nn.ReLU(),
            nn.Linear(node_feat_size, 1),
            nn.ReLU()  # Flow must be non-negative
        )

    def forward(self, node_features, real_adj_mask, virt_adj_mask, edge_features):
        
        node_features = self.feature_preprocessor(node_features)
        #print('!!!!!!', real_adj_mask.shape)
        v_encoded, r_encoded = self.dual_encoder(node_features, real_adj_mask, virt_adj_mask, edge_features)
        #print(v_encoded.shape, r_encoded.shape)
        combined = torch.cat([v_encoded, r_encoded], dim=-1)
        src_nodes = combined.unsqueeze(2).expand(-1, -1, combined.size(1), -1)
        dst_nodes = combined.unsqueeze(1).expand(-1, combined.size(1), -1, -1)
        #print(src_nodes.shape, dst_nodes.shape)
        pair_features = torch.cat([src_nodes, dst_nodes], dim=-1)
        #print(pair_features.shape)
        
        flows = self.flow_predictor(pair_features).squeeze(-1) # [batch, nodes, nodes, 1]
        flows = flows * real_adj_mask
        
        return flows

In [5]:
coords_for_concat = np.repeat(np.expand_dims(node_coords_arr, 0), repeats=input_matrices.shape[0],axis=0)

In [6]:
X = np.concatenate([input_matrices, coords_for_concat], axis=2)
X.shape

(4096, 24, 26)

In [7]:
capacities = np.repeat(np.expand_dims(C, 0), axis=0, repeats=input_matrices.shape[0])
free_flow_times = np.repeat(np.expand_dims(T_0, 0), axis=0, repeats=input_matrices.shape[0])

In [19]:
class TrafficDataset(Dataset):
    def __init__(self, X_data, capacity_data, free_flow_data, flows_data):
        self.data = []
        
        # Нормализация данных
        for x, cap, fft, fl in zip(X_data, capacity_data, free_flow_data, flows_data):
            # Отделяем OD-матрицу (спрос) и координаты
            od_matrix = x[:, :-2]
            coordinates = x[:, -2:]
            
            # Нормализация спроса
            od_mean = od_matrix.mean()
            od_std = od_matrix.std() + 1e-8
            od_matrix = (od_matrix - od_mean) / od_std
    
            # Нормализация координат
            coord_mean = coordinates.mean(axis=0)
            coord_std = coordinates.std(axis=0) + 1e-8
            coordinates = (coordinates - coord_mean) / coord_std
    
            x_normalized = np.concatenate([od_matrix, coordinates], axis=1)
            
            # Нормализация capacity и free_flow_time
            cap = (cap - cap.mean()) / (cap.std() + 1e-8)
            fft = (fft - fft.mean()) / (fft.std() + 1e-8)
            
            # Нормализация целевых потоков
            fl = (fl - fl.mean()) / (fl.std() + 1e-8)
            
            self.data.append({
                'od_matrix': torch.FloatTensor(od_matrix),
                'coordinates': torch.FloatTensor(coordinates),
                'capacity': torch.FloatTensor(cap),
                'free_flow': torch.FloatTensor(fft),
                'flows': torch.FloatTensor(fl),
                'x_normalized' : torch.FloatTensor(x_normalized)
            })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        # Создание масок с учетом величины спроса
        real_adj_mask = (item['capacity'] > 0).float()
        virtual_adj = item['od_matrix'].abs()  # Учитываем абсолютное значение спроса
        
        # Объединение признаков ребер
        edge_features = torch.stack([
            item['capacity'],
            item['free_flow']
        ], dim=-1)
        
        # Нормализованные node features = координаты + агрегированный спрос
        agg_demand = item['od_matrix'].sum(dim=1, keepdim=True)
        node_features = torch.cat([item['coordinates'], agg_demand], dim=1)
        
        return {
            'node_features': node_features,
            'real_adj_mask': real_adj_mask,
            'edge_features': edge_features,
            'virtual_adj': virtual_adj,
            'target_flows': item['flows'],
            'x_normalized' : item['x_normalized']
        }

# class TrafficAssignmentModel(nn.Module):
#     def __init__(self, node_feat_size=32, edge_feat_size=2,
#                  v_layers=2, r_layers=2, n_heads=4, dropout=0.1):
#         super().__init__()
        
#         # Нормализация входных данных
#         self.node_bn = LayerNorm(3)  # координаты(2) + агрегированный спрос(1)
#         self.edge_bn = LayerNorm(edge_feat_size)
        
#         # Кодировщики
#         self.dual_encoder = DualGraphEncoder(
#             node_feat_size, edge_feat_size, v_layers, r_layers, n_heads, dropout
#         )
        
#         # Предиктор потоков с учетом спроса
#         self.flow_predictor = nn.Sequential(
#             nn.Linear(2*node_feat_size + 1, node_feat_size),  # +1 для спроса
#             nn.ReLU(),
#             LayerNorm(node_feat_size),
#             nn.Linear(node_feat_size, 1),
#             nn.ReLU()
#         )

#     def forward(self, node_features, real_adj_mask, edge_features, virtual_adj):
#         # Нормализация
#         node_features = self.node_bn(node_features)
#         edge_features = self.edge_bn(edge_features)
        
#         # Кодирование
#         v_encoded, r_encoded = self.dual_encoder(
#             node_features, real_adj_mask, edge_features, virtual_adj
#         )
        
#         # Комбинирование признаков с добавлением спроса
#         batch_size, n_nodes, _ = v_encoded.size()
#         src = v_encoded.unsqueeze(2) + r_encoded.unsqueeze(2)
#         dst = v_encoded.unsqueeze(1) + r_encoded.unsqueeze(1)
        
#         # Добавляем информацию о спросе к парам узлов
#         demand = virtual_adj.unsqueeze(-1)  # [batch, n, n, 1]
#         pair_features = torch.cat([
#             src.expand(-1,-1,n_nodes,-1),
#             dst.expand(-1,n_nodes,-1,-1),
#             demand
#         ], dim=-1)
        
#         # Прогнозирование потоков
#         flows = self.flow_predictor(pair_features).squeeze(-1)
#         return flows * real_adj_mask

# class VEncoderLayer(nn.Module):
#     """Virtual Link Encoder Layer"""
#     def __init__(self, node_feat_size, edge_feat_size, n_heads, dropout=0.1):
#         super().__init__()
#         self.shared = SharedEncoderComponents(node_feat_size, n_heads, dropout)
        
#         # FFN for adaptive virtual edge weights
#         self.edge_weight_ffn = nn.Sequential(
#             nn.Linear(2 * node_feat_size, edge_feat_size),
#             nn.ReLU(),
#             nn.Linear(edge_feat_size, n_heads),
#             nn.Sigmoid()
#         )
        
#         # Additional FFN for the layer
#         self.ffn = nn.Sequential(
#             nn.Linear(node_feat_size, node_feat_size * 2),
#             nn.ReLU(),
#             nn.Linear(node_feat_size * 2, node_feat_size),
#             nn.Dropout(dropout)
#         )
        
#     def forward(self, x, adj_mask, demand):
#         batch_size, num_nodes, _ = x.size()
        
#         # Project to Q, K, V
#         q = self.shared.q_linear(x).view(batch_size, num_nodes, self.n_heads, self.head_dim)
#         k = self.shared.k_linear(x).view(batch_size, num_nodes, self.n_heads, self.head_dim)
#         v = self.shared.v_linear(x).view(batch_size, num_nodes, self.n_heads, self.head_dim)
        
#         # Prepare for attention
#         q = q.transpose(1, 2)  # [batch, heads, nodes, head_dim]
#         k = k.transpose(1, 2)
#         v = v.transpose(1, 2)
        
#         # Compute scaled dot-product attention
#         scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
#         # Compute adaptive edge weights for virtual links
#         src_demand = demand.unsqueeze(-1)
#         dst_demand = demand.unsqueeze(-2)
#         demand_pairs = torch.cat([src_demand, dst_demand], dim=-1)
        
#         # Модифицированное вычисление beta с учетом спроса
#         node_pairs = torch.cat([
#             x.unsqueeze(2).expand(-1,-1,num_nodes,-1),
#             x.unsqueeze(1).expand(-1,num_nodes,-1,-1),
#             demand_pairs
#         ], dim=-1)
        
#         beta = self.edge_weight_ffn(node_pairs).permute(0,3,1,2) # [batch, heads, nodes, nodes]
#         # Apply adjacency mask and combine with attention scores
#         adj_mask = adj_mask.unsqueeze(1)  # Add head dimension
#         scores = scores * adj_mask * beta  # Apply both mask and adaptive weights
        
#         # Apply softmax
#         attn_weights = F.softmax(scores, dim=-1)
#         attn_weights = self.shared.dropout(attn_weights)
        
#         # Apply attention to values
#         output = torch.matmul(attn_weights, v)
#         output = output.transpose(1, 2).contiguous()
#         output = output.view(batch_size, num_nodes, self.shared.node_feat_size)
        
#         # Residual connection and layer norm
#         output = self.shared.out_ffn(output)
#         output = self.shared.dropout(output)
#         output = self.shared.layer_norm(x + output)
        
#         # FFN block
#         ffn_output = self.ffn(output)
#         output = self.shared.layer_norm(output + ffn_output)
        
#         return output

# Пример использования с нормализацией и учетом спроса

In [56]:
from tqdm import tqdm

In [84]:
def train_model(model, train_loader, val_loader, num_epochs=100, lr=1e-3, patience=5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3)
    criterion = nn.MSELoss()
    
    best_loss = float('inf')
    epochs_no_improve = 0
    train_history = []
    val_history = []
    
    for epoch in tqdm(range(num_epochs)):
        #print('train')
        # Training phase
        model.train()
        train_loss = 0.0
        for batch in train_loader:
            # Prepare batch data
            # node_feat = batch['node_features'].to(device)
            node_feat = batch['x_normalized'].to(device)
            real_mask = batch['real_adj_mask'].to(device)
            edge_feat = batch['edge_features'].to(device)
            virt_adj = batch['virtual_adj'].to(device)
            targets = batch['target_flows'].to(device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(node_feat, real_mask, virt_adj, edge_feat)
            
            # Compute loss only on real edges
            loss = criterion(outputs[real_mask.bool()], targets[real_mask.bool()])
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
            optimizer.step()
            
            train_loss += loss.item() * node_feat.size(0)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            #print('val')
            for batch in val_loader:
                #node_feat = batch['node_features'].to(device)
                node_feat = batch['x_normalized'].to(device)
                real_mask = batch['real_adj_mask'].to(device)
                edge_feat = batch['edge_features'].to(device)
                virt_adj = batch['virtual_adj'].to(device)
                targets = batch['target_flows'].to(device)
                
                outputs = model(node_feat, real_mask, virt_adj, edge_feat)
                
                loss = criterion(outputs[real_mask.bool()], targets[real_mask.bool()])
                val_loss += loss.item() * node_feat.size(0)
        
        # Calculate epoch metrics
        train_loss = train_loss / len(train_loader.dataset)
        val_loss = val_loss / len(val_loader.dataset)
        train_history.append(train_loss)
        val_history.append(val_loss)
        
        # Update learning rate
        scheduler.step(val_loss)
        
        # Print progress
        print(f'Epoch {epoch + 1}/{num_epochs}')
        print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}')
        print(f'LR: {optimizer.param_groups[0]["lr"]:.2e}')
        
        # Early stopping
        if val_loss < best_loss:
            best_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                print(f'Early stopping after {epoch+1} epochs')
                break
    
    # Plot training history
    plt.figure(figsize=(10, 6))
    plt.plot(train_history, label='Train Loss')
    plt.plot(val_history, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training History')
    plt.legend()
    plt.savefig('training_history.png')
    plt.close()
    
    # Load best model
    model.load_state_dict(torch.load('best_model.pth'))
    return model

In [89]:
def model_predict(model, data_loader, device=None):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = model.to(device)
    model.eval()
    
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc='Making predictions'):
            node_feat = batch['x_normalized'].to(device)
            real_mask = batch['real_adj_mask'].to(device)
            edge_feat = batch['edge_features'].to(device)
            virt_adj = batch['virtual_adj'].to(device)
            targets = batch['target_flows'].to(device)
            
            outputs = model(node_feat, real_mask, virt_adj, edge_feat)
            
            mask = real_mask.bool()
            all_predictions.append(outputs[mask].cpu().numpy())
            all_targets.append(targets[mask].cpu().numpy())
            break
    
    predictions = np.concatenate(all_predictions)
    targets = np.concatenate(all_targets)
    
    return predictions, targets

In [91]:
ans = model_predict(model, train_loader)

Making predictions:   0%|          | 0/52 [00:00<?, ?it/s]


In [105]:
print(ans[0], ans[1], sep='\n')

[0.52857995 1.0203347  0.5007929  ... 2.244346   2.7004042  0.9136704 ]
[0.51010585 1.0598823  0.5754481  ... 2.2259939  2.9019318  0.9558847 ]


In [None]:
n_nodes = 24
X_data = X
capacity_data = capacities
free_flow_data = free_flow_times
flows_data = output_matrices

full_dataset = TrafficDataset(X_data, capacity_data, free_flow_data, flows_data)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

model = TrafficAssignmentModel(24)

# Запуск обучения
trained_model = train_model(
    model,
    train_loader,
    val_loader,
    num_epochs=100,
    lr=1e-3,
    patience=10
)

# Сохранение финальной модели
torch.save(trained_model.state_dict(), 'final_model.pth')

  1%|          | 1/100 [00:03<05:46,  3.50s/it]

Epoch 1/100
Train Loss: 1.7905 | Val Loss: 0.8613
LR: 1.00e-03


  2%|▏         | 2/100 [00:06<05:42,  3.49s/it]

Epoch 2/100
Train Loss: 0.6359 | Val Loss: 0.3121
LR: 1.00e-03


  3%|▎         | 3/100 [00:10<05:47,  3.59s/it]

Epoch 3/100
Train Loss: 0.1756 | Val Loss: 0.0997
LR: 1.00e-03


  4%|▍         | 4/100 [00:14<05:37,  3.51s/it]

Epoch 4/100
Train Loss: 0.0990 | Val Loss: 0.0806
LR: 1.00e-03


  5%|▌         | 5/100 [00:17<05:43,  3.61s/it]

Epoch 5/100
Train Loss: 0.0864 | Val Loss: 0.0763
LR: 1.00e-03


  6%|▌         | 6/100 [00:21<05:30,  3.52s/it]

Epoch 6/100
Train Loss: 0.0797 | Val Loss: 0.0718
LR: 1.00e-03


  7%|▋         | 7/100 [00:24<05:26,  3.51s/it]

Epoch 7/100
Train Loss: 0.0746 | Val Loss: 0.0678
LR: 1.00e-03


  8%|▊         | 8/100 [00:28<05:17,  3.45s/it]

Epoch 8/100
Train Loss: 0.0713 | Val Loss: 0.0669
LR: 1.00e-03


  9%|▉         | 9/100 [00:31<05:20,  3.52s/it]

Epoch 9/100
Train Loss: 0.0688 | Val Loss: 0.0634
LR: 1.00e-03


 10%|█         | 10/100 [00:35<05:19,  3.55s/it]

Epoch 10/100
Train Loss: 0.0664 | Val Loss: 0.0617
LR: 1.00e-03


 11%|█         | 11/100 [00:39<05:19,  3.59s/it]

Epoch 11/100
Train Loss: 0.0646 | Val Loss: 0.0603
LR: 1.00e-03


 12%|█▏        | 12/100 [00:42<05:16,  3.60s/it]

Epoch 12/100
Train Loss: 0.0628 | Val Loss: 0.0582
LR: 1.00e-03


 13%|█▎        | 13/100 [00:45<05:06,  3.52s/it]

Epoch 13/100
Train Loss: 0.0608 | Val Loss: 0.0587
LR: 1.00e-03


 14%|█▍        | 14/100 [00:49<05:02,  3.52s/it]

Epoch 14/100
Train Loss: 0.0590 | Val Loss: 0.0551
LR: 1.00e-03


 15%|█▌        | 15/100 [00:53<05:00,  3.54s/it]

Epoch 15/100
Train Loss: 0.0580 | Val Loss: 0.0545
LR: 1.00e-03


 16%|█▌        | 16/100 [00:56<04:47,  3.42s/it]

Epoch 16/100
Train Loss: 0.0561 | Val Loss: 0.0521
LR: 1.00e-03


 17%|█▋        | 17/100 [00:59<04:36,  3.34s/it]

Epoch 17/100
Train Loss: 0.0546 | Val Loss: 0.0526
LR: 1.00e-03


 18%|█▊        | 18/100 [01:02<04:29,  3.29s/it]

Epoch 18/100
Train Loss: 0.0537 | Val Loss: 0.0502
LR: 1.00e-03


 19%|█▉        | 19/100 [01:04<04:03,  3.00s/it]

Epoch 19/100
Train Loss: 0.0519 | Val Loss: 0.0496
LR: 1.00e-03


 20%|██        | 20/100 [01:07<03:50,  2.88s/it]

Epoch 20/100
Train Loss: 0.0510 | Val Loss: 0.0504
LR: 1.00e-03


 21%|██        | 21/100 [01:10<04:01,  3.05s/it]

Epoch 21/100
Train Loss: 0.0507 | Val Loss: 0.0478
LR: 1.00e-03


 22%|██▏       | 22/100 [01:14<04:07,  3.18s/it]

Epoch 22/100
Train Loss: 0.0500 | Val Loss: 0.0481
LR: 1.00e-03


 23%|██▎       | 23/100 [01:17<04:05,  3.18s/it]

Epoch 23/100
Train Loss: 0.0482 | Val Loss: 0.0468
LR: 1.00e-03


 24%|██▍       | 24/100 [01:21<04:09,  3.28s/it]

Epoch 24/100
Train Loss: 0.0475 | Val Loss: 0.0462
LR: 1.00e-03


 25%|██▌       | 25/100 [01:24<04:09,  3.33s/it]

Epoch 25/100
Train Loss: 0.0468 | Val Loss: 0.0456
LR: 1.00e-03


 26%|██▌       | 26/100 [01:28<04:11,  3.40s/it]

Epoch 26/100
Train Loss: 0.0468 | Val Loss: 0.0438
LR: 1.00e-03


 27%|██▋       | 27/100 [01:31<04:09,  3.42s/it]

Epoch 27/100
Train Loss: 0.0460 | Val Loss: 0.0451
LR: 1.00e-03


 28%|██▊       | 28/100 [01:35<04:10,  3.47s/it]

Epoch 28/100
Train Loss: 0.0453 | Val Loss: 0.0444
LR: 1.00e-03


 29%|██▉       | 29/100 [01:38<04:08,  3.51s/it]

Epoch 29/100
Train Loss: 0.0449 | Val Loss: 0.0425
LR: 1.00e-03


 30%|███       | 30/100 [01:42<04:07,  3.53s/it]

Epoch 30/100
Train Loss: 0.0443 | Val Loss: 0.0430
LR: 1.00e-03


 31%|███       | 31/100 [01:45<04:03,  3.53s/it]

Epoch 31/100
Train Loss: 0.0440 | Val Loss: 0.0420
LR: 1.00e-03


 32%|███▏      | 32/100 [01:49<04:04,  3.60s/it]

Epoch 32/100
Train Loss: 0.0437 | Val Loss: 0.0427
LR: 1.00e-03


 33%|███▎      | 33/100 [01:53<04:01,  3.61s/it]

Epoch 33/100
Train Loss: 0.0434 | Val Loss: 0.0413
LR: 1.00e-03


 34%|███▍      | 34/100 [01:56<03:52,  3.52s/it]

Epoch 34/100
Train Loss: 0.0422 | Val Loss: 0.0426
LR: 1.00e-03


 35%|███▌      | 35/100 [02:00<03:47,  3.50s/it]

Epoch 35/100
Train Loss: 0.0419 | Val Loss: 0.0397
LR: 1.00e-03


 36%|███▌      | 36/100 [02:03<03:43,  3.49s/it]

Epoch 36/100
Train Loss: 0.0415 | Val Loss: 0.0405
LR: 1.00e-03


 37%|███▋      | 37/100 [02:06<03:39,  3.49s/it]

Epoch 37/100
Train Loss: 0.0409 | Val Loss: 0.0390
LR: 1.00e-03


 38%|███▊      | 38/100 [02:10<03:37,  3.51s/it]

Epoch 38/100
Train Loss: 0.0404 | Val Loss: 0.0405
LR: 1.00e-03


 39%|███▉      | 39/100 [02:14<03:33,  3.50s/it]

Epoch 39/100
Train Loss: 0.0397 | Val Loss: 0.0395
LR: 1.00e-03


 40%|████      | 40/100 [02:17<03:31,  3.52s/it]

Epoch 40/100
Train Loss: 0.0396 | Val Loss: 0.0379
LR: 1.00e-03


 41%|████      | 41/100 [02:20<03:24,  3.47s/it]

Epoch 41/100
Train Loss: 0.0391 | Val Loss: 0.0402
LR: 1.00e-03


 42%|████▏     | 42/100 [02:24<03:19,  3.43s/it]

Epoch 42/100
Train Loss: 0.0389 | Val Loss: 0.0383
LR: 1.00e-03


 43%|████▎     | 43/100 [02:27<03:15,  3.42s/it]

Epoch 43/100
Train Loss: 0.0382 | Val Loss: 0.0362
LR: 1.00e-03


 44%|████▍     | 44/100 [02:30<03:08,  3.37s/it]

Epoch 44/100
Train Loss: 0.0376 | Val Loss: 0.0365
LR: 1.00e-03


 45%|████▌     | 45/100 [02:34<03:07,  3.40s/it]

Epoch 45/100
Train Loss: 0.0376 | Val Loss: 0.0359
LR: 1.00e-03


 46%|████▌     | 46/100 [02:37<03:06,  3.46s/it]

Epoch 46/100
Train Loss: 0.0368 | Val Loss: 0.0377
LR: 1.00e-03


 47%|████▋     | 47/100 [02:41<03:05,  3.49s/it]

Epoch 47/100
Train Loss: 0.0367 | Val Loss: 0.0353
LR: 1.00e-03


 48%|████▊     | 48/100 [02:45<03:01,  3.48s/it]

Epoch 48/100
Train Loss: 0.0364 | Val Loss: 0.0361
LR: 1.00e-03


 49%|████▉     | 49/100 [02:48<02:59,  3.51s/it]

Epoch 49/100
Train Loss: 0.0362 | Val Loss: 0.0353
LR: 1.00e-03


 50%|█████     | 50/100 [02:52<02:55,  3.51s/it]

Epoch 50/100
Train Loss: 0.0362 | Val Loss: 0.0345
LR: 1.00e-03


 51%|█████     | 51/100 [02:55<02:51,  3.51s/it]

Epoch 51/100
Train Loss: 0.0360 | Val Loss: 0.0367
LR: 1.00e-03


 52%|█████▏    | 52/100 [02:59<02:48,  3.52s/it]

Epoch 52/100
Train Loss: 0.0360 | Val Loss: 0.0340
LR: 1.00e-03


 53%|█████▎    | 53/100 [03:02<02:43,  3.49s/it]

Epoch 53/100
Train Loss: 0.0350 | Val Loss: 0.0337
LR: 1.00e-03


 54%|█████▍    | 54/100 [03:06<02:41,  3.51s/it]

Epoch 54/100
Train Loss: 0.0350 | Val Loss: 0.0337
LR: 1.00e-03


In [66]:
104*64 / 2 / 0.8

4160.0

In [62]:
s = """torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([64, 4, 24, 24]) torch.Size([64, 1, 24, 24]) torch.Size([64, 4, 24, 24])
torch.Size([12, 4, 24, 24]) torch.Size([12, 1, 24, 24]) torch.Size([12, 4, 24, 24])
torch.Size([12, 4, 24, 24]) torch.Size([12, 1, 24, 24]) torch.Size([12, 4, 24, 24])"""
s.count('torch') / 3

104.0

In [43]:
36864 / 64 / 24

24.0

In [38]:
49152 / 64 / 24 / 4

8.0

In [32]:
768 / 24

32.0

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import *

In [102]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv  # Using PyG for optimized GNN operations

class FlowDistributionModel(nn.Module):
    def __init__(self, num_nodes, hidden_dim=128, num_heads=2):
        super().__init__()
        self.num_nodes = num_nodes
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Input feature encoder (3 features: OD, C, T0)
        # self.encoder = nn.Sequential(
        #     nn.Linear(3, hidden_dim),
        #     nn.ReLU(),
        #     nn.Linear(hidden_dim, hidden_dim)
        #     #nn.Identity()
        # ).to(self.device)
        
        # Graph Attention Layers (using PyG for better GPU utilization)
        self.gat1 = GATConv(3, 16, heads=num_heads).to(self.device)
        self.gat2 = GATConv(16, 16, heads=num_heads).to(self.device)
        self.gat3 = GATConv(32, 1, heads=1).to(self.device)
        
        # Output decoder
        # self.decoder = nn.Sequential(
        #     nn.Linear(hidden_dim, hidden_dim),
        #     nn.ReLU(),
        #     nn.Linear(hidden_dim, 1)
        # ).to(self.device)
        
        # Prepare static edge index for complete graph
        self.edge_index = self._create_complete_graph_edge_index(num_nodes).to(self.device)
    
    def _create_complete_graph_edge_index(self, num_nodes):
        """Create edge index for fully connected graph"""
        adj = torch.ones(num_nodes, num_nodes) - torch.eye(num_nodes)
        edge_index = adj.nonzero(as_tuple=False).t()  # Get indices and transpose
        return edge_index
    
    def forward(self, OD, C, T0):
        batch_size = OD.shape[0]
        
        # Move inputs to GPU
        OD, C, T0 = OD.to(self.device), C.to(self.device), T0.to(self.device)
        
        # Prepare edge features [batch_size * num_edges, 3]
        edge_features = torch.stack([OD, C, T0], dim=-1)  # [batch, N, N, 3]
        edge_features = edge_features.view(-1, 3)  # [batch*N*N, 3]
        
        x = edge_features
        
        # Process each graph in batch separately (using vmap would be better)
        outputs = []
        for i in range(batch_size):
            # Get features for this graph
            start_idx = i * self.num_nodes * self.num_nodes
            end_idx = (i+1) * self.num_nodes * self.num_nodes
            x_graph = x[start_idx:end_idx][self.edge_index[0] * self.num_nodes + self.edge_index[1]]
            
            # GAT layers
            h = F.softplus(self.gat1(x_graph, self.edge_index))
            h = F.softplus(self.gat2(x_graph, self.edge_index))
            h = F.softplus(self.gat3(h, self.edge_index))
            
            # Decode to flows
            flows = torch.abs(h.squeeze(-1))
            
            # Reshape to adjacency matrix
            flow_mat = torch.zeros(self.num_nodes, self.num_nodes, device=self.device)
            flow_mat[self.edge_index[0], self.edge_index[1]] = flows
            outputs.append(flow_mat)
        
        return torch.stack(outputs)  # [batch, N, N]

In [57]:
def loss_fn(pred_flows, target_flows, OD, C):
    # Основная MSE потеря
    mse_loss = F.mse_loss(pred_flows, target_flows)
    
    # Потеря на сохранение потока (входной поток = выходной)
    inflow = pred_flows.sum(dim=1)  # [batch, num_nodes]
    outflow = pred_flows.sum(dim=2)  # [batch, num_nodes]
    flow_conservation_loss = F.mse_loss(inflow, outflow)
    
    # Потеря на превышение пропускной способности
    capacity_violation = torch.relu(pred_flows - C).mean()
    
    # Комбинированная потеря
    total_loss = mse_loss + 2 * flow_conservation_loss + 2 * capacity_violation
    return total_loss

In [75]:
class TrafficDataset(torch.utils.data.Dataset):
    """
    Dataset class for traffic assignment problem
    """
    def __init__(self, od_matrices, capacities, free_flow_times, link_flows):
        """
        Args:
            od_matrices: List/array of OD demand matrices [num_samples, num_nodes, num_nodes]
            capacities: List/array of capacity matrices [num_samples, num_edges]
            free_flow_times: List/array of free-flow time matrices [num_samples, num_nodes, num_nodes]
            link_flows: List/array of ground truth link flows [num_samples, num_edges]
        """
        self.nodes_num = len(od_matrices)
        self.od_matrices = torch.tensor(od_matrices, dtype=torch.float32)
        self.capacities = torch.tensor([capacities for _ in range(self.nodes_num)], dtype=torch.float32)
        self.free_flow_times = torch.tensor([free_flow_times for _ in range(self.nodes_num)], dtype=torch.float32)
        self.link_flows = torch.tensor(link_flows, dtype=torch.float32)
        
    def __len__(self):
        return len(self.od_matrices)
    
    def __getitem__(self, idx):
        return {
            'od_matrix': self.od_matrices[idx],
            'capacity': self.capacities[idx],
            'free_flow_time': self.free_flow_times[idx],
            'link_flows': self.link_flows[idx]
        }
    
    def get_batches(self, batch_size):
        dataloader = torch.utils.data.DataLoader(
            self, batch_size=batch_size, shuffle=True
        )
        return dataloader

In [7]:
sioux = create_network_df(network_name="SiouxFalls")
T_0, C = prepare_network_data(sioux)

directory = "/home/podozerovapo/traffic_assignment/data/sioux/uncongested"

inputs = []
outputs = []
metadata = []

for filename in sorted(os.listdir(directory)):
    if filename.endswith(".pkl"):
        filepath = os.path.join(directory, filename)
        
        with open(filepath, 'rb') as f:
            data_pair = pickle.load(f)
            
            inputs.append(data_pair['input'])
            outputs.append(data_pair['output'])
            metadata.append(data_pair.get('metadata', None))

input_matrices = np.array(inputs)  # [num_samples, num_nodes, num_nodes]
output_matrices = np.array(outputs)  # [num_samples, num_nodes, num_nodes]

print(f"Loaded {len(inputs)} samples")
print(f"Input shape: {input_matrices.shape}")
print(f"Output shape: {output_matrices.shape}")

dataset = TrafficDataset(input_matrices, C, T_0, output_matrices)
dataloader = dataset.get_batches(batch_size=128)
num_nodes = len(dataset[0]['od_matrix'])
num_edges = (dataset[0]['capacity'] > 0).sum()

Loaded 4096 samples
Input shape: (4096, 24, 24)
Output shape: (4096, 24, 24)


AttributeError: 'TrafficDataset' object has no attribute 'get_batches'

In [18]:

for i in dataloader:
    a = i
    break

In [22]:
a['od_matrix'].shape

torch.Size([128, 24, 24])

In [7]:
from tqdm import tqdm

In [103]:
model = FlowDistributionModel(num_nodes, hidden_dim=32).cuda()

In [104]:
def train(model, dataloader, epochs=5):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    for epoch in tqdm(range(epochs)):
        total_loss = 0
        for batch in dataloader:
            OD, C, T0, targets = batch.values()
            # Move data to GPU
            OD, C, T0, targets = OD.cuda(), C.cuda(), T0.cuda(), targets.cuda()
            
            optimizer.zero_grad()
            preds = model(OD, C, T0)
            loss = F.mse_loss(preds, targets)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(dataloader):.4f}")

train(model, dataloader)

  0%|          | 0/5 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (552x3 and 16x32)

In [80]:
model = FlowDistributionModel(num_nodes=num_nodes)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in tqdm(range(100)):
    for batch in dataloader:
        OD_batch, C_batch, T0_batch, target_flows_batch = batch.values()

        optimizer.zero_grad()
        
        pred_flows = model(OD_batch, C_batch, T0_batch)
        loss = loss_fn(pred_flows, target_flows_batch, OD_batch, C_batch)
        
        loss.backward()
        optimizer.step()

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

In [39]:
dataset[0].keys()

dict_keys(['od_matrix', 'capacity', 'free_flow_time', 'link_flows'])

In [81]:
OD, C, T_0, true = dataset[11].values()

In [82]:
OD.shape

torch.Size([24, 24])

In [101]:

pred_flows = model(OD.unsqueeze(0), C.unsqueeze(0), T_0.unsqueeze(0))[0]

# Loss calculation
pred_flows[true==0] = 0
loss = F.mse_loss(pred_flows, true.unsqueeze(0).cuda())
print(f"Initial loss: {loss.item():.4f}")

Initial loss: 894766.4375


  loss = F.mse_loss(pred_flows, true.unsqueeze(0).cuda())


In [84]:
torch.sum(pred_flows), torch.sum(true)

(tensor(380693.3750, device='cuda:0', grad_fn=<SumBackward0>), tensor(269510.))

In [6]:
torch.abs(pred_flows.sum(1) - true.sum(1))

NameError: name 'pred_flows' is not defined

In [96]:
true[true==0].to(int).sum()

tensor(0)

In [97]:
pred_flows[0][true==0].to(int).sum()

tensor(158915, device='cuda:0')

In [85]:
true[3].to(int)

tensor([   0,    0, 3056,    0, 4744,    0,    0,    0,    0,    0, 2630,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0])

In [86]:
pred_flows[0][3].to(int)

tensor([  10,    8, 4703,    0, 4910,   21,    7,   22,   29,   34, 1390,   16,
          17,   31,   33,   47,   23,    4,    5,    6,    8,    8,   25,    9],
       device='cuda:0')

In [51]:
torch.abs(true.cuda().to(int) - pred_flows[0].to(int))

tensor([[   0, 1175, 2180,  165,  165,  165,  165,  165,  165,  165,  165,  165,
          165,  165,  165,  165,  165,  165,  165,  165,  165,  165,  165,  165],
        [ 988,    0,    2,    1,    3,  903,    1,    0,    1,    0,    1,    3,
            0,    2,    2,    0,    2,    0,    2,    2,    0,    4,    0,    0],
        [ 957,    2,    0, 1342,    2,    1,    2,    2,    2,    1,    0, 1162,
            2,    2,    2,    2,    3,    0,    0,    0,    0,    2,    3,    0],
        [   1,    1,  452,    0, 1305,    0,    1,    0,    1,    2,  509,    0,
            0,    1,    1,    4,    0,    2,    2,    2,    1,    1,    0,    1],
        [   1,    2,    2, 1756,    0, 3106,    1,    0,  322,    5,    1,    3,
            2,    2,    1,    0,    2,    0,    2,    3,    2,    1,    4,    0],
        [   0,  753,    1,    1, 2486,    0,    0, 3291,    0,    0,    0,    2,
            1,    2,    2,    2,    0,    2,    2,    0,    3,    1,    2,    2],
        [   1,    1,  