In [None]:
import numpy as np
import pyproj
from dataset.shanghai import ShanghaiDataset
from metrics.shanghai import LONLAT_MEAN, LONLAT_STD
from model import RegionDiff
from scripts.noise_prior import noise_prior
from torch.utils.data import DataLoader
from tqdm import tqdm

proj = pyproj.Proj("+proj=utm +zone=50 +ellps=WGS84 +datum=WGS84 +units=m +no_defs")

#### Construct Noise Prior

In [None]:
# load trajectory sampled by EPR and Flow information
traj_epr = np.load("path/to/traj_epr.npy")

# load trained model
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = RegionDiff.load_from_checkpoint(
    "/path/to/your/checkpoint", map_location=device, strict=False
)
model.eval()
for param in model.parameters():
    param.requires_grad = False

In [None]:
latents_epr = []
dataset = ShanghaiDataset("/path/to/dataset", target="lonlat")
batch_size = 2048
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

for i, batch in enumerate(tqdm(loader)):
    batch = {k: v.cuda() for k, v in batch.items()}
    traj_epr_ll = traj_epr[i * batch_size : (i + 1) * batch_size]
    sample_epr = (traj_epr_ll - LONLAT_MEAN) / LONLAT_STD
    latent_epr = model.inverse_sampling(
        data=batch,
        latent=torch.from_numpy(sample_epr),
        show_progress=False,
    )
    latents_epr.append(latent_epr.cpu().numpy())

latents_epr = np.concatenate(latents_epr, axis=0)

#### Generate Trajectories

In [None]:
# generate trajectory
traj_gen = []
for i, batch in enumerate(tqdm(loader)):
    batch = {k: v.cuda() for k, v in batch.items()}
    latent = torch.from_numpy(latents_epr[i * batch_size : (i + 1) * batch_size])
    latent = noise_prior(latent)
    sample = (
        model.sampling(data=batch, latent=latent, num_steps=100, show_progress=False)
        .detach()
        .cpu()
        .to(torch.float32)
    )
    sample_idx = model._reconstruct_idx(sample)
    traj_gen.append(sample_idx)
traj_gen = torch.cat(traj_gen, axis=0)

In [None]:
# save generated trajectory
np.save("path/to/traj_gen.npy", traj_gen.numpy())