In [1]:
# CityCAN

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GlobalLocalAttentionEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads, local_window_size):
        super(GlobalLocalAttentionEncoder, self).__init__()
        self.local_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads)
        self.global_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads)
        self.fc = nn.Linear(input_dim, hidden_dim)
        self.local_window_size = local_window_size

    def forward(self, x):
        # x: [batch_size, num_regions, seq_len, input_dim]
        batch_size, num_regions, seq_len, input_dim = x.size()
        x = x.view(batch_size * num_regions, seq_len, input_dim)
        x = self.fc(x)  # [batch_size*num_regions, seq_len, hidden_dim]
        
        # Local Attention
        local_outputs = []
        for i in range(seq_len - self.local_window_size + 1):
            local_x = x[:, i:i + self.local_window_size, :]
            local_output, _ = self.local_attn(local_x, local_x, local_x)
            local_outputs.append(local_output[:, -1, :])
        local_outputs = torch.stack(local_outputs, dim=1)
        
        # Global Attention
        global_output, _ = self.global_attn(x, x, x)
        
        return local_outputs, global_output

class CityCAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads, local_window_size, num_blocks):
        super(CityCAN, self).__init__()
        self.blocks = nn.ModuleList([
            GlobalLocalAttentionEncoder(input_dim if i == 0 else hidden_dim, hidden_dim, num_heads, local_window_size)
            for i in range(num_blocks)
        ])
        self.fc_output = nn.Linear(hidden_dim, 1)  # Final output layer
        
    def forward(self, x):
        for block in self.blocks:
            local_outputs, global_output = block(x)
            x = local_outputs + global_output
        
        x = x.mean(dim=1)  # [batch_size, num_regions, hidden_dim]
        out = self.fc_output(x)
        return out

# Loss function
def citywide_loss(pred, target, region_prior):
    mse_loss = F.mse_loss(pred, target)
    cosine_loss = 1 - F.cosine_similarity(pred, target).mean()
    return mse_loss + cosine_loss * region_prior

# Example usage
batch_size = 16
num_regions = 77
seq_len = 6
input_dim = 32
hidden_dim = 128
num_heads = 4
local_window_size = 3
num_blocks = 3

model = CityCAN(input_dim, hidden_dim, num_heads, local_window_size, num_blocks)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Example input
x = torch.randn(batch_size, num_regions, seq_len, input_dim)
target = torch.randn(batch_size, num_regions, 1)
region_prior = torch.ones(batch_size, num_regions)  # Example prior

# Training loop
model.train()
for epoch in range(10):
    optimizer.zero_grad()
    pred = model(x)
    loss = citywide_loss(pred, target, region_prior)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")



RuntimeError: The size of tensor a (4) must match the size of tensor b (6) at non-singleton dimension 1