In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
torch.cuda.empty_cache()

from torchvision.models import resnet18, ResNet18_Weights

In [2]:
from models.keypoints.KeypointPredictor import KeypointPredictor
from models.keypoints.ImageEncoder import ImageEncoder
from models.keypoints.ImageDecoder import ImageDecoder

x = torch.rand(1,3,128,128)
model = KeypointPredictor()
out, soft, m = model(x)
print(out.shape)

torch.Size([1, 100, 2])




In [3]:
x = torch.rand(1,64,32,32)
y = torch.rand(1,100,2)

model = ImageDecoder(64, 100, 3)
out = model(x,y)
out.shape

model = ImageEncoder()
out = model(torch.rand(1,3,128,128))
out.shape

torch.Size([1, 64, 32, 32])

In [4]:
import lightning as L

import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms.functional as TF
from torchvision.utils import make_grid
import io
from PIL import Image

from lightning.pytorch.callbacks import Callback

class AlphaScheduler(Callback):
    def __init__(self, warmup_epochs=10, ramp_epochs=20, final_alpha=0.01):
        """
        Args:
            warmup_epochs: Number of epochs to keep alpha = 0
            ramp_epochs: Number of epochs to linearly ramp up alpha after warmup
            final_alpha: Target alpha value
        """
        self.warmup_epochs = warmup_epochs
        self.ramp_epochs = ramp_epochs
        self.final_alpha = final_alpha

    def on_train_epoch_start(self, trainer, pl_module):
        epoch = trainer.current_epoch

        if epoch < self.warmup_epochs:
            alpha = 0.0
        elif epoch < self.warmup_epochs + self.ramp_epochs:
            # Linear ramp from 0 to final_alpha
            ramp_progress = (epoch - self.warmup_epochs) / self.ramp_epochs
            alpha = ramp_progress * self.final_alpha
        else:
            alpha = self.final_alpha

        pl_module.alpha = alpha

def draw_keypoints_on_image(image_tensor, keypoints, color='r'):
    """
    image_tensor: [3, H, W] (float in [0,1])
    keypoints: [N, 2] with (x, y) format in image coordinates
    """
    image = TF.to_pil_image(image_tensor.cpu())
    plt.figure(figsize=(4, 4))
    plt.imshow(image)
    plt.axis("off")
    for x, y in keypoints.cpu():
        plt.scatter(x, y, c=color, s=10)
    
    buf = io.BytesIO()
    plt.savefig(buf, format='jpeg', bbox_inches='tight', pad_inches=0)
    plt.close()
    buf.seek(0)
    pil_image = Image.open(buf)
    return TF.to_tensor(pil_image)
    
class LitKeypointDetector(L.LightningModule):
    def __init__(self, keypoint_encoder, feature_encoder, feature_decoder):
        super().__init__()
        self.keypoint_generator = keypoint_encoder
        self.feature_encoder = feature_encoder
        self.image_reconstructor = feature_decoder
        self.alpha = 0

    def training_step(self, batch, batch_idx):
        v1, vt = batch
        
        v1_features = self.feature_encoder(v1)

        _, _, v1_heatmaps = self.keypoint_generator(v1)
        _, vt_soft_keypoints, vt_heatmaps = self.keypoint_generator(vt)

        vt_pred = self.image_reconstructor(v1_features, vt_soft_keypoints)
        loss, mse, condens = self.keypoint_loss(vt, vt_pred, v1_heatmaps, vt_heatmaps)
        self.log("train/total_loss", loss)
        self.log("train/mse", mse)
        self.log("train/condensation_loss", condens)
        self.log("alpha", self.alpha)

        return loss
    
    def validation_step(self, batch, batch_idx):
        v1, vt = batch

        # Encode features from v0
        v1_features = self.feature_encoder(v1)

        # Get keypoints from v0 and v1 separately
        _, soft_kp_vt, v1_heatmaps = self.keypoint_generator(v1)
        _, soft_kp_vt, vt_heatmaps = self.keypoint_generator(vt)

        # Predict v1 using v0 features + v1 keypoints
        vt_pred = self.image_reconstructor(v1_features, soft_kp_vt)
        loss, mse, condens = self.keypoint_loss(vt, vt_pred, v1_heatmaps, vt_heatmaps)

        self.log("val/total_loss", loss)
        self.log("val/mse", mse)
        self.log("val/condensation_loss", condens)

        if batch_idx == 0:
            B, C, img_H, img_W = v1.shape
            _, _, heatmap_H, heatmap_W = vt_heatmaps.shape

            scale_x = img_W / heatmap_W
            scale_y = img_H / heatmap_H

            # Rescale keypoints
            scaled_kps_v0 = soft_kp_vt[0].clone()
            scaled_kps_v0[..., 0] *= scale_x
            scaled_kps_v0[..., 1] *= scale_y

            scaled_kps_v1 = soft_kp_vt[0].clone()
            scaled_kps_v1[..., 0] *= scale_x
            scaled_kps_v1[..., 1] *= scale_y

            # Clamp images for visualization
            v0_img = v1[0].clamp(0, 1)
            v1_img = vt[0].clamp(0, 1)
            v1_pred_img = vt_pred[0].clamp(0, 1)

            # Overlay keypoints
            vis_v0 = draw_keypoints_on_image(v0_img, scaled_kps_v0, color='b')  # Blue keypoints
            vis_v1 = draw_keypoints_on_image(v1_img, scaled_kps_v1, color='r')  # Red keypoints
            vis_pred = draw_keypoints_on_image(v1_pred_img, scaled_kps_v1, color='r')

            # Combine visuals
            grid = make_grid([vis_v0, vis_v1, vis_pred])
            self.logger.experiment.add_image("val/visualize_key_reconstruction", grid, self.current_epoch)

            # Log heatmaps of v1
            sample_heatmaps = vt_heatmaps[0].unsqueeze(1)  # [K, 1, H, W]
            heatmap_grid = make_grid(sample_heatmaps, nrow=10, normalize=True, scale_each=True)
            self.logger.experiment.add_image("val/heatmaps_v1", heatmap_grid, self.current_epoch)



    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=1e-3)
        return optimizer
    
    def keypoint_loss(self, vt, vt_pred, v0_heat, vt_heat):

        mse = nn.functional.mse_loss(vt, vt_pred)


        # condensation_loss_1 = (v0_heat.amax(dim=(-1,-2)) - v0_heat.mean(dim=(-1,-2))).mean()
        # condensation_loss_t = (vt_heat.amax(dim=(-1,-2)) - vt_heat.mean(dim=(-1,-2))).mean()
        # condensation_loss = -(condensation_loss_1 + condensation_loss_t)/2
        condensation_loss_1 = self.condensation_loss_entropy(v0_heat)
        condensation_loss_t = self.condensation_loss_entropy(vt_heat)
        condensation_loss = (condensation_loss_1 + condensation_loss_t)/2

        return mse+self.alpha*condensation_loss, mse, condensation_loss
    
    def condensation_loss_entropy(self, heatmaps):
        # heatmaps: [B, K, H, W] (softmaxed)
        eps = 1e-8
        log_h = torch.log(heatmaps + eps)
        entropy = -torch.sum(heatmaps * log_h, dim=(-2, -1))  # [B, K]
        return entropy.mean()


In [5]:
import os
import random
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class CameraSequentialPairs(Dataset):
    def __init__(self, root, transform=None, min_offset=1, max_offset=5):
        """
        Args:
            root (str): Root directory containing camera folders.
            transform: Transformations to apply to each image.
            min_offset (int): Minimum frame difference between pairs.
            max_offset (int): Maximum frame difference between pairs.
        """
        self.transform = transform or transforms.ToTensor()
        self.min_offset = min_offset
        self.max_offset = max_offset
        self.sequences = []

        # List all camera directories
        camera_dirs = sorted([os.path.join(root, d) for d in os.listdir(root)
                              if os.path.isdir(os.path.join(root, d))])
        
        # For each camera directory, get sorted image paths (using numeric sorting)
        for cam_dir in camera_dirs:
            image_dir = os.path.join(cam_dir, "images")
            if not os.path.exists(image_dir):
                continue
            image_files = sorted(
                [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')],
                key=lambda x: int(os.path.splitext(os.path.basename(x))[0].split('_')[-1])
            )
            if len(image_files) > 1:
                self.sequences.append(image_files)
        
        # Build cumulative lengths for indexing purposes
        self.cumulative_lengths = []
        total = 0
        for seq in self.sequences:
            # Each valid pair comes from a starting image (all except the last)
            total += len(seq) - 1
            self.cumulative_lengths.append(total)

    def __len__(self):
        return self.cumulative_lengths[-1] if self.cumulative_lengths else 0

    def __getitem__(self, idx):
    # Find which sequence to use
        seq_idx = 0
        while idx >= self.cumulative_lengths[seq_idx]:
            seq_idx += 1
        seq = self.sequences[seq_idx]
        
        local_idx = idx if seq_idx == 0 else idx - self.cumulative_lengths[seq_idx - 1]
        
        # Check that the index is valid
        if local_idx >= len(seq) - 1:
            raise IndexError("Local index out of range for the sequence")
        
        img1_path = seq[local_idx]

        # Compute the maximum valid offset
        remaining = len(seq) - 1 - local_idx
        effective_max_offset = min(self.max_offset, remaining)
        
        # In case effective_max_offset < min_offset, use effective_max_offset
        # Otherwise, randomly sample
        offset = effective_max_offset  # or if you want random:
        # offset = random.randint(self.min_offset, effective_max_offset)

        img2_idx = local_idx + offset
        
        img2_path = seq[img2_idx]

        try:
            img1 = Image.open(img1_path).convert("RGB")
            img2 = Image.open(img2_path).convert("RGB")
        except Exception as e:
            print(f"Error loading images: {img1_path} or {img2_path}: {e}")
            # Optionally, raise an error here or return a default dummy tensor
            raise e
        
        return self.transform(img1), self.transform(img2)



In [6]:
import os
from torch.utils.data import DataLoader, random_split
from torchvision import transforms

# Assuming CameraSequentialPairs is defined as in the previous code block.
# For example:
# class CameraSequentialPairs(Dataset):
#     ... (code from previous message) ...

def get_camera_dataloaders(data_root="./dataset_00", batch_size=8, min_offset=5, max_offset=30, num_workers=8):
    # For training, we use the "train" subfolder within the data root.
    train_dataset = CameraSequentialPairs(
        root=os.path.join(data_root, "train"),
        transform=transforms.ToTensor(),
        min_offset=min_offset,
        max_offset=max_offset
    )
    
    # Similarly, for validation, use the "val" subfolder.
    val_dataset = CameraSequentialPairs(
        root=os.path.join(data_root, "val"),
        transform=transforms.ToTensor(),
        min_offset=min_offset,
        max_offset=max_offset
    )
    
    # Optionally, print out dataset sizes for debugging
    print("Training dataset pairs:", len(train_dataset))
    print("Validation dataset pairs:", len(val_dataset))
    
    # Build DataLoaders with shuffling for training and no shuffle for validation.
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )
    
    return train_loader, val_loader

# Example usage:
train_loader, val_loader = get_camera_dataloaders(
    data_root="./dataset_00", batch_size=16, min_offset=1, max_offset=5, num_workers=8
)

# For testing, iterate over one batch:
for img1, img2 in train_loader:
    print("Batch shapes:", img1.shape, img2.shape)
    break


Training dataset pairs: 199998
Validation dataset pairs: 19998
Batch shapes: torch.Size([16, 3, 128, 128]) torch.Size([16, 3, 128, 128])


In [7]:
train_loader, val_loader = get_camera_dataloaders(data_root="./dataset_00", batch_size=16)
print(len(train_loader))
for img1, img2 in train_loader:
    print(img1.shape, img2.shape)  # e.g., [16, 3, H, W]
    break


Training dataset pairs: 199998
Validation dataset pairs: 19998
12500
torch.Size([16, 3, 128, 128]) torch.Size([16, 3, 128, 128])


In [8]:
# trainer = L.Trainer(max_epochs=200, callbacks=[AlphaScheduler()])
# model = LitKeypointDetector(KeypointPredictor(100), ImageEncoder(), ImageDecoder(64, 100))

# trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)


In [9]:
encoder_layer = nn.TransformerEncoderLayer(d_model=200, nhead=8, batch_first=True)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
# out = transformer_encoder(test)

In [10]:
class SequenceKeypointDetector(nn.Module):
    def __init__(self, keypoint_detector):
        super().__init__()
        self.keypoint_detector = keypoint_detector  # This is your existing 4D keypoint model

    def forward(self, x):
        """
        Args:
          x: input tensor of shape (B, T, C, H, W)
        Returns:
          Output: tensor of shape (B, T, num_keypoints, 2)
        """
        B, T, C, H, W = x.shape
        # Flatten batch and sequence dimensions
        x_flat = x.view(B * T, C, H, W)
        # Process with keypoint detector (assumed to output (B*T, num_keypoints, 2))
        keypoints_flat = self.keypoint_detector(x_flat)[0]  # For example, if keypoint_detector returns (grid_keypoints, soft_keypoints, heatmaps)
        # Reshape the output back to (B, T, num_keypoints, 2)
        keypoints = keypoints_flat.view(B, T, -1, 2)
        return keypoints

In [11]:

from models.keypoints.KeypointPredictor import KeypointPredictor
from models.keypoints.ImageEncoder import ImageEncoder
from models.keypoints.ImageDecoder import ImageDecoder

seq_detector = SequenceKeypointDetector(KeypointPredictor(100))
# Sample input: (64, 100, 3, 128, 128)
input_tensor = torch.randn(8, 100, 3, 128, 128)
output_keypoints = seq_detector(input_tensor)
transformer_input = output_keypoints.view(8,100,-1)
print(transformer_input.size())  # Should print torch.Size([64, 100, 100, 2])



torch.Size([8, 100, 200])


In [12]:
# out = transformer_encoder(transformer_input)

In [13]:
out.size()

torch.Size([1, 64, 32, 32])

In [14]:
from losses.masked_mse import masked_mse
from losses.dice import dice_loss
import torchvision.utils as vutils
from utils import draw_keypoints_on_image

class LitPredictorT2NOD(L.LightningModule):
    def __init__(self, image_encoder, seq_keypoint_encoder, transformer_encoder, T2N_decoder, future_decoder, num_keypoints):
        super().__init__()
        self.seq_keypoint_encoder = seq_keypoint_encoder
        self.transformer_encoder = transformer_encoder
        self.image_encoder = image_encoder
        self.image_decoder = T2N_decoder
        D = 160  # flatten dim: T × (K × 2)
        T = 50
        self.pos_embed = nn.Parameter(torch.randn(1, T, D))  # [1, T, D]
        
        self.linear_projto_attnn = nn.Linear(num_keypoints*2, D)
        self.transform_to_kps = nn.Sequential(
            nn.Linear(D, D),
            nn.SiLU(),
            nn.Linear(D, num_keypoints * 2),
            nn.Tanh()
        )

        self.seq_keypoint_encoder.eval()
        self.act = nn.SiLU()

        self.future_predictor = future_decoder
        for param in self.future_predictor.parameters():
            param.requires_grad = False

        self.sample_size = 10

        
    def training_step(self, batch, batch_idx):
        seq_images, t2nod, future_images= batch
        with torch.no_grad():
            seq_keypoints = self.seq_keypoint_encoder(seq_images)

        B, T, num_kps, _ = seq_keypoints.shape
        flat_seq_keypoints = (seq_keypoints/32.0).view(B, T, -1)
        transformer_input = self.act(self.linear_projto_attnn(flat_seq_keypoints))
        transformer_input = transformer_input + self.pos_embed[:, :T]
        
        flat_pred_kepoints = self.transformer_encoder(transformer_input)
        flat_pred_kepoints = self.transform_to_kps(flat_pred_kepoints)
        pred_keypoints = flat_pred_kepoints.view(B,T*num_kps,2)
        pred_keypoints = 32*((pred_keypoints+1)/2)
        last_frame = seq_images[:, -1, :, :, :]
        features = self.image_encoder(last_frame) # B, C, H, W

        features_repeated = features.unsqueeze(1).expand(-1, T, -1, -1, -1)

        future_keypoints = pred_keypoints.view(B,T, num_kps, 2)

        sampled_idxs = torch.randperm(T)[:self.sample_size]
        pred_keypoints_sampled = future_keypoints[:, sampled_idxs]
        future_features_sampled = features_repeated[:, sampled_idxs]

        future_keypoints = pred_keypoints_sampled.view(B*self.sample_size, num_kps, 2)
        features_repeated = future_features_sampled.reshape(B*self.sample_size, 64, 32, 32)
        
        future_images_pred = self.future_predictor(features_repeated, future_keypoints)
        future_loss = F.mse_loss(future_images_pred.view(B*self.sample_size,-1), (future_images[:,sampled_idxs].view(B*self.sample_size,-1)*2)-1)
        
        out = self.image_decoder(features, pred_keypoints)
        
        total_loss, bce, t2no_mse, t2nd_mse, t2no_mask = masked_mse(out, t2nod)
        dice = dice_loss(out[:,0], t2no_mask)
        
        self.log("train/masked_mse", total_loss)
        self.log("train/bce", bce)
        self.log("train/t2no_mse", t2no_mse)
        self.log("train/t2nd_mse", t2nd_mse)
        self.log("train/future_loss", future_loss)
        self.log("train/total_loss", total_loss + future_loss)
        self.log("train/dice_loss", dice)

        return future_loss
    
    def validation_step(self, batch, batch_idx):
        seq_images, t2nod, future_images = batch
        with torch.no_grad():
            seq_keypoints = self.seq_keypoint_encoder(seq_images)

            B, T, num_kps, _ = seq_keypoints.shape
            flat_seq_keypoints = (seq_keypoints/32.0).view(B, T, -1)
            transformer_input = self.act(self.linear_projto_attnn(flat_seq_keypoints))
            transformer_input = transformer_input + self.pos_embed[:, :T]


            flat_pred_kepoints = self.transformer_encoder(transformer_input)
            flat_pred_kepoints = self.transform_to_kps(flat_pred_kepoints)
            pred_keypoints = flat_pred_kepoints.view(B,T*num_kps,2)
            pred_keypoints = 32*((pred_keypoints+1)/2)
            last_frame = seq_images[:, -1, :, :, :]
            features = self.image_encoder(last_frame) # B, C, H, W
            
            features_repeated = features.unsqueeze(1).expand(-1, T, -1, -1, -1).reshape(B*T, 64, 32, 32)

            future_keypoints = pred_keypoints.view(B*T, num_kps, 2)

            future_images_pred = self.future_predictor(features_repeated, future_keypoints)
            future_loss = F.mse_loss(future_images_pred.view(B*T,-1), (future_images.view(B*T,-1)*2)-1)

            out, pred_heatmaps = self.image_decoder(features, pred_keypoints, return_heatmaps=True)
            
            total_loss, bce, t2no_mse, t2nd_mse, t2no_mask = masked_mse(out, t2nod)
            dice = dice_loss(out[:,0], t2no_mask)
        
            self.log("val/masked_mse", total_loss)
            self.log("val/bce", bce)
            self.log("val/t2no_mse", t2no_mse)
            self.log("val/t2nd_mse", t2nd_mse)
            self.log("val/future_loss", future_loss)
            self.log("val/dice_loss", dice)
            self.log("val/total_loss", total_loss+future_loss)

        if batch_idx == 0:  # Log only the first batch per epoch
            last_image = seq_images[0,-1]
            last_future_image = future_images[0,-1]
            last_future_image_pred = future_images_pred.view(B,T,3,128,128)[0,-1]
            scaled_last_keypoints = seq_keypoints[0,-1].clone()
            scaled_last_keypoints[...,0] *= 128/32
            scaled_last_keypoints[...,1] *= 128/32

            scaled_pred_keypoints = pred_keypoints[0]
            scaled_pred_keypoints[...,0] *= 128/32
            scaled_pred_keypoints[...,1] *= 128/32

            vis_seq_keypoints = draw_keypoints_on_image(last_image, scaled_last_keypoints, color = "g")
            vis_pred_keypoints = draw_keypoints_on_image(last_image, scaled_pred_keypoints, color = "b")
            vis_pred_keypoints_resized = F.interpolate(
            vis_pred_keypoints.unsqueeze(0), size=vis_seq_keypoints.shape[-2:], mode="bilinear", align_corners=False).squeeze(0)
            
            vis_future_keypoints = draw_keypoints_on_image(last_future_image, scaled_pred_keypoints[-100:,:])
            vis_future_pred_keypoints = draw_keypoints_on_image((last_future_image_pred+1)/1, scaled_pred_keypoints[-100:,:])
            
            vis_future_keypoints_resized = F.interpolate(
            vis_future_keypoints.unsqueeze(0), size=vis_seq_keypoints.shape[-2:], mode="bilinear", align_corners=False).squeeze(0)
            vis_future_pred_keypoints_resized = F.interpolate(
            vis_future_pred_keypoints.unsqueeze(0), size=vis_seq_keypoints.shape[-2:], mode="bilinear", align_corners=False).squeeze(0)
            grid = make_grid([vis_seq_keypoints, vis_future_keypoints_resized, vis_future_pred_keypoints_resized, vis_pred_keypoints_resized])

            self.logger.experiment.add_image("val/visualize_keypoints", grid, self.current_epoch)


            grid = self.visualize_predictions_for_tensorboard(
                last_frame[0],      # RGB image
                t2nod[0],           # [2, H, W] GT
                out[0],             # [3, H, W] prediction
                t2no_mask,             # [1, H, W]
            )
            num_kp_to_show = min(16, pred_heatmaps.shape[1])
            heatmaps = pred_heatmaps[0, :num_kp_to_show]  # [K, H, W]
            heatmaps = (heatmaps - heatmaps.min()) / (heatmaps.max() - heatmaps.min() + 1e-6)

            # Optional: convert to RGB for visibility
            heatmaps_rgb = heatmaps.unsqueeze(1).repeat(1, 3, 1, 1)  # [K, 3, H, W]
            heatmap_grid = vutils.make_grid(heatmaps_rgb, nrow=4, padding=2)

            self.logger.experiment.add_image("val/pred_heatmaps", heatmap_grid, self.current_epoch)

            self.logger.experiment.add_image("val/sample_grid", grid, self.current_epoch)

    def visualize_predictions_for_tensorboard(self, rgb_img, gt, pred, mask):
        """
        Returns a grid of shape [3, H, 7*W] for logging as an image.
        """
        def to_3ch(x):
            if x.ndim == 2:
                return x.unsqueeze(0).repeat(3, 1, 1)
            elif x.shape[0] == 1:
                return x.repeat(3, 1, 1)
            return x

        def normalize(x):
            return (x - x.min()) / (x.max() - x.min() + 1e-5)

        rgb_img = normalize(rgb_img)

        t2no_mask     = (gt[0] != 1).float()
        gt_t2no       = (gt[0]+1)/2
        gt_t2nd       = (gt[1]+1)/2 + gt_t2no

        pred_mask_asis = normalize(torch.sigmoid(pred[0]))
        pred_mask = (torch.sigmoid(pred[0]) > 0.5).float()

        pred_t2no_gt_mask     = ((pred[1]+1)/2) * mask[0] + (~mask[0])
        pred_t2nd_gt_mask     = ((pred[2]+1)/2) * mask[0] + (~mask[0])


        tiles = [
            to_3ch(rgb_img),
            to_3ch(t2no_mask),
            to_3ch(normalize(gt_t2no)),
            to_3ch(normalize(gt_t2nd)),
            to_3ch(pred_mask),
            to_3ch(pred_mask_asis),
            to_3ch(normalize(pred_t2no_gt_mask)),
            to_3ch(normalize(pred_t2nd_gt_mask)),
            to_3ch(normalize(pred[1])),
            to_3ch(normalize(pred[2])),
        ]

        grid = vutils.make_grid(tiles, nrow=7)
        return grid

    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    

In [15]:
import torch

def logsumexp_pooling(heatmaps, group_size=100, beta=10.0):
    """
    Consolidate heatmaps using LogSumExp pooling.

    Args:
        heatmaps: Tensor of shape (B, K, H, W) where K = num_keypoints
        num_groups: Number of final heatmaps desired
        beta: Temperature parameter for LogSumExp

    Returns:
        Tensor of shape (B, num_groups, H, W)
    """
    B, K, H, W = heatmaps.shape

    # Reshape: (B, num_groups, group_size, H, W)
    grouped = heatmaps.view(B, K//group_size, group_size, H, W)

    # LogSumExp pooling over the group dimension
    pooled = (1.0 / beta) * torch.logsumexp(beta * grouped, dim=2)  # shape: (B, num_groups, H, W)

    return pooled


In [16]:
# Suppose we have 10,000 keypoint heatmaps of size 32x32
keypoints = torch.randn(8, 20000)  # Batch of 8
keypoints = keypoints.view(8,10000,-1)
features = torch.randn(8, 32, 32, 32)  # Batch of 8


In [17]:

decoder = ImageDecoder(32, 1000, 3, condense=True)
decoder(features, keypoints).size()

torch.Size([8, 3, 128, 128])

In [18]:

# encoder_layer = nn.TransformerEncoderLayer(d_model=200, nhead=8, batch_first=True)
# transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
    
# trainer = L.Trainer(max_epochs=200, callbacks=[AlphaScheduler()])
# model = LitPredictorT2NOD(ImageEncoder(), SequenceKeypointDetector(KeypointPredictor(100)), transformer_encoder, ImageDecoder(64, 1000, 3, condense=True))

# trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [19]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T

import re

def natural_sort(l):
    return sorted(l, key=lambda s: [int(text) if text.isdigit() else text.lower()
                                    for text in re.split(r'(\d+)', s)])

class SequenceDataset(Dataset):
    def __init__(self, 
                 root_dir, 
                 mode='train', 
                 sequence_length=5,  
                 input_transform=None, 
                 target_transform=None):
        """
        Args:
            root_dir (str): Path to the dataset folder (e.g., 'dataset_00').
            mode (str): 'train' or 'val'.
            sequence_length (int): Number of consecutive images in the input sequence.
            input_transform (callable, optional): Transformation to be applied to each input image.
            target_transform (callable, optional): Transformation to be applied to each target image.
        """
        self.sequence_length = sequence_length
        self.input_transform = input_transform
        self.target_transform = target_transform
        
        self.samples = []  # Will hold dictionaries with sample info.
        mode_dir = os.path.join(root_dir, mode)
        
        # Get each camera directory (e.g., camera_0, camera_1, etc.)
        camera_dirs = sorted([d for d in os.listdir(mode_dir) if os.path.isdir(os.path.join(mode_dir, d))])
        
        for camera in camera_dirs:
            camera_path = os.path.join(mode_dir, camera)
            images_dir = os.path.join(camera_path, 'images')
            t2no_dir   = os.path.join(camera_path, 't2no')
            t2nd_dir   = os.path.join(camera_path, 't2nd')
            
            # List and sort file names in each folder
            image_files = natural_sort(os.listdir(images_dir))
            t2no_files  = natural_sort(os.listdir(t2no_dir))
            t2nd_files  = natural_sort(os.listdir(t2nd_dir))
            
            # Check that all three folders contain the same number of files.
            if not (len(image_files) == len(t2no_files) == len(t2nd_files)):
                raise ValueError(f"Mismatch in file counts in camera folder: {camera}")
            
            num_frames = len(image_files)
            # Only use valid starting indices where the full sequence exists.
            for start_idx in range(num_frames - 2 * sequence_length + 1):
                self.samples.append({
                    "images_dir": images_dir,
                    "t2no_dir": t2no_dir,
                    "t2nd_dir": t2nd_dir,
                    "start_idx": start_idx,
                    "file_list": image_files  # Assuming the same ordering applies for all directories.
                })
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample_info = self.samples[idx]
        images_dir = sample_info["images_dir"]
        t2no_dir   = sample_info["t2no_dir"]
        t2nd_dir   = sample_info["t2nd_dir"]
        start_idx  = sample_info["start_idx"]
        file_list  = sample_info["file_list"]
        
        # Load the input sequence as RGB images.
        input_sequence = []
        future_sequence = []
        for i in range(start_idx, start_idx + self.sequence_length):
            img_path = os.path.join(images_dir, file_list[i])
            future_path = os.path.join(images_dir, file_list[i+self.sequence_length])
            # Open the image in RGB mode to retain three channels.
            img = Image.open(img_path).convert('RGB')
            future_img = Image.open(future_path).convert('RGB')

            if self.input_transform:
                img = self.input_transform(img)
                future_img = self.input_transform(future_img)
            else:
                # Default: convert image to tensor (C x H x W) with values in [0, 1]
                img = T.ToTensor()(img)
                future_img = T.ToTensor()(future_img)


            input_sequence.append(img)
            future_sequence.append(future_img)
        
        # Stack into a tensor of shape [sequence_length, 3, H, W]
        input_sequence = torch.stack(input_sequence, dim=0)
        future_sequence = torch.stack(future_sequence,dim=0)
        
        # For the target, choose the frame corresponding to the last image in the sequence.
        target_idx = start_idx + self.sequence_length - 1
        t2no_path = os.path.join(t2no_dir, file_list[target_idx])
        t2nd_path = os.path.join(t2nd_dir, file_list[target_idx])
        
        # Load target images in grayscale.
        t2no = Image.open(t2no_path).convert('L')
        t2nd = Image.open(t2nd_path).convert('L')
        
        # Resize the targets to 128x128 pixels.
        resize_transform = T.Resize((128, 128))
        t2no = resize_transform(t2no)
        t2nd = resize_transform(t2nd)
        
        delta = torch.from_numpy(((np.array(t2nd)-np.array(t2no))/25)-1).float().unsqueeze(0)
        t2no = torch.from_numpy((np.array(t2no)/25)-1).float().unsqueeze(0)
        
        # Concatenate the two target images along the channel dimension.
        target = torch.cat([t2no, delta], dim=0)  # Expected shape: [2, 128, 128]
        
        return input_sequence, target, future_sequence

# Example usage:
if __name__ == '__main__':
    # Define transforms if desired.
    input_transform = T.Compose([
        # Example: You can add T.Resize, T.RandomCrop, etc.
        T.ToTensor()
    ])
    
    target_transform = T.Compose([
    T.ToTensor(),
    lambda x: (x/25)-1
])

    
    # Create dataset instances for training and validation.
    dataset_root = 'dataset_01'  # Replace with your dataset path.
    train_dataset = SequenceDataset(root_dir=dataset_root, 
                                    mode='train', 
                                    sequence_length=20, 
                                    input_transform=input_transform, 
                                    target_transform=target_transform)
    
    val_dataset = SequenceDataset(root_dir=dataset_root, 
                                  mode='val', 
                                  sequence_length=20, 
                                  input_transform=input_transform, 
                                  target_transform=target_transform)
    
    # Test by loading one sample.
    sample_in, sample_target, sample_future = train_dataset[0]
    print("Input sequence shape:", sample_in.shape)      # Expected: [sequence_length, 3, H, W]
    print("Target shape:", sample_target.shape)            # Expected: [2, 128, 128]
    print("Target shape:", sample_future.shape)            # Expected: [2, 128, 128]

print("Target shape:", sample_target.max())            # Expected: [2, 128, 128]
print("Target shape:", sample_target.min())            # Expected: [2, 128, 128]




    

Input sequence shape: torch.Size([20, 3, 128, 128])
Target shape: torch.Size([2, 128, 128])
Target shape: torch.Size([20, 3, 128, 128])
Target shape: tensor(1.)
Target shape: tensor(-1.)


In [20]:
from train_predict_keypoints import LitKeypointDetector
from models.keypoints.KeypointPredictor import KeypointPredictor
from models.keypoints.ImageEncoder import ImageEncoder
from models.keypoints.ImageDecoder import ImageDecoder
model = torch.load("/home/sgarikipati7/packages/SSTA_2025/lightning_logs/version_3/checkpoints/epoch=99-step=1250000.ckpt")


keypoint_generator_state_dict = {
    k.replace("keypoint_generator.", ""): v
    for k, v in model["state_dict"].items()
    if k.startswith("keypoint_generator.")
}
# print(keypoint_generator_state_dict.keys())

key_gen = KeypointPredictor(100)
key_gen.load_state_dict(keypoint_generator_state_dict)
# key_gen(torch.randn(1, 3, 128, 128))

<All keys matched successfully>

In [21]:
image_decoder_state_dict = {
    k.replace("image_reconstructor.", ""): v
    for k, v in model["state_dict"].items()
    if k.startswith("image_reconstructor.")
}
print(image_decoder_state_dict.keys())
# print(model["state_dict"].items())
future_decoder = ImageDecoder(64,100,3)
future_decoder.load_state_dict(image_decoder_state_dict)



dict_keys(['x_map', 'y_map', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'deconv1.weight', 'deconv1.bias', 'deconv2.weight', 'deconv2.bias', 'deconv3.weight', 'deconv3.bias', 'conv4.weight', 'conv4.bias', 'conv5.weight', 'conv5.bias'])


<All keys matched successfully>

In [None]:
DATASET_ROOT = "dataset_01"
SEQ_LEN = 50
BATCH_SIZE = 16
NUM_WORKERS = 8
MAX_EPOCHS = 200

# ------------------------------
# Transforms
# ------------------------------
input_transform = T.Compose([
    T.Resize((128, 128)),  # Or whatever size your model expects
    T.ToTensor(),
])
# ------------------------------
# Dataset & DataLoaders
# ------------------------------
train_dataset = SequenceDataset(
    root_dir=DATASET_ROOT,
    mode="train",
    sequence_length=SEQ_LEN,
    input_transform=input_transform,
    target_transform=target_transform
)

val_dataset = SequenceDataset(
    root_dir=DATASET_ROOT,
    mode="val",
    sequence_length=SEQ_LEN,
    input_transform=input_transform,
    target_transform=target_transform
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# ------------------------------
# Model Initialization
# ------------------------------

encoder_layer = nn.TransformerEncoderLayer(d_model=160, nhead=4, batch_first=True, activation=nn.GELU(approximate="tanh"), dim_feedforward=512)

from models.keypoints.KeypointPredictor import KeypointPredictor
from models.keypoints.ImageEncoder import ImageEncoder
from models.keypoints.ImageDecoder import ImageDecoder

image_encoder = ImageEncoder()

detector = KeypointPredictor(100)
detector.load_state_dict(keypoint_generator_state_dict)
seq_keypoint_encoder = SequenceKeypointDetector(detector)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
T2N_decoder = ImageDecoder(64,100,3, condense=True)
# future_decoder = ImageDecoder(64,100,3)
model = LitPredictorT2NOD(
    image_encoder=image_encoder,
    seq_keypoint_encoder=seq_keypoint_encoder,
    transformer_encoder=transformer_encoder,
    T2N_decoder=T2N_decoder,
    future_decoder=future_decoder,
    num_keypoints=100,
)

trainer = L.Trainer(
    max_epochs=MAX_EPOCHS,
)
# ckpt="/home/sgarikipati7/packages/SSTA_2025/lightning_logs/version_11/checkpoints/epoch=6-step=174832.ckpt"
torch.cuda.empty_cache()
trainer.fit(model, train_loader, val_loader)


You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                 | Type                     | Params | Mode 
--------------------------------------------------------------------------
0 | seq_keypoint_encoder | SequenceKeypointDetector | 164 K  | eval 
1 | transformer_encoder  | TransformerEncoder       | 1.6 M  | train
2 | image_encoder        | ImageEncoder             | 157 K  |

Epoch 0:  96%|█████████▌| 11981/12488 [33:27<01:24,  5.97it/s, v_num=8]    