## Seeing whether we can use the CLIP model with Lightning
Lightning handles the training loop, logging, and more. It's a great tool for training models. Let's see if we can use it with CLIP.

### Defining CLIP in a `pl.LightningModule` by reference only
Time to use CLIP as a part of a `LightningModule`, since `open_clip.create_model_and_transforms` is a factory for `CLIP (nn.Module)`. See [model.py](https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/model.py#L155).

In [20]:
import open_clip
import pytorch_lightning as pl
import torch
import torch.nn as nn

from torch import Tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split


class CringeCLIPModel(pl.LightningModule):
    """
        CringeCLIP
    """

    def __init__(self, model_type="RN50", hparams=None, has_cross_attention=False, img_dim=512):
        super().__init__()

        self.clip_module, _, self.preprocess = open_clip.create_model_and_transforms("RN50") # type: ignore
        self.clip_module.load_state_dict(torch.load("rn50-quickgelu-yfcc15m.pt"))
        self.tokenizer = open_clip.get_tokenizer("RN50") # type: ignore

    def forward(self, text = None, image = None):
        if (text is None) and (image is None):
            raise ValueError("Must provide either text or image")
        
        elif image is not None:
            x = self.clip_module.encode_image(image)
            x /= x.norm(dim=-1, keepdim=True)
            return x

        elif text is not None:
            x = self.clip_module.encode_text(text)
            x /= x.norm(dim=-1, keepdim=True)
            return x


    def configure_optimizers(self):
        """
            configure_optimizers

            This is the optimizer for the model.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=5e-5)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        text, output = train_batch
        y = self.forward(text=text)
        loss = F.mse_loss(y, output)

        return loss
    
    def validation_step(self, train_batch, batch_idx):
        text, output = train_batch
        y = self.forward(text=text)
        loss = F.mse_loss(y, output)

        return loss
        

    

### Load `YFCC15M` pretrained model

Loading the pretrained model.

In [21]:
# Define the model
model = CringeCLIPModel()

### Infer `YFCC15M` pretrained model

Inferring from the pretrained model.

In [22]:
example_output = model.forward(text = model.tokenizer(["a photo of a cat",]))
print(x)

tensor([[-0.0199,  0.0126,  0.0156,  ...,  0.0032,  0.0122,  0.0059]],
       grad_fn=<DivBackward0>)


### Dump the model
Now that we've got the model loaded, let's dump it to a file. 😂

First, we need a dummy input to pass to the model. We'll use random text and images.

Then we train for not even one step, then we save the model.

In [24]:
import numpy as np
import torch

from torch.utils.data import Dataset

class DummyDataset(Dataset):
    def __init__(self):
        super().__init__()
        self.constant = 2
        self.batch_size = 20
        
    def __len__(self):
        return 100

    def __getitem__(self, idx):
        text = model.tokenizer(["a photo of a cat",]).squeeze(0)
        output = torch.randn(example_output.size()).squeeze(0)
        
        return text, output

dataset = DummyDataset()    
training_set, validation_set = torch.utils.data.random_split(dataset, [int(len(dataset)*0.8), int(len(dataset)*0.2)])
train_loader = DataLoader(training_set, batch_size=10)
val_loader = DataLoader(validation_set, batch_size=10)


model_trainer = pl.Trainer(max_steps=0)
model_trainer.fit(model, train_loader, val_loader)
model_trainer.save_checkpoint("cringe_clip.ckpt")

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Missing logger folder: /mnt/e/Source/teeny-transformer-experiments/lightning_logs

  | Name        | Type | Params
-------------------------------------
0 | clip_module | CLIP | 102 M 
-------------------------------------
102 M     Trainable params
0         Non-trainable params
102 M     Total params
408.029   Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

`Trainer.fit` stopped: `max_steps=0` reached.




### Confirm the model is saved and working!
Now, let's try this new model with the new definition.

In [26]:
import open_clip
import pytorch_lightning as pl
import torch
import torch.nn as nn

from torch import Tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split


class CringeCLIPModel(pl.LightningModule):
    """
        CringeCLIP
    """

    def __init__(self, model_type="RN50", hparams=None, has_cross_attention=False, img_dim=512):
        super().__init__()

        self.clip_module, _, self.preprocess = open_clip.create_model_and_transforms("RN50") # type: ignore
        self.tokenizer = open_clip.get_tokenizer("RN50") # type: ignore

    def forward(self, text = None, image = None):
        if (text is None) and (image is None):
            raise ValueError("Must provide either text or image")
        
        elif image is not None:
            x = self.clip_module.encode_image(image)
            x /= x.norm(dim=-1, keepdim=True)
            return x

        elif text is not None:
            x = self.clip_module.encode_text(text)
            x /= x.norm(dim=-1, keepdim=True)
            return x


    def configure_optimizers(self):
        """
            configure_optimizers

            This is the optimizer for the model.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=5e-5)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        text, output = train_batch
        y = self.forward(text=text)
        loss = F.mse_loss(y, output)

        return loss
    
    def validation_step(self, train_batch, batch_idx):
        text, output = train_batch
        y = self.forward(text=text)
        loss = F.mse_loss(y, output)

        return loss
    

Then the inference.

In [27]:
# Define the model
model = CringeCLIPModel()
# Load it in
model.load_state_dict(torch.load("cringe_clip.ckpt")["state_dict"])

# All hell breaks loose?
example_output = model.forward(text = model.tokenizer(["a photo of a cat",]))
print(example_output)

tensor([[-0.0199,  0.0126,  0.0156,  ...,  0.0032,  0.0122,  0.0059]],
       grad_fn=<DivBackward0>)
