In [None]:
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat

def exists(val):
    return val is not None

def uniq(arr):
    return {el: True for el in arr}.keys()

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

def max_neg_value(t):
    return -torch.finfo(t.dtype).max

def init_(tensor):
    dim = tensor.shape[-1]
    std = 1 / math.sqrt(dim)
    tensor.uniform_(-std, std)
    return tensor

# 辅助模块
class GEGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * F.gelu(gate)

class FeedForward(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = default(dim_out, dim)
        project_in = nn.Sequential(
            nn.Linear(dim, inner_dim),
            nn.GELU()
        ) if not glu else GEGLU(dim, inner_dim)

        self.net = nn.Sequential(
            project_in,
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim_out)
        )

    def forward(self, x):
        return self.net(x)

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class TimeShiftedMultiModalAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, max_time_lag=3, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)
        
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.max_time_lag = max_time_lag
        
        # 可学习的滞后权重参数
        self.lag_weights = nn.Parameter(torch.randn(max_time_lag + 1))
        
        # 标准QKV投影
        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
        
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None, mask=None):
        h = self.heads
        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)
        
        # 重排为多头形式 [batch, time, (heads dim)] -> [batch heads time dim]
        q, k, v = map(lambda t: rearrange(t, 'b t (h d) -> (b h) t d', h=h), (q, k, v))
        
        # 计算原始注意力分数
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
        
        # 创建时间滞后索引
        t = sim.size(1)  # 时间步数
        rows = torch.arange(t, device=x.device).view(-1, 1)
        cols = torch.arange(t, device=x.device).view(1, -1)
        time_lags = (rows - cols).clamp(min=0, max=self.max_time_lag)
        
        # 应用滞后权重
        lag_effect = self.lag_weights[time_lags]  # 直接索引
        lag_effect = lag_effect.unsqueeze(0).expand(sim.size(0), -1, -1)  # 广播到batch维度
        
        sim = sim + lag_effect
        
        
        attn = sim.softmax(dim=-1)
        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) t d -> b t (h d)', h=h)
        return self.to_out(out), attn.detach()
        
# 空间注意力（保持不变）
class SpatialAttention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out

# 时间注意力（保持不变）
class TemporalAttention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x, bias=None):
        b, n, _, h = *x.shape, self.heads

        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b t (h d) -> b h t d', h=h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        if exists(bias):
            bias = self.to_qkv(bias).chunk(3, dim=-1)
            qb, kb, _ = map(lambda t: rearrange(t, 'b t (h d) -> b h t d', h=h), bias)
            bias = einsum('b h i d, b h j d -> b h i j', qb, kb) * self.scale
            dots += bias

        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out

# 修改后的多模态Transformer
class MultiModalTransformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, context_dim=9, max_time_lag=3, mult=4, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.norm = nn.LayerNorm(dim)
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, TimeShiftedMultiModalAttention(
                    dim, 
                    context_dim=context_dim, 
                    heads=heads, 
                    dim_head=dim_head,
                    max_time_lag=max_time_lag,
                    dropout=dropout
                )),
                PreNorm(dim, FeedForward(dim, dim_out=dim, mult=mult, dropout=dropout))
            ]))

    def forward(self, x, context=None, mask=None):
        attn_weights = []
        for attn, ff in self.layers:
            x_out, attn = attn(x, context=context, mask=mask)
            x = x_out + x
            x = ff(x) + x
            attn_weights.append(attn)
        return self.norm(x), attn_weights[-1]  # 返回最后一层注意力

# 空间Transformer（保持不变）
class SpatialTransformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mult=4, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.norm = nn.LayerNorm(dim)
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, SpatialAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, dim_out=dim, mult=mult, dropout=dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

# 时间Transformer（保持不变）
class TemporalTransformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mult=4, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.norm = nn.LayerNorm(dim)
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, TemporalAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, dim_out=dim, mult=mult, dropout=dropout))
            ]))

    def forward(self, x, bias=None):
        for attn, ff in self.layers:
            x = attn(x, bias=bias) + x
            x = ff(x) + x
        return self.norm(x)

In [None]:
import models_pvt
from attention import TimeShiftedMultiModalAttention
from torch import nn
from einops import rearrange

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim=None, dropout=0.):
        super().__init__()
        hidden_dim = hidden_dim if hidden_dim else dim * 4
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

class PVTSimCLR(nn.Module):
    def __init__(self, base_model, out_dim=512, context_dim=9, num_head=8, 
                 mm_depth=2, dropout=0., max_time_lag=3, pretrained=True):
        super(PVTSimCLR, self).__init__()
        
        self.backbone = models_pvt.__dict__[base_model](pretrained=pretrained)
        num_ftrs = self.backbone.head.in_features
        
        self.proj = nn.Linear(num_ftrs, out_dim)
        self.proj_context = nn.Linear(context_dim, out_dim)
        self.norm1 = nn.LayerNorm(context_dim)
        
        dim_head = out_dim // num_head
        self.mm_transformer = nn.ModuleList([
            PreNorm(out_dim, TimeShiftedMultiModalAttention(
                query_dim=out_dim,
                context_dim=out_dim,
                heads=num_head,
                dim_head=dim_head,
                max_time_lag=max_time_lag,
                dropout=dropout
            )) for _ in range(mm_depth)
        ])
        self.ff = nn.ModuleList([
            PreNorm(out_dim, FeedForward(out_dim, dropout=dropout)) 
            for _ in range(mm_depth)
        ])

    def forward(self, x, context=None, mask=None):  # 参数名从time_mask改为更通用的mask
        # 视觉特征提取
        h = self.backbone.forward_features(x)  # [B, N, D]
        h = h.mean(dim=1) if h.dim() == 3 else h  # 确保[B, D]
        
        # 投影到目标维度
        x = self.proj(h).unsqueeze(1)  # [B, 1, D]
        context = self.proj_context(self.norm1(context))  # [B, T, D]
        
        # 准备mask（如果需要）
        if mask is not None:
            # 确保mask形状正确 [B, T]
            if mask.dim() == 1:
                mask = mask.unsqueeze(0).expand(x.size(0), -1)
        
        # 多模态时间延迟注意力
        for attn, ff in zip(self.mm_transformer, self.ff):
            x_attn, _ = attn(x, context=context, mask=mask)  # 传入mask
            x = x_attn + x
            x = ff(x) + x
        
        return x.squeeze(1)  # [B, D]

In [None]:
import torch
from torch import nn
from einops import rearrange, repeat

from attention import SpatialTransformer, TemporalTransformer

from models_pvt_simclr import PVTSimCLR


class MMST_ViT(nn.Module):
    def __init__(self, out_dim=2, num_grid=64, num_short_term_seq=6, num_long_term_seq=12, num_year=5,
                 pvt_backbone=None, context_dim=9, dim=192, batch_size=64, depth=4, heads=3, pool='cls', dim_head=64,
                 dropout=0., emb_dropout=0., scale_dim=4, ):
        super().__init__()

        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.batch_size = batch_size
        self.pvt_backbone = pvt_backbone

        self.proj_context = nn.Linear(num_year * num_long_term_seq * context_dim, num_short_term_seq * dim)
        # self.proj_context = nn.Linear(num_year * num_long_term_seq * context_dim, dim)

        self.pos_embedding = nn.Parameter(torch.randn(1, num_short_term_seq, num_grid, dim))
        self.space_token = nn.Parameter(torch.randn(1, 1, dim))
        self.space_transformer = SpatialTransformer(dim, depth, heads, dim_head, mult=scale_dim, dropout=dropout)

        self.temporal_token = nn.Parameter(torch.randn(1, 1, dim))
        self.temporal_transformer = TemporalTransformer(dim, depth, heads, dim_head, mult=scale_dim, dropout=dropout)

        self.dropout = nn.Dropout(emb_dropout)
        self.pool = pool

        self.norm1 = nn.LayerNorm(dim)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, out_dim)
        )

    def forward_features(self, x, ys):
        x = rearrange(x, 'b t g c h w -> (b t g) c h w')
        ys = rearrange(ys, 'b t g n d -> (b t g) n d')

        # prevent the number of grids from being too large to cause out of memory
        B = x.shape[0]
        n = B // self.batch_size if B % self.batch_size == 0 else B // self.batch_size + 1

        x_hat = torch.empty(0).to(x.device)
        for i in range(n):
            start, end = i * self.batch_size, (i + 1) * self.batch_size
            x_tmp = x[start:end]
            ys_tmp = ys[start:end]

            x_hat_tmp = self.pvt_backbone(x_tmp, context=ys_tmp)
            x_hat = torch.cat([x_hat, x_hat_tmp], dim=0)

        return x_hat

    def forward(self, x, ys=None, yl=None):
        b, t, g, _, _, _ = x.shape
        x = self.forward_features(x, ys)
        x = rearrange(x, '(b t g) d -> b t g d', b=b, t=t, g=g)

        cls_space_tokens = repeat(self.space_token, '() g d -> b t g d', b=b, t=t)
        x = torch.cat((cls_space_tokens, x), dim=2)
        x += self.pos_embedding[:, :, :(g + 1)]
        x = self.dropout(x)

        x = rearrange(x, 'b t g d -> (b t) g d')
        x = self.space_transformer(x)
        x = rearrange(x[:, 0], '(b t) ... -> b t ...', b=b)

        cls_temporal_tokens = repeat(self.temporal_token, '() t d -> b t d', b=b)
        x = torch.cat((cls_temporal_tokens, x), dim=1)

        # concatenate parameters in different months
        yl = rearrange(yl, 'b y m d -> b (y m d)')
        yl = self.proj_context(yl)
        yl = rearrange(yl, 'b (t d) -> b t d', t=t)
        # yl = repeat(yl, '() d -> b t d', b=b, t=t)

        yl = torch.cat((cls_temporal_tokens, yl), dim=1)
        yl = self.norm1(yl)

        x = self.temporal_transformer(x, yl)

        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]

        return self.mlp_head(x)


if __name__ == "__main__":
    # x.shape = B, T, G, C, H, W
    x = torch.randn((1, 6, 10, 3, 224, 224))
    # ys.shape = B, T, G, N1, d
    ys = torch.randn((1, 6, 10, 28, 9))
    # yl.shape = B, T, N2, d
    yl = torch.randn((1, 5, 12, 9))

    pvt = PVTSimCLR("pvt_tiny", out_dim=512, context_dim=9)
    model = MMST_ViT(out_dim=4, pvt_backbone=pvt, dim=512)

    # print(model)

    z = model(x, ys=ys, yl=yl)
    print(z)
    print(z.shape)

In [None]:
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat


def exists(val):
    return val is not None


def uniq(arr):
    return {el: True for el in arr}.keys()


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def max_neg_value(t):
    return -torch.finfo(t.dtype).max


def init_(tensor):
    dim = tensor.shape[-1]
    std = 1 / math.sqrt(dim)
    tensor.uniform_(-std, std)
    return tensor


# feedforward
class GEGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * F.gelu(gate)


class FeedForward(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = default(dim_out, dim)
        project_in = nn.Sequential(
            nn.Linear(dim, inner_dim),
            nn.GELU()
        ) if not glu else GEGLU(dim, inner_dim)

        self.net = nn.Sequential(
            project_in,
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim_out)
        )

    def forward(self, x):
        return self.net(x)


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class MultiModalAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None, mask=None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)


class SpatialAttention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out

class TimeShiftedCrossModalAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, max_time_lag=5, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads
        self.max_time_lag = max_time_lag

        # 遥感图像到查询向量
        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        # 气象数据到键值向量
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        # 可学习的滞后权重参数，每个头独立学习
        self.lag_weights = nn.Parameter(torch.randn(heads, max_time_lag + 1))
        
        # 输出层
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None):
        h = self.heads

        # 查询来自遥感图像
        q = self.to_q(x)
        
        # 键值来自气象数据
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

        # 重排为多头形式
        q, k, v = map(lambda t: rearrange(t, 'b t (h d) -> (b h) t d', h=h), (q, k, v))

        # 计算原始注意力分数
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        # 创建因果掩码 - 遥感图像只能访问当前及之前时刻的气象数据
        t = sim.size(1)
        causal_mask = torch.ones(t, t, device=x.device).triu_(1).bool()
        max_neg_val = -torch.finfo(sim.dtype).max
        sim.masked_fill_(causal_mask.unsqueeze(0), max_neg_val)

        # 应用时间滞后权重
        time_lags = torch.arange(t, device=x.device).view(1, -1, 1) - torch.arange(t, device=x.device).view(1, 1, -1)
        time_lags = time_lags.clamp(min=0, max=self.max_time_lag)
        
        # 扩展滞后权重到batch维度
        lag_effect = self.lag_weights.unsqueeze(0).unsqueeze(3)  # [1, heads, max_lag+1, 1]
        lag_effect = lag_effect.expand(sim.size(0), -1, -1, t)   # [batch*heads, heads, max_lag+1, t]
        
        # 为每个头选择对应的滞后权重
        head_indices = torch.arange(h, device=x.device).repeat(sim.size(0) // h)
        selected_lag_weights = lag_effect[torch.arange(sim.size(0)), head_indices]  # [batch*heads, max_lag+1, t]
        
        # 应用滞后效应
        lag_adjustment = selected_lag_weights.gather(1, time_lags.expand(sim.size(0), -1, -1))
        sim = sim + lag_adjustment

        # 注意力计算
        attn = sim.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) t d -> b t (h d)', h=h)
        
        # 返回输出张量
        return self.to_out(out)

class MultiModalTransformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, context_dim=9, mult=4, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.norm = nn.LayerNorm(dim)
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, MultiModalAttention(dim, context_dim=context_dim, heads=heads, dim_head=dim_head,
                                                 dropout=dropout)),
                PreNorm(dim, FeedForward(dim, dim_out=dim, mult=mult, dropout=dropout))
            ]))

    def forward(self, x, context=None):
        for attn, ff in self.layers:
            x = attn(x, context=context) + x
            x = ff(x) + x
        return self.norm(x)


class SpatialTransformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mult=4, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.norm = nn.LayerNorm(dim)
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, SpatialAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, dim_out=dim, mult=mult, dropout=dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)


class TemporalTransformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, max_time_lag=5, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.norm = nn.LayerNorm(dim)
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, TimeShiftedCrossModalAttention(
                    query_dim=dim, 
                    context_dim=dim, 
                    heads=heads, 
                    dim_head=dim_head, 
                    max_time_lag=max_time_lag, 
                    dropout=dropout
                )),
                PreNorm(dim, FeedForward(dim, dropout=dropout))
            ]))

    def forward(self, x, bias=None, context=None):
        for attn, ff in self.layers:
            if context is not None:
                # 使用跨模态注意力，x作为query，context作为key/value
                x = attn(x, context=context) + x
            else:
                # 使用自注意力
                x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

In [None]:
import torch
from torch import nn
from einops import rearrange, repeat

from attention import SpatialTransformer, TemporalTransformer

from models_pvt_simclr import PVTSimCLR


class MMST_ViT(nn.Module):
    def __init__(self, out_dim=2, num_grid=64, num_short_term_seq=6, num_long_term_seq=12, num_year=5,
                 pvt_backbone=None, context_dim=9, dim=192, batch_size=64, depth=4, heads=3, pool='cls', dim_head=64,
                 dropout=0., emb_dropout=0., scale_dim=4, max_time_lag=5):
        super().__init__()

        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.batch_size = batch_size
        self.pvt_backbone = pvt_backbone

        self.proj_context = nn.Linear(num_year * num_long_term_seq * context_dim, num_short_term_seq * dim)
        # self.proj_context = nn.Linear(num_year * num_long_term_seq * context_dim, dim)

        self.pos_embedding = nn.Parameter(torch.randn(1, num_short_term_seq, num_grid, dim))
        self.space_token = nn.Parameter(torch.randn(1, 1, dim))
        self.space_transformer = SpatialTransformer(dim, depth, heads, dim_head, mult=scale_dim, dropout=dropout)

        self.temporal_token = nn.Parameter(torch.randn(1, 1, dim))
        # 更新 TemporalTransformer 初始化，添加 max_time_lag 参数
        self.temporal_transformer = TemporalTransformer(
            dim=dim, 
            depth=depth, 
            heads=heads, 
            dim_head=dim_head, 
            max_time_lag=max_time_lag,  # 添加时间滞后参数
            dropout=dropout
        )

        self.dropout = nn.Dropout(emb_dropout)
        self.pool = pool

        self.norm1 = nn.LayerNorm(dim)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, out_dim)
        )

    def forward_features(self, x, ys):
        x = rearrange(x, 'b t g c h w -> (b t g) c h w')
        ys = rearrange(ys, 'b t g n d -> (b t g) n d')

        # prevent the number of grids from being too large to cause out of memory
        B = x.shape[0]
        n = B // self.batch_size if B % self.batch_size == 0 else B // self.batch_size + 1

        x_hat = torch.empty(0).to(x.device)
        for i in range(n):
            start, end = i * self.batch_size, (i + 1) * self.batch_size
            x_tmp = x[start:end]
            ys_tmp = ys[start:end]

            x_hat_tmp = self.pvt_backbone(x_tmp, context=ys_tmp)
            x_hat = torch.cat([x_hat, x_hat_tmp], dim=0)

        return x_hat

    def forward(self, x, ys=None, yl=None):
        b, t, g, _, _, _ = x.shape
        x = self.forward_features(x, ys)
        x = rearrange(x, '(b t g) d -> b t g d', b=b, t=t, g=g)

        cls_space_tokens = repeat(self.space_token, '() g d -> b t g d', b=b, t=t)
        x = torch.cat((cls_space_tokens, x), dim=2)
        x += self.pos_embedding[:, :, :(g + 1)]
        x = self.dropout(x)

        x = rearrange(x, 'b t g d -> (b t) g d')
        x = self.space_transformer(x)
        x = rearrange(x[:, 0], '(b t) ... -> b t ...', b=b)

        cls_temporal_tokens = repeat(self.temporal_token, '() t d -> b t d', b=b)
        x = torch.cat((cls_temporal_tokens, x), dim=1)

        # concatenate parameters in different months
        yl = rearrange(yl, 'b y m d -> b (y m d)')
        yl = self.proj_context(yl)
        yl = rearrange(yl, 'b (t d) -> b t d', t=t)
        # yl = repeat(yl, '() d -> b t d', b=b, t=t)

        yl = torch.cat((cls_temporal_tokens, yl), dim=1)
        yl = self.norm1(yl)

        # 更新 TemporalTransformer 调用，传递气象数据作为上下文
        x = self.temporal_transformer(x, context=yl)

        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]

        return self.mlp_head(x)


if __name__ == "__main__":
    # x.shape = B, T, G, C, H, W
    x = torch.randn((1, 6, 10, 3, 224, 224))
    # ys.shape = B, T, G, N1, d
    ys = torch.randn((1, 6, 10, 28, 9))
    # yl.shape = B, T, N2, d
    yl = torch.randn((1, 5, 12, 9))

    pvt = PVTSimCLR("pvt_tiny", out_dim=512, context_dim=9)
    model = MMST_ViT(out_dim=4, pvt_backbone=pvt, dim=512, max_time_lag=5)

    # print(model)

    z = model(x, ys=ys, yl=yl)
    print(z)
    print(z.shape)