In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
torch.cuda.empty_cache()

from torchvision.models import resnet18, ResNet18_Weights

In [2]:
from models.keypoints.KeypointPredictor import KeypointPredictor
from models.keypoints.ImageEncoder import ImageEncoder
from models.keypoints.ImageDecoder import ImageDecoder

x = torch.rand(1,3,128,128)
model = KeypointPredictor()
out, soft, m = model(x)
print(out.shape)

torch.Size([1, 100, 2])




In [3]:
x = torch.rand(1,64,32,32)
y = torch.rand(1,100,2)

model = ImageDecoder(64, 100, 3)
out = model(x,y)
out.shape

model = ImageEncoder()
out = model(torch.rand(1,3,128,128))
out.shape

torch.Size([1, 64, 32, 32])

In [4]:
import lightning as L

import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms.functional as TF
from torchvision.utils import make_grid
import io
from PIL import Image

from lightning.pytorch.callbacks import Callback

class AlphaScheduler(Callback):
    def __init__(self, warmup_epochs=10, ramp_epochs=20, final_alpha=0.01):
        """
        Args:
            warmup_epochs: Number of epochs to keep alpha = 0
            ramp_epochs: Number of epochs to linearly ramp up alpha after warmup
            final_alpha: Target alpha value
        """
        self.warmup_epochs = warmup_epochs
        self.ramp_epochs = ramp_epochs
        self.final_alpha = final_alpha

    def on_train_epoch_start(self, trainer, pl_module):
        epoch = trainer.current_epoch

        if epoch < self.warmup_epochs:
            alpha = 0.0
        elif epoch < self.warmup_epochs + self.ramp_epochs:
            # Linear ramp from 0 to final_alpha
            ramp_progress = (epoch - self.warmup_epochs) / self.ramp_epochs
            alpha = ramp_progress * self.final_alpha
        else:
            alpha = self.final_alpha

        pl_module.alpha = alpha

def draw_keypoints_on_image(image_tensor, keypoints, color='r'):
    """
    image_tensor: [3, H, W] (float in [0,1])
    keypoints: [N, 2] with (x, y) format in image coordinates
    """
    image = TF.to_pil_image(image_tensor.cpu())
    plt.figure(figsize=(4, 4))
    plt.imshow(image)
    plt.axis("off")
    for x, y in keypoints.cpu():
        plt.scatter(x, y, c=color, s=10)
    
    buf = io.BytesIO()
    plt.savefig(buf, format='jpeg', bbox_inches='tight', pad_inches=0)
    plt.close()
    buf.seek(0)
    pil_image = Image.open(buf)
    return TF.to_tensor(pil_image)
    
class LitKeypointDetector(L.LightningModule):
    def __init__(self, keypoint_encoder, feature_encoder, feature_decoder):
        super().__init__()
        self.keypoint_generator = keypoint_encoder
        self.feature_encoder = feature_encoder
        self.image_reconstructor = feature_decoder
        self.alpha = 0

    def training_step(self, batch, batch_idx):
        v1, vt = batch
        
        v1_features = self.feature_encoder(v1)

        _, _, v1_heatmaps = self.keypoint_generator(v1)
        _, vt_soft_keypoints, vt_heatmaps = self.keypoint_generator(vt)

        vt_pred = self.image_reconstructor(v1_features, vt_soft_keypoints)
        loss, mse, condens = self.keypoint_loss(vt, vt_pred, v1_heatmaps, vt_heatmaps)
        self.log("train/total_loss", loss)
        self.log("train/mse", mse)
        self.log("train/condensation_loss", condens)
        self.log("alpha", self.alpha)

        return loss
    
    def validation_step(self, batch, batch_idx):
        v1, vt = batch

        # Encode features from v0
        v1_features = self.feature_encoder(v1)

        # Get keypoints from v0 and v1 separately
        _, soft_kp_vt, v1_heatmaps = self.keypoint_generator(v1)
        _, soft_kp_vt, vt_heatmaps = self.keypoint_generator(vt)

        # Predict v1 using v0 features + v1 keypoints
        vt_pred = self.image_reconstructor(v1_features, soft_kp_vt)
        loss, mse, condens = self.keypoint_loss(vt, vt_pred, v1_heatmaps, vt_heatmaps)

        self.log("val/total_loss", loss)
        self.log("val/mse", mse)
        self.log("val/condensation_loss", condens)

        if batch_idx == 0:
            B, C, img_H, img_W = v1.shape
            _, _, heatmap_H, heatmap_W = vt_heatmaps.shape

            scale_x = img_W / heatmap_W
            scale_y = img_H / heatmap_H

            # Rescale keypoints
            scaled_kps_v0 = soft_kp_vt[0].clone()
            scaled_kps_v0[..., 0] *= scale_x
            scaled_kps_v0[..., 1] *= scale_y

            scaled_kps_v1 = soft_kp_vt[0].clone()
            scaled_kps_v1[..., 0] *= scale_x
            scaled_kps_v1[..., 1] *= scale_y

            # Clamp images for visualization
            v0_img = v1[0].clamp(0, 1)
            v1_img = vt[0].clamp(0, 1)
            v1_pred_img = vt_pred[0].clamp(0, 1)

            # Overlay keypoints
            vis_v0 = draw_keypoints_on_image(v0_img, scaled_kps_v0, color='b')  # Blue keypoints
            vis_v1 = draw_keypoints_on_image(v1_img, scaled_kps_v1, color='r')  # Red keypoints
            vis_pred = draw_keypoints_on_image(v1_pred_img, scaled_kps_v1, color='r')

            # Combine visuals
            grid = make_grid([vis_v0, vis_v1, vis_pred])
            self.logger.experiment.add_image("val/visualize_key_reconstruction", grid, self.current_epoch)

            # Log heatmaps of v1
            sample_heatmaps = vt_heatmaps[0].unsqueeze(1)  # [K, 1, H, W]
            heatmap_grid = make_grid(sample_heatmaps, nrow=10, normalize=True, scale_each=True)
            self.logger.experiment.add_image("val/heatmaps_v1", heatmap_grid, self.current_epoch)



    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=1e-3)
        return optimizer
    
    def keypoint_loss(self, vt, vt_pred, v0_heat, vt_heat):

        mse = nn.functional.mse_loss(vt, vt_pred)


        # condensation_loss_1 = (v0_heat.amax(dim=(-1,-2)) - v0_heat.mean(dim=(-1,-2))).mean()
        # condensation_loss_t = (vt_heat.amax(dim=(-1,-2)) - vt_heat.mean(dim=(-1,-2))).mean()
        # condensation_loss = -(condensation_loss_1 + condensation_loss_t)/2
        condensation_loss_1 = self.condensation_loss_entropy(v0_heat)
        condensation_loss_t = self.condensation_loss_entropy(vt_heat)
        condensation_loss = (condensation_loss_1 + condensation_loss_t)/2

        return mse+self.alpha*condensation_loss, mse, condensation_loss
    
    def condensation_loss_entropy(self, heatmaps):
        # heatmaps: [B, K, H, W] (softmaxed)
        eps = 1e-8
        log_h = torch.log(heatmaps + eps)
        entropy = -torch.sum(heatmaps * log_h, dim=(-2, -1))  # [B, K]
        return entropy.mean()


In [5]:
x = torch.rand(4,64,32,32)
(x.amax(dim=(-1,-2)) - x.mean(dim=(-1,-2))).mean()

tensor(0.4991)

In [6]:
from torch.utils.data import Dataset
from torchvision.datasets import MNIST
import random
import torch
from torchvision import transforms

class MovingMNISTPairs(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.dataset = MNIST(root=root, train=train, download=True)
        self.transform = transform or transforms.ToTensor()
        self.canvas_size = 128
        self.digit_size = 28
        self.max_pos = self.canvas_size - self.digit_size

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, _ = self.dataset[idx]
        img = self.transform(img)  # [1, 28, 28]

        # Random top-left position where digit fits
        x = random.randint(0, self.max_pos)
        y = random.randint(0, self.max_pos)

        # Create black canvas and place digit
        canvas = torch.zeros(1, self.canvas_size, self.canvas_size)
        canvas[:, y:y+self.digit_size, x:x+self.digit_size] = img

        # Simulate motion (safe shift)
        dx = random.randint(-4, 4)
        dy = random.randint(-4, 4)
        new_x = min(max(x + dx, 0), self.max_pos)
        new_y = min(max(y + dy, 0), self.max_pos)

        moved = torch.zeros_like(canvas)
        moved[:, new_y:new_y+self.digit_size, new_x:new_x+self.digit_size] = img

        return canvas.expand(3, -1, -1), moved.expand(3, -1, -1)  # Convert to RGB

In [7]:
from torch.utils.data import DataLoader, random_split

def get_dataloaders(data_root="./data", batch_size=8):
    full_train = MovingMNISTPairs(data_root, train=True)
    val_size = int(0.1 * len(full_train))
    train_size = len(full_train) - val_size

    train_dataset, val_dataset = random_split(full_train, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

    return train_loader, val_loader

In [8]:
import os
import random
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class CameraSequentialPairs(Dataset):
    def __init__(self, root, transform=None, min_offset=1, max_offset=5):
        """
        Args:
            root (str): Root directory containing camera folders.
            transform: Transformations to apply to each image.
            min_offset (int): Minimum frame difference between pairs.
            max_offset (int): Maximum frame difference between pairs.
        """
        self.transform = transform or transforms.ToTensor()
        self.min_offset = min_offset
        self.max_offset = max_offset
        self.sequences = []

        # List all camera directories
        camera_dirs = sorted([os.path.join(root, d) for d in os.listdir(root)
                              if os.path.isdir(os.path.join(root, d))])
        
        # For each camera directory, get sorted image paths (using numeric sorting)
        for cam_dir in camera_dirs:
            image_dir = os.path.join(cam_dir, "images")
            if not os.path.exists(image_dir):
                continue
            image_files = sorted(
                [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')],
                key=lambda x: int(os.path.splitext(os.path.basename(x))[0].split('_')[-1])
            )
            if len(image_files) > 1:
                self.sequences.append(image_files)
        
        # Build cumulative lengths for indexing purposes
        self.cumulative_lengths = []
        total = 0
        for seq in self.sequences:
            # Each valid pair comes from a starting image (all except the last)
            total += len(seq) - 1
            self.cumulative_lengths.append(total)

    def __len__(self):
        return self.cumulative_lengths[-1] if self.cumulative_lengths else 0

    def __getitem__(self, idx):
    # Find which sequence to use
        seq_idx = 0
        while idx >= self.cumulative_lengths[seq_idx]:
            seq_idx += 1
        seq = self.sequences[seq_idx]
        
        local_idx = idx if seq_idx == 0 else idx - self.cumulative_lengths[seq_idx - 1]
        
        # Check that the index is valid
        if local_idx >= len(seq) - 1:
            raise IndexError("Local index out of range for the sequence")
        
        img1_path = seq[local_idx]

        # Compute the maximum valid offset
        remaining = len(seq) - 1 - local_idx
        effective_max_offset = min(self.max_offset, remaining)
        
        # In case effective_max_offset < min_offset, use effective_max_offset
        # Otherwise, randomly sample
        offset = effective_max_offset  # or if you want random:
        # offset = random.randint(self.min_offset, effective_max_offset)

        img2_idx = local_idx + offset
        
        img2_path = seq[img2_idx]

        try:
            img1 = Image.open(img1_path).convert("RGB")
            img2 = Image.open(img2_path).convert("RGB")
        except Exception as e:
            print(f"Error loading images: {img1_path} or {img2_path}: {e}")
            # Optionally, raise an error here or return a default dummy tensor
            raise e
        
        return self.transform(img1), self.transform(img2)



In [9]:
import os
from torch.utils.data import DataLoader, random_split
from torchvision import transforms

# Assuming CameraSequentialPairs is defined as in the previous code block.
# For example:
# class CameraSequentialPairs(Dataset):
#     ... (code from previous message) ...

def get_camera_dataloaders(data_root="./dataset_00", batch_size=8, min_offset=5, max_offset=30, num_workers=8):
    # For training, we use the "train" subfolder within the data root.
    train_dataset = CameraSequentialPairs(
        root=os.path.join(data_root, "train"),
        transform=transforms.ToTensor(),
        min_offset=min_offset,
        max_offset=max_offset
    )
    
    # Similarly, for validation, use the "val" subfolder.
    val_dataset = CameraSequentialPairs(
        root=os.path.join(data_root, "val"),
        transform=transforms.ToTensor(),
        min_offset=min_offset,
        max_offset=max_offset
    )
    
    # Optionally, print out dataset sizes for debugging
    print("Training dataset pairs:", len(train_dataset))
    print("Validation dataset pairs:", len(val_dataset))
    
    # Build DataLoaders with shuffling for training and no shuffle for validation.
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )
    
    return train_loader, val_loader

# Example usage:
train_loader, val_loader = get_camera_dataloaders(
    data_root="./dataset_00", batch_size=16, min_offset=1, max_offset=5, num_workers=8
)

# For testing, iterate over one batch:
for img1, img2 in train_loader:
    print("Batch shapes:", img1.shape, img2.shape)
    break


Training dataset pairs: 199998
Validation dataset pairs: 19998
Batch shapes: torch.Size([16, 3, 128, 128]) torch.Size([16, 3, 128, 128])


In [10]:
train_loader, val_loader = get_camera_dataloaders(data_root="./dataset_00", batch_size=16)
print(len(train_loader))
for img1, img2 in train_loader:
    print(img1.shape, img2.shape)  # e.g., [16, 3, H, W]
    break


Training dataset pairs: 199998
Validation dataset pairs: 19998
12500
torch.Size([16, 3, 128, 128]) torch.Size([16, 3, 128, 128])


In [12]:
# trainer = L.Trainer(max_epochs=200, callbacks=[AlphaScheduler()])
# model = LitKeypointDetector(KeypointPredictor(100), ImageEncoder(), ImageDecoder(64, 100))

# trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)


In [22]:
encoder_layer = nn.TransformerEncoderLayer(d_model=200, nhead=8, batch_first=True)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
out = transformer_encoder(test)

In [24]:
class SequenceKeypointDetector(nn.Module):
    def __init__(self, keypoint_detector):
        super().__init__()
        self.keypoint_detector = keypoint_detector  # This is your existing 4D keypoint model

    def forward(self, x):
        """
        Args:
          x: input tensor of shape (B, T, C, H, W)
        Returns:
          Output: tensor of shape (B, T, num_keypoints, 2)
        """
        B, T, C, H, W = x.shape
        # Flatten batch and sequence dimensions
        x_flat = x.view(B * T, C, H, W)
        # Process with keypoint detector (assumed to output (B*T, num_keypoints, 2))
        keypoints_flat = self.keypoint_detector(x_flat)[0]  # For example, if keypoint_detector returns (grid_keypoints, soft_keypoints, heatmaps)
        # Reshape the output back to (B, T, num_keypoints, 2)
        keypoints = keypoints_flat.view(B, T, -1, 2)
        return keypoints

In [30]:

from models.keypoints.KeypointPredictor import KeypointPredictor
from models.keypoints.ImageEncoder import ImageEncoder
from models.keypoints.ImageDecoder import ImageDecoder

seq_detector = SequenceKeypointDetector(KeypointPredictor(100))
# Sample input: (64, 100, 3, 128, 128)
input_tensor = torch.randn(8, 100, 3, 128, 128)
output_keypoints = seq_detector(input_tensor)
transformer_input = output_keypoints.view(8,100,-1)
print(transformer_input.size())  # Should print torch.Size([64, 100, 100, 2])



torch.Size([8, 100, 200])


In [33]:
out = transformer_encoder(transformer_input)

In [34]:
out.size()

torch.Size([8, 100, 200])

In [None]:
class LitPredictorT2NOD(L.LightningModule):
    def __init__(self, keypoint_encoder, transformer_encoder, DecoderT2NOD):
        super().__init__()
        self.keypoint_encoder = keypoint_encoder
        self.transformer_encoder = transformer_encoder
        
    def training_step(self, *args, **kwargs):
        return super().training_step(*args, **kwargs)
    def validation_step(self, *args, **kwargs):
        return super().validation_step(*args, **kwargs)