In [1]:
import sys, os

# sys path hack to allow importing the encoding functions and other modules
sys.path.insert(0, os.path.abspath('../src'))
sys.path.insert(0, os.path.abspath('../externals'))

In [2]:
import torch 

if torch.cuda.is_available():
    gpu = torch.device("cuda")
elif torch.backends.mps.is_available():
    gpu = torch.device("mps")
else:
    print("Warning: no GPU detected, falling back to CPU")
    gpu = torch.device("cpu")

In [3]:
%load_ext tensorboard

In [4]:
from transformers import CLIPModel, CLIPImageProcessor, CLIPTokenizer

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(gpu)
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

  from .autonotebook import tqdm as notebook_tqdm
2023-05-28 19:58:18.342079: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-05-28 19:58:18.366558: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
from prompt_dataset import PromptDataset
import pandas as pd
from torch.utils.data import DataLoader

training_rel_samples = pd.read_pickle("../data/finetuning/train_rel_samples_all.pkl")
training_rel_set = PromptDataset(training_rel_samples, prompt_transform=lambda e: [
    f"{e['object0_name']} {e['rel']} {e['object1_name']}"
], img_size=224, mode="pad")
training_loader = DataLoader(training_rel_set, batch_size=32, shuffle=True)

In [9]:
validation_rel_samples = pd.read_pickle("../data/finetuning/val_rel_samples_50k.pkl")
validation_rel_set = PromptDataset(validation_rel_samples, prompt_transform=lambda e: [
    f"{e['object0_name']} {e['rel']} {e['object1_name']}"
], img_size=224, mode="pad")
validation_loader = DataLoader(validation_rel_set, batch_size=32, shuffle=True)

# Finetuning

In [16]:
import transformers

def get_linear_schedule_with_warumup(optimizer, len_dataset, batch_size, num_epochs, warmup_percentage=0.2):
    steps_per_epoch = len_dataset // batch_size
    total_steps = steps_per_epoch * num_epochs
    warmup_steps = total_steps * warmup_percentage
    return transformers.get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps - warmup_steps)

In [17]:
from torch.nn.functional import cross_entropy

def loss_fn(logits):
    labels = torch.arange(logits.shape[0])
    loss_images = cross_entropy(logits, labels)
    loss_texts = cross_entropy(logits.t(), labels)
    return (loss_images + loss_texts) / 2

In [18]:
def predict(batch):
    images = batch[0].unbind(0)
    texts = list(batch[1][0])

    image_inputs = image_processor(images, return_tensors="pt", do_resize=False, do_center_crop=False).to(gpu)
    text_inputs = tokenizer(texts, return_tensors="pt", padding=True).to(gpu)

    logits = model(**image_inputs, **text_inputs)["logits_per_image"]
    logits_cpu = logits.to("cpu")
    del images, texts, image_inputs, text_inputs, logits
    return logits_cpu

In [19]:
def train_one_epoch(epoch, optimizer, scheduler, training_loader, tb_writer):
    running_loss = 0.
    last_loss = 0.

    for i, batch in enumerate(training_loader):
        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        logits = predict(batch)

        # Compute the loss and its gradients
        loss = loss_fn(logits)
        loss.backward()

        # Adjust learning weights
        optimizer.step()
        scheduler.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 200 == 199:
            last_loss = running_loss / 200 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.
        del logits, loss

    return last_loss

In [20]:
def validate(validation_loader):
    running_vloss = 0.0
    for i, batch in enumerate(validation_loader):
        logits = predict(batch)
        vloss = loss_fn(logits)
        running_vloss += vloss.item()
        del logits, vloss

    return running_vloss / (i + 1)

In [21]:
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/clip_finetune_{}'.format(timestamp))

epoch_number = 0

EPOCHS = 5

optimizer = torch.optim.Adam(model.parameters(), lr=5e-7)
scheduler = get_linear_schedule_with_warumup(optimizer, training_rel_set.__len__(), training_loader.batch_size, EPOCHS)
best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)

    # Train for one epoch
    avg_loss = train_one_epoch(epoch, optimizer, scheduler, training_loader, writer)

    # We don't need gradients on to do reporting
    model.train(False)

    torch.cuda.empty_cache()

    # Validate
    avg_vloss = validate(validation_loader)

    print('LOSS train {} validation {}'.format(avg_loss, avg_vloss))
    writer.add_scalars('Training vs. Validation Loss', {
        'Training' : avg_loss, 
        'Validation' : avg_vloss 
    }, epoch + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        print('New best model found!')
        best_vloss = avg_vloss
        model.save_pretrained(f"model_snapshots/model_{timestamp}_{epoch}_{best_vloss:.2f}")

EPOCH 1:
  batch 200 loss: 1.2607923579216003
  batch 400 loss: 1.253549875319004
  batch 600 loss: 1.248550270795822
  batch 800 loss: 1.2203076487779618
  batch 1000 loss: 1.1222080099582672
  batch 1200 loss: 1.074778725504875
  batch 1400 loss: 1.0473621149361134
  batch 1600 loss: 0.9823834873735905
  batch 1800 loss: 0.9771840846538544
  batch 2000 loss: 0.9560783676803112
  batch 2200 loss: 0.9312538653612137
  batch 2400 loss: 0.8800676545500755
  batch 2600 loss: 0.8826087538897991
  batch 2800 loss: 0.8576292327046394
  batch 3000 loss: 0.8183759136497974
  batch 3200 loss: 0.8408213220536709
  batch 3400 loss: 0.7943902204930783
  batch 3600 loss: 0.7578974530100823
  batch 3800 loss: 0.734758093804121
  batch 4000 loss: 0.7302561473101378
  batch 4200 loss: 0.7191734200716019
  batch 4400 loss: 0.6998930731415749
  batch 4600 loss: 0.6993444043397904
  batch 4800 loss: 0.675366016253829
  batch 5000 loss: 0.6951696416735649
  batch 5200 loss: 0.6752667605876923
  batch 5400

  batch 14000 loss: 0.27273024601861834
LOSS train 0.27273024601861834 validation 0.4794739084255775
EPOCH 4:
  batch 200 loss: 0.25151115051470696
  batch 400 loss: 0.24852171598002315
  batch 600 loss: 0.25081537118181585
  batch 800 loss: 0.26017503429204225
  batch 1000 loss: 0.2611316084302962
  batch 1200 loss: 0.25259262027218937
  batch 1400 loss: 0.2601969661936164
  batch 1600 loss: 0.24145988695323467
  batch 1800 loss: 0.24964060816913844
  batch 2000 loss: 0.2521200533397496
  batch 2200 loss: 0.24789788333699106
  batch 2400 loss: 0.2508187795802951
  batch 2600 loss: 0.23587181514129044
  batch 2800 loss: 0.2411006654240191
  batch 3000 loss: 0.25045114394277335
  batch 3200 loss: 0.25605797324329616
  batch 3400 loss: 0.24252273278310896
  batch 3600 loss: 0.2557850718125701
  batch 3800 loss: 0.23374844616279006
  batch 4000 loss: 0.24238831143826245
  batch 4200 loss: 0.24150524575263263
  batch 4400 loss: 0.246237119352445
  batch 4600 loss: 0.2561848810128868
  batc