In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
torch.cuda.empty_cache()

from torchvision.models import resnet18, ResNet18_Weights

In [2]:
from models.keypoints.KeypointPredictor import KeypointPredictor
from models.keypoints.ImageEncoder import ImageEncoder
from models.keypoints.ImageDecoder import ImageDecoder

x = torch.rand(1,3,128,128)
model = KeypointPredictor()
out, soft, m = model(x)
print(out.shape)

torch.Size([1, 100, 2])




In [3]:
x = torch.rand(1,64,32,32)
y = torch.rand(1,100,2)

model = ImageDecoder(64, 100, 3)
out = model(x,y)
out.shape

model = ImageEncoder()
out = model(torch.rand(1,3,128,128))
out.shape

torch.Size([1, 64, 32, 32])

In [None]:
import lightning as L

import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms.functional as TF
from torchvision.utils import make_grid
import io
from PIL import Image

from lightning.pytorch.callbacks import Callback

class AlphaScheduler(Callback):
    def __init__(self, warmup_epochs=2, ramp_epochs=10, final_alpha=0.01):
        """
        Args:
            warmup_epochs: Number of epochs to keep alpha = 0
            ramp_epochs: Number of epochs to linearly ramp up alpha after warmup
            final_alpha: Target alpha value
        """
        self.warmup_epochs = warmup_epochs
        self.ramp_epochs = ramp_epochs
        self.final_alpha = final_alpha

    def on_train_epoch_start(self, trainer, pl_module):
        epoch = trainer.current_epoch

        if epoch < self.warmup_epochs:
            alpha = 0.0
        elif epoch < self.warmup_epochs + self.ramp_epochs:
            # Linear ramp from 0 to final_alpha
            ramp_progress = (epoch - self.warmup_epochs) / self.ramp_epochs
            alpha = ramp_progress * self.final_alpha
        else:
            alpha = self.final_alpha

        pl_module.alpha = alpha

def draw_keypoints_on_image(image_tensor, keypoints, color='r'):
    """
    image_tensor: [3, H, W] (float in [0,1])
    keypoints: [N, 2] with (x, y) format in image coordinates
    """
    image = TF.to_pil_image(image_tensor.cpu())
    plt.figure(figsize=(4, 4))
    plt.imshow(image)
    plt.axis("off")
    for x, y in keypoints.cpu():
        plt.scatter(x, y, c=color, s=10)
    
    buf = io.BytesIO()
    plt.savefig(buf, format='jpeg', bbox_inches='tight', pad_inches=0)
    plt.close()
    buf.seek(0)
    pil_image = Image.open(buf)
    return TF.to_tensor(pil_image)
    
class LitKeypointDetector(L.LightningModule):
    def __init__(self, keypoint_encoder, feature_encoder, feature_decoder):
        super().__init__()
        self.keypoint_generator = keypoint_encoder
        self.feature_encoder = feature_encoder
        self.image_reconstructor = feature_decoder
        self.alpha = 0.0

    def training_step(self, batch, batch_idx):
        v1, vt = batch
        
        v1_features = self.feature_encoder(v1)

        _, _, v1_heatmaps = self.keypoint_generator(v1)
        _, vt_soft_keypoints, vt_heatmaps = self.keypoint_generator(vt)

        vt_pred = self.image_reconstructor(v1_features, vt_soft_keypoints)
        loss, mse, condens = self.keypoint_loss(vt, vt_pred, v1_heatmaps, vt_heatmaps)
        self.log("train_loss", loss)
        self.log("mse", mse)
        self.log("condensation_loss", condens)
        self.log("alpha", self.alpha)

        for name, param in self.named_parameters():
            if param.grad is not None:
                self.log(f"grad_norm/{name}", param.grad.norm())

        return loss
    
    def validation_step(self, batch, batch_idx):
        v0, v1 = batch
        v0_features = self.feature_encoder(v0)
        # Get keypoints and heatmaps for v1; notice we capture all three outputs.
        v1_keypoints, _, v1_heatmaps = self.keypoint_generator(v1)
        v1_pred = self.image_reconstructor(v0_features, v1_keypoints)

        if batch_idx == 0:
            img_H, img_W = v0.shape[-2:]
            heatmap_H, heatmap_W = 32, 32  # or dynamically infer from v1_heatmaps shape

            # Log the raw heatmaps for the first sample.
            # v1_heatmaps shape: [B, num_keypoints, H, W]; here we pick the first sample.
            sample_heatmaps = v1_heatmaps[0]  # shape: [num_keypoints, H, W]

            # Unsqueeze to add a channel dimension if needed (make_grid expects images to have a channel dimension)
            sample_heatmaps_unsq = sample_heatmaps.unsqueeze(1)  # shape: [num_keypoints, 1, H, W]
            # Create a grid with, e.g., 5 heatmaps per row.
            heatmap_grid = make_grid(sample_heatmaps_unsq, nrow=5, normalize=True, scale_each=True)
            self.logger.experiment.add_image("v1_heatmaps", heatmap_grid, self.current_epoch)

            # Scale keypoints to image resolution for visualization.
            scaled_kps = v1_keypoints.clone()
            scaled_kps[..., 0] *= img_W / heatmap_W
            scaled_kps[..., 1] *= img_H / heatmap_H

            # Grab first samples to show reconstructions and keypoints.
            v0_img = v0[0].clamp(0, 1)
            v1_img = v1[0].clamp(0, 1)
            v1_pred_img = v1_pred[0].clamp(0, 1)
            kps = scaled_kps[0]

            # Overlay keypoints on images.
            vis_v0 = draw_keypoints_on_image(v0_img, kps)
            vis_v1 = draw_keypoints_on_image(v1_img, kps)
            vis_pred = draw_keypoints_on_image(v1_pred_img, kps)

            # Combine images side-by-side and log.
            comparison = make_grid([vis_v0, vis_v1, vis_pred])
            self.logger.experiment.add_image("./val/v0_v1_pred_with_keypoints", comparison, self.current_epoch)


    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=1e-3)
        return optimizer
    
    def keypoint_loss(self, vt, vt_pred, v0_heat, vt_heat):

        mse = nn.functional.mse_loss(vt, vt_pred)


        condensation_loss_1 = (v0_heat.amax(dim=(-1,-2)) - v0_heat.mean(dim=(-1,-2))).mean()
        condensation_loss_t = (vt_heat.amax(dim=(-1,-2)) - vt_heat.mean(dim=(-1,-2))).mean()
        condensation_loss = -(condensation_loss_1 + condensation_loss_t)/2

        return mse+self.alpha*condensation_loss, mse, condensation_loss


In [5]:
x = torch.rand(4,64,32,32)
(x.amax(dim=(-1,-2)) - x.mean(dim=(-1,-2))).mean()

tensor(0.4999)

In [6]:
from torch.utils.data import Dataset
from torchvision.datasets import MNIST
import random
import torch
from torchvision import transforms

class MovingMNISTPairs(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.dataset = MNIST(root=root, train=train, download=True)
        self.transform = transform or transforms.ToTensor()
        self.canvas_size = 128
        self.digit_size = 28
        self.max_pos = self.canvas_size - self.digit_size

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, _ = self.dataset[idx]
        img = self.transform(img)  # [1, 28, 28]

        # Random top-left position where digit fits
        x = random.randint(0, self.max_pos)
        y = random.randint(0, self.max_pos)

        # Create black canvas and place digit
        canvas = torch.zeros(1, self.canvas_size, self.canvas_size)
        canvas[:, y:y+self.digit_size, x:x+self.digit_size] = img

        # Simulate motion (safe shift)
        dx = random.randint(-4, 4)
        dy = random.randint(-4, 4)
        new_x = min(max(x + dx, 0), self.max_pos)
        new_y = min(max(y + dy, 0), self.max_pos)

        moved = torch.zeros_like(canvas)
        moved[:, new_y:new_y+self.digit_size, new_x:new_x+self.digit_size] = img

        return canvas.expand(3, -1, -1), moved.expand(3, -1, -1)  # Convert to RGB

In [7]:
from torch.utils.data import DataLoader, random_split

def get_dataloaders(data_root="./data", batch_size=8):
    full_train = MovingMNISTPairs(data_root, train=True)
    val_size = int(0.1 * len(full_train))
    train_size = len(full_train) - val_size

    train_dataset, val_dataset = random_split(full_train, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

    return train_loader, val_loader

In [None]:
import os
import random
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class CameraSequentialPairs(Dataset):
    def __init__(self, root, transform=None, max_offset=5):
        self.transform = transform or transforms.ToTensor()
        self.max_offset = max_offset
        self.sequences = []

        camera_dirs = sorted([os.path.join(root, d) for d in os.listdir(root)
                              if os.path.isdir(os.path.join(root, d))])

        for cam_dir in camera_dirs:
            print(camera_dirs)
            image_dir = os.path.join(cam_dir, "images")
            if not os.path.exists(image_dir):
                continue
            image_files = sorted([
                os.path.join(image_dir, f)
                for f in os.listdir(image_dir) if f.endswith('.png')
            ])
            if len(image_files) > 1:
                self.sequences.append(image_files)

        # Indexing support
        self.cumulative_lengths = []
        total = 0
        for seq in self.sequences:
            total += len(seq) - 1
            self.cumulative_lengths.append(total)

    def __len__(self):
        return self.cumulative_lengths[-1] if self.cumulative_lengths else 0

    def __getitem__(self, idx):
        seq_idx = 0
        while idx >= self.cumulative_lengths[seq_idx]:
            seq_idx += 1
        seq = self.sequences[seq_idx]

        local_idx = idx if seq_idx == 0 else idx - self.cumulative_lengths[seq_idx - 1]

        img1_path = seq[local_idx]
        offset = random.randint(1, self.max_offset)
        img2_idx = min(local_idx + offset, len(seq) - 1)
        img2_path = seq[img2_idx]

        img1 = Image.open(img1_path).convert("RGB")
        img2 = Image.open(img2_path).convert("RGB")

        return self.transform(img1), self.transform(img2)


In [9]:
from torch.utils.data import DataLoader
from torchvision import transforms

def get_camera_dataloaders(data_root="./dataset_00", batch_size=8, max_offset=5, num_workers=8):

    train_dataset = CameraSequentialPairs(
        root=os.path.join(data_root, "train"),
        transform=transforms.ToTensor(),
        max_offset=max_offset
    )

    val_dataset = CameraSequentialPairs(
        root=os.path.join(data_root, "val"),
        transform=transforms.ToTensor(),
        max_offset=max_offset
    )

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )

    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )

    return train_loader, val_loader


In [12]:
train_loader, val_loader = get_camera_dataloaders(data_root="./dataset_00", batch_size=16)
print(len(train_loader))
for img1, img2 in train_loader:
    print(img1.shape, img2.shape)  # e.g., [16, 3, H, W]
    break


1250


torch.Size([16, 3, 128, 128]) torch.Size([16, 3, 128, 128])


In [11]:
trainer = L.Trainer(max_epochs=100, callbacks=[AlphaScheduler()])
model = LitKeypointDetector(KeypointPredictor(100), ImageEncoder(), ImageDecoder(64, 100))

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)


You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                | Type              | Params | Mode 
------------------------------------------------------------------
0 | keypoint_generator  | KeypointPredictor | 164 K  | train
1 | feature_encoder     | ImageEncoder      | 157 K  | train
2 | image_reconstructor | ImageDecoder      | 190 K  | train
---------------------------------

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 24:  55%|█████▍    | 684/1250 [00:15<00:12, 44.99it/s, v_num=0]      


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined