In [3]:
import sys
sys.path.append('../')

import torch
import clip

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

100%|███████████████████████████████████████| 338M/338M [01:09<00:00, 5.10MiB/s]


In [3]:
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

In [5]:
from dataset import shapenetpart_cat2id


shapenetpart_id2cat = {v: k for k, v in shapenetpart_cat2id.items()}

In [6]:
import lightning as L


class TextPointCloudCLIP(torch.nn.Module):
    def __init__(self, point_cloud_encoder, clip_name="ViT-B/32", device="cuda"):
        super().__init__()
        self.pretrained_model, self.preprocess = clip.load(clip_name, device=device)
        self.point_cloud_encoder = point_cloud_encoder

        self.dim_reduction = torch.nn.Linear(512, 128)
        
    def forward(self, text, encoded_points):
        text_features = self.pretrained_model.encode_text(text)
        text_features = self.dim_reduction(text_features)
        point_cloud_features = self.point_cloud_encoder(encoded_points)
        return text_features, point_cloud_features
        
class LitTextPointCloudCLIP(L.LightningModule):
    def __init__(self, point_cloud_encoder, clip_name="ViT-B/32", device="cuda"):
        super().__init__()
        self.clip_model = TextPointCloudCLIP(point_cloud_encoder, clip_name, device)
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
    
    def _compute_logits(self, features_a, features_b):
        features_a = features_a / features_a.norm(dim=1, keepdim=True)
        features_b = features_b / features_b.norm(dim=1, keepdim=True)
        
        logit_scale = self.logit_scale.exp()
        logits_per_a = logit_scale * features_a @ features_b.t()
        logits_per_b = logits_per_a.t()
        
        return logits_per_a, logits_per_b
    
    def _step(self, batch):
        pts_enc = batch["points_encoded"]
        labels = batch["label"]
        
        text = clip.tokenize([
            shapenetpart_id2cat[l] for l in labels
        ]).to(self.device)
        
        text_features, point_cloud_features = self.clip_model(text, pts_enc)
        
        logits_per_text, logits_per_point_cloud = self._compute_logits(
            text_features, point_cloud_features)
        
        ground_truth = torch.arange(len(labels)).to(self.device)
        clip_loss = (self.loss_fn(logits_per_text, ground_truth) + self.loss_fn(logits_per_point_cloud, ground_truth)) / 2
        
        return clip_loss
    
    def training_step(self, batch, batch_idx):
        loss = self._step(batch, batch_idx)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._step(batch, batch_idx)
        self.log("val_loss", loss)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
        # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9995)
        return optimizer

In [4]:
with torch.no_grad():
    text_features = model.encode_text(text)

In [6]:
text_features.shape

torch.Size([3, 512])

In [7]:
# Modified from https://github.com/openai/CLIP/blob/a1d071733d7111c9c014f024669f959182114e33/clip/model.py
