In [1]:
import sys
sys.path.append('../')

import os
import datetime
import lightning as L
from dataset import AEDataset
from torch.utils.data import DataLoader
from lightning_model.autoencoder import LitAE 
from lightning_model.clip import LitTextPointCloudCLIP
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint

In [2]:
from types import SimpleNamespace
config = {
    'enc_filters': (64, 128, 128, 256),
    'latent_dim': 128,
    'enc_bn': True,
    'dec_features': (256, 256),
    'n_pts': 256,
    'dec_bn': False,
}
config = SimpleNamespace(**config)

checkpoint = 'lightning_logs/autoencoder/checkpoints/epoch=999-step=1116000.ckpt'
autoencoder = LitAE.load_from_checkpoint(checkpoint, config=config)
encoder = autoencoder.autoencoder.encoder

In [3]:
lit_clip_model = LitTextPointCloudCLIP(encoder, clip_name="ViT-B/32", device='cuda:0')

In [4]:
root = '..'
dataset_name = 'shapenetcorev2'
# choose split type from 'train', 'test', 'all', 'trainval' and 'val'
# only shapenetcorev2 and shapenetpart dataset support 'trainval' and 'val'

train_dataset = AEDataset(root=root, dataset_name=dataset_name, num_points=config.n_pts, split='train')
val_dataset = AEDataset(root=root, dataset_name=dataset_name, num_points=config.n_pts, split='val')
test_dataset = AEDataset(root=root, dataset_name=dataset_name, num_points=config.n_pts, split='test')

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
print("datasize:", train_dataset.__len__())
print("datasize:", val_dataset.__len__())
print("datasize:", test_dataset.__len__())

datasize: 35708
datasize: 5158
datasize: 10261


In [5]:
date_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="val_loss")
logger = TensorBoardLogger(save_dir=f'lightning_logs', name=f"clip_model_{date_time}")
trainer = L.Trainer(max_epochs=5000, gpus=1, logger=logger, callbacks=[checkpoint_callback])
trainer.fit(model=lit_clip_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

  f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: lightning_logs/clip_model_20240405-183146
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type               | Params
--------------------------------------------------
0 | clip_model | TextPointCloudCLIP | 151 M 
1 | loss_fn    | CrossEntropyLoss   | 0     
--------------------------------------------------
151 M     Trainable params
0         Non-trainable params
1

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]