In [None]:
from torchvision.transforms import v2 as transforms

import lightning as pl
from lightning.pytorch.callbacks import ModelCheckpoint

from keyrover.datasets import *
from keyrover.vision import *
from keyrover import *

device

In [None]:
train_paths, _, valid_paths = split_train_test_valid(image_paths, 1, 0.1)

SIZE = (256, 256)

train_dataset = KeyboardCameraTransformDataset(train_paths, size=SIZE)
valid_dataset = KeyboardCameraTransformDataset(valid_paths, size=SIZE)

len(train_dataset), len(valid_dataset)

In [None]:
train_dataset.set_augmentations([
    transforms.ToDtype(torch.float32, scale=True),
    transforms.GaussianNoise(sigma=0.01),
    transforms.RandomApply([transforms.GaussianNoise(sigma=0.01)], p=0.5),
    transforms.Normalize(mean, std),
])

test_transforms = transforms.Compose([
    transforms.ToImage(),
    transforms.Resize(SIZE),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize(mean, std),
])

valid_dataset.set_augmentations([
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize(mean, std),
])

In [None]:
img, target = train_dataset.random_img()
print("Target:", target)
print("Image:", img.min(), img.max())

mean = torch.tensor([0.29174, 8.5515e-06, 0.023512, -0.20853, -0.80377, 3.3909], device=device)
std = torch.tensor([0.14669, 0.047459, 1.1898, 0.9208, 0.71566, 0.85268], device=device)

target = target.unsqueeze(0).to(device) * std + mean
texcoords = prediction_to_texture_coordinates(target)
imshow(img, texcoords[0])

In [None]:
from torch.utils.data import DataLoader

BATCH_SIZE = 128

dl_kwargs = {"batch_size": BATCH_SIZE, "num_workers": 2, "persistent_workers": True, "pin_memory": False}

train_dataloader = DataLoader(train_dataset, **dl_kwargs, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, **dl_kwargs)

In [None]:
from torchvision import models


class CornersRegressionModel(pl.LightningModule):
    def __init__(self, lr: float | None = None) -> None:
        super().__init__()

        self.model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

        # Freeze the parameters of the pre-trained layers
        # for param in self.model.parameters():
        #     param.requires_grad = False

        # Unfreeze the parameters of the last few layers for fine-tuning
        # for param in self.model.layer4.parameters():
        #     param.requires_grad = True

        self.loss_fn = torch.nn.MSELoss()

        self.model.fc = torch.nn.Sequential(
            torch.nn.Linear(self.model.fc.in_features, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 6),
        )

        self.learning_rate = lr
        self.lr = self.learning_rate
        self.save_hyperparameters()

    def predict(self, image: torch.Tensor) -> np.ndarray:
        image = image.to(self.device)
        if len(image.shape) == 3:
            image = image.unsqueeze(0)

        with torch.no_grad():
            pred = self.forward(image)
        pred = prediction_to_texture_coordinates(pred)

        if len(pred) == 1:
            return pred[0,]
        return pred

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        return self.model(image)

    def _step(self, batch: tuple[torch.Tensor, torch.Tensor], stage: str) -> float:
        image, target = batch
        predictions = self.model(image)

        loss = self.loss_fn(predictions, target)
        self.log(f"{stage}_loss", loss)
        return loss

    def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> float:
        return self._step(batch, "train")

    def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> float:
        return self._step(batch, "val")

    def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> float:
        return self._step(batch, "test")

    def configure_optimizers(self) -> dict:
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        return optimizer

In [None]:
import wandb
from lightning.pytorch.loggers import WandbLogger

wandb.login()

In [None]:
LEARNING_RATE = 1e-4

wandb.finish()
model = CornersRegressionModel(lr=LEARNING_RATE)
model

In [None]:
summarize(model)

In [None]:
logger = WandbLogger(project="mrover-keyboard-corner-prediction")

checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min")

trainer = pl.Trainer(log_every_n_steps=1, logger=logger, max_epochs=100, callbacks=[checkpoint_callback])
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)

In [None]:
torch.save(model.state_dict(), f"models/transform-prediction/{wandb.run.name}.pt")

In [None]:
model = CornersRegressionModel()
model.load_state_dict(torch.load(f"models/transform-prediction/balmy-sponge-6.pt", weights_only=True))

In [None]:
model.to(device)
model.eval()

image, target = valid_dataset.random_img()
image = image.to(device).unsqueeze(0)
target = target.to(device).unsqueeze(0)

with torch.no_grad():   
    pred = model(image)

pred = (pred * std + mean)
target = (target * std + mean)

pred = prediction_to_texture_coordinates(pred)
target = prediction_to_texture_coordinates(target)

show_images([image[0], pred[0], image[0], target[0]])

In [None]:
vidcap = cv2.VideoCapture(f"{TEST_DATASET}/110.mp4")
total = vidcap.get(cv2.CAP_PROP_FRAME_COUNT)
frame = 150
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame)
_, image = vidcap.read()

image = test_transforms(image).unsqueeze(0)
image = image.to(device)

with torch.no_grad():
    pred = model(image)

pred = (pred * std + mean)
pred = prediction_to_texture_coordinates(pred)

imshow(image[0], pred[0])

In [None]:
print(image.min(), image.max())
image.shape