# Fine-Tuning Clip Visual Encoder on Custom Data

The text encoder is frozen, the visual encoder is finetuned on the custom data

In [None]:
!pip install torch pillow open_clip_torch

In [1]:
import json
import os

from PIL import Image

import torch

import open_clip

Find device (CUDA or CPU)

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

### Get OpenClip Model

https://github.com/mlfoundations/open_clip

In [3]:
model, _, preprocess = open_clip.create_model_and_transforms("RN50", "cc12m")
model.transformer.eval()
tokenizer = open_clip.get_tokenizer("RN50")

  checkpoint = torch.load(checkpoint_path, map_location=map_location)


### Define Dataset and Dataloader

In [4]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, descr_fpath, image_folder, preprocess, tokenizer):
        self.descriptions = list(json.load(open(descr_fpath, mode="r")).items())
        self.image_folder = image_folder
        self.preprocess = preprocess
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.descriptions)

    def __getitem__(self, idx):
        img_fname, description = self.descriptions[idx]
        img = Image.open(os.path.join(self.image_folder, img_fname))
        img = self.preprocess(img)
        tokens = tokenizer(description)
        return img, tokens

In [5]:
dataset = Dataset("data/descriptions.json", "data/images", preprocess, tokenizer)

In [6]:
dataset_train, dataset_test = torch.utils.data.random_split(dataset, [0.7, 0.3])

In [7]:
batch_size = 16

In [8]:
dataloader_train = torch.utils.data.DataLoader(
    dataset_train,
    shuffle=True,
    batch_size=batch_size,
    num_workers=4,
    drop_last=True,
)

In [9]:
dataloader_test = torch.utils.data.DataLoader(
    dataset_test,
    shuffle=False,
    batch_size=batch_size,
    num_workers=4,
    drop_last=True,
)

### Training Loop

In [10]:
num_epochs = 8
optimizer = torch.optim.AdamW(model.visual.parameters())
loss_fn = torch.nn.CrossEntropyLoss()
model = model.to(device)
scaler = torch.amp.GradScaler(device)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=0.0001,
    steps_per_epoch=len(dataloader_train),
    epochs=num_epochs,
    pct_start=0.1,
)

In [11]:
for epoch in range(num_epochs):
    print("epoch", epoch + 1)
    model.visual.train()

    losses = []
    for img, tokens in dataloader_train:
        with torch.autocast(device, dtype=torch.float16):
            # encode features
            img_features = model.encode_image(img.to(device))
            text_features = model.encode_text(tokens.squeeze(1).to(device))

            # normalize features
            img_features = img_features / img_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            # get loss
            logits = 100.0 * img_features @ text_features.T
            loss = loss_fn(logits, torch.arange(batch_size).to(device))

        # do optimization step
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()
        losses.append(loss.item())

    print("train:", sum(losses) / len(losses))

    losses = []
    model.visual.eval()
    with torch.no_grad(), torch.autocast(device, dtype=torch.float16):
        img_features = model.encode_image(img.to(device))
        text_features = model.encode_text(tokens.squeeze(1).to(device))
        
        img_features = img_features / img_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    
        logits = 100.0 * img_features @ text_features.T
        loss = loss_fn(logits, torch.arange(batch_size).to(device))
        losses.append(loss.item())

    print("test:", sum(losses) / len(losses))
    print()

epoch 1
train: 1.6339112717744917
test: 0.48242297768592834

epoch 2
train: 0.618115863307964
test: 0.29135844111442566

epoch 3
train: 0.25435521696196045
test: 0.054301679134368896

epoch 4
train: 0.11629492373660553
test: 0.31455808877944946

epoch 5
train: 0.07687708475562029
test: 0.0015368163585662842

epoch 6
train: 0.06449450136617173
test: 0.0017068013548851013

epoch 7
train: 0.03913865966159244
test: 0.04371177405118942

epoch 8
train: 0.05220240943653639
test: 0.0405363067984581



### Save Model

In [12]:
checkpoint = {
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "scaler": scaler.state_dict(),
    "scheduler": scheduler.state_dict(),
    "epoch": epoch,
}
torch.save(checkpoint, "models/model.chkpt")