In [1]:
import random

import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm

from tqdm import tqdm

from torch import nn
from torchvision.datasets import ImageFolder
from torchvision import transforms as T
from torch.utils.data import random_split, DataLoader

import lightning as pl
from lightning.pytorch.loggers import TensorBoardLogger

from dlc.trainers.img_classification import launch_training
from dlc.vq_vae.model import VQVAE
from dlc.trainers.vqvae import VQVAELightningModule

In [2]:
transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
dataset = ImageFolder(root="D:/data/images/galaxy10_unamur/train", transform=transform)
train_dataset, test_dataset = random_split(dataset, lengths=(0.8, 0.2))
print("#train samples:", len(train_dataset))
print("#test samples:", len(test_dataset))

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)
print("#train batches:", len(train_dataloader))
print("#test batches:", len(test_dataloader))

#train samples: 14691
#test samples: 3672
#train batches: 919
#test batches: 230


In [3]:
def display_sample(sample, unnormalize=True, ax=None):
    ax = plt.gca() if ax is None else ax
    img, label = sample
    img = img * 0.5 + 0.5 if unnormalize else img
    ax.set_title(dataset.classes[label])
    ax.axis('off')
    ax.imshow(img.permute(1, 2, 0).cpu().numpy())

def display_batch(batch, unnormalize=True):
    imgs, labels = batch
    samples = random.sample(list(zip(imgs, labels)), k=8)
    plt.figure(figsize=(16, 8))
    for i, sample in enumerate(samples):
        plt.subplot(2, 4, i + 1)
        display_sample(sample, unnormalize=unnormalize)
    plt.tight_layout()
    plt.show()

In [4]:
vq_vae = VQVAE(
    in_channels=3,
    embedding_dim=256,
    n_embeddings=512,
    hidden_channels_enc=(64, 128, 256, 512, 512),
    hidden_channels_dec=(512, 512, 256, 128, 64),
    commitment_loss_factor=0.25,
    quantization_loss_factor=1.0,
)

lightning_module = VQVAELightningModule(
    vq_vae=vq_vae,
    learning_rate=4e-4,
    n_warmup_epochs=20,
    plateau_patience=5,
    plateau_factor=0.5,
)

In [5]:
lightning_trainer = pl.Trainer(
    fast_dev_run=True,
    max_epochs=500,
    logger=TensorBoardLogger(save_dir="logs"),
)

lightning_trainer.fit(
    lightning_module,
    train_dataloaders=train_dataloader,
    val_dataloaders=test_dataloader,
)

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
  return _C._get_float32_matmul_precision()
You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                 | Type          | Params | Mode 
------------------------------------

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.
