In [None]:
import typing
from typing import Any, Dict

import lightning as L
import plotly.graph_objects as go
import rpad.visualize_3d.plots as v3p
import torch
import torch_geometric.data as tgd
from flowbot3d.models.artflownet import artflownet_loss, flow_metrics
from online_adaptation.nets.history_nets import *
from plotly.subplots import make_subplots
from torch import optim
from online_adaptation.models.history_tformer import FlowHistoryTformerPredictorTrainingModule


In [None]:
data = []

for i in range(16):
    d = tgd.Data(
        x=None,
        pos=torch.rand(torch.randint(1000, tuple()), 3),
    )

    history = []
    flow_history = []
    lengths = []
    for _ in range(torch.randint(1, 10, tuple())):
        N = torch.randint(1000, tuple())
        history.append(torch.rand(N, 3))
        flow_history.append(torch.rand(N, 3))
        lengths.append(N)

    d.history = torch.cat(history, dim=0) if len(history) > 0 else None
    d.flow_history = torch.cat(flow_history, dim=0) if len(flow_history) > 0 else None
    d.lengths = torch.tensor(lengths)
    data.append(d)
    
batch = tgd.Batch.from_data_list(data)

In [None]:
def get_history_batch(batch):
    """Extracts a single batch of the history data for encoding, because each history element is processed separately."""
    history_datas = []
    for data in batch.to_data_list():
        history_data = []
        # Get start/end positions based on lengths.
        ixs = [0] + data.lengths.cumsum(0).tolist()
        for i in range(len(data.lengths)):
            history_data.append(tgd.Data(
                x=data.flow_history[ixs[i]:ixs[i + 1]],
                pos=data.history[ixs[i]:ixs[i + 1]]
            ))
        history_datas.extend(history_data)

    return tgd.Batch.from_data_list(history_datas)

def history_latents_to_nested_list(batch, history_latents):
    datas = batch.to_data_list()
    history_lengths = [0] + [len(data.lengths) for data in datas]
    ixs = torch.tensor(history_lengths).cumsum(0).tolist()
    post_encoder_latents = []
    for i, data in enumerate(datas):
        post_encoder_latents.append(history_latents[ixs[i]:ixs[i + 1]])

    return post_encoder_latents

history_batch = get_history_batch(batch)


In [None]:
encoder = pnp.PN2Encoder(in_dim=3, out_dim=256)
results = encoder(history_batch)
results.shape

In [None]:
history_nested_list = history_latents_to_nested_list(batch, results)

In [None]:
print([x.shape for x in history_nested_list])

In [None]:
tformer = nn.Transformer(d_model=256)


In [None]:
# The list of history latents is the input to the transformer.
# Each element in the list is a variable-lenght sequence of latents, with shape [Ni, 256]
# The transformer expects the input to have shape [S, N, E], where S is the sequence length, N is the batch size, and E is the embedding size.
# We need to pad and mask:

# Pad the sequences to the same length, using torch's pad_sequence function.
src_padded = nn.utils.rnn.pad_sequence(history_nested_list, batch_first=True, padding_value=0)
print(src_padded.shape)

# Create a mask for the padded sequences.
src_mask = (src_padded != 0).all(-1) # [N, S] 
print(src_mask.shape)

# The transformer expects the input to have shape [S, N, E], where S is the sequence length, N is the batch size, and E is the embedding size.
# We need to permute the dimensions.
src_padded = src_padded.permute(1, 0, 2)
# src_mask = src_mask.permute(1, 0) # No need to transpose...

# This is our query vector. It has shape [S, N, E], where S is the sequence length, N is the batch size, and E is the embedding size.
tgt = torch.ones(1, 16, 256)

# The transformer also expects the input to be of type float.
src_padded = src_padded.float()
tgt = tgt.float()

print(src_padded.shape, tgt.shape, src_mask.shape)

# Pass the input through the transformer, with mask and tgt.
out = tformer(src_padded, tgt, src_key_padding_mask=src_mask)

In [None]:
src_mask.shape

In [None]:
results.shape


In [None]:
post_encoder_batch[0]