# CLIP training using HuggingFace libs

This notebook demonstrates how to finetune the CLIP model that was used for Stable Diffusion v1.4-1.5 (`clip-vit-large-patch14-336` aka `ViT-L/14@336px`) using a local dataset stored as `.jpg/.jpeg` and `.txt` file pairs.

## Setup
Run these first. Assumes that PyTorch is already installed.

!pip install -q datasets pillow
# we need v4.26 of transformers - as of writing pip only provides up to v4.25
!pip install -q git+https://github.com/huggingface/transformers
print("--\nDONE")

## Convert the data folder of text/image pairs to a huggingface dataset-compatible json

Replace `root_folder` in the next cell with the top-level folder containing your images, and `out_json` with a path to where the json file representing the image/caption pairs in that folder should be saved.

Note this only works with pairs of the form `filename.jpg`/`filename.txt` or `filename.jpeg`/`filename.txt`.

In [None]:
import os
import json
import pathlib
from typing import Generator

def collect_captioned_images(root_folder: str) -> Generator[tuple[str, str], None, None]:
    for label, folder_name in enumerate(os.listdir(root_folder)):
        folder_path = os.path.join(root_folder, folder_name)
        if os.path.isdir(folder_path):
            for filename in os.listdir(folder_path):
                if filename.lower().endswith(('png', 'jpg', 'jpeg', 'gif')):
                    image_path = os.path.join(folder_path, filename)
                    yield image_path, folder_name  # Assuming folder_name is used as the caption

def convert_text_image_pairs_to_huggingface_json(root_folder: str, out_json: str):
    out_folder = os.path.dirname(out_json)
    pathlib.Path(out_folder).mkdir(parents=True, exist_ok=True)
    with open(out_json, "w") as f:
        written_count = 0
        for image_path, caption in collect_captioned_images(root_folder):
            line_dict = {"image": image_path, "caption": caption}
            json_line = json.dumps(line_dict, indent=None, separators=(",", ":"))
            f.write(json_line + "\n")
            written_count += 1
        print(f"wrote {written_count} lines to {out_json}")

root_folder = "/s/bach/n/under/truongak/am/aesthetics_images_large"
out_json = "/s/bach/n/under/truongak/am/aesthetics_large.json"
convert_text_image_pairs_to_huggingface_json(root_folder, out_json)

In [None]:
!cat "/s/bach/n/under/truongak/am/aesthetics_large.json"

Test that it worked by running the following cell:

In [None]:
from datasets import load_dataset

# Load the dataset
dataset = load_dataset('json', data_files={'data': out_json})

# Shuffle the dataset
shuffled_dataset = dataset['data'].shuffle()

# Split the shuffled dataset into training and validation sets
# Here we split 10% of the data for validation
split_dataset = shuffled_dataset.train_test_split(test_size=0.1)

# Get the training and validation datasets
train_dataset = split_dataset['train']
val_dataset = split_dataset['test']

# Check the sizes of the datasets
print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")

# Save the training and validation datasets to separate JSON files
train_json_path = 'train_aesthetics_large.json'
val_json_path = 'val_aesthetics_large.json'

train_dataset.to_json(train_json_path)
val_dataset.to_json(val_json_path)

print(f"Training dataset saved to: {train_json_path}")
print(f"Validation dataset saved to: {val_json_path}")

from IPython.display import Image
item = train_dataset[1]
print(item['caption'])
Image(filename=item['image'])

In [None]:
import random
item = random.choice(train_dataset)
print(item['caption'])
Image(filename=item['image'])

## Run the finetuning

### Configuration

`repo_id` - The starting point for finetuning. By default this uses the `openai/clip-vit-large-patch14-336` pre-trained CLIP weights. This is what Stable Diffusion versions up to 1.5 used. Another option you might want to consider is `laion/CLIP-ViT-H-14-laion2B-s32B-b79K`, which was used for Stable Diffusion 2.0 onwards.

`output_folder` - Where to store the output. The saving process writes multiple files to this folder, so it should be empty.

`batch_size` - Training batch size. Don't go lower than 8 - try 32 or 64 (unless you only have a few images).

`num_train_epochs` - How many epochs to train. With <500 images each epoch on a 3090 takes a few minutes - do a small number, say `3` to start with, and check the loss when it's done before increasing the number of epochs. With 3 epochs my loss went down to around 2. After 10 epochs it was down to 0.63. Be careful not to over-fit.

In [7]:
train_json_path = 'train_aesthetics_large.json'
val_json_path = 'val_aesthetics_large.json'

repo_id =  "openai/clip-vit-large-patch14-336"
output_folder = "./output/clip-finetuned-csu-p14-336-e4l57-l"
batch_size = 128
num_train_epochs = 3

In [None]:
print(f"Finetuning {repo_id} for {num_train_epochs} epochs with batch size {batch_size}, and then saving output to {output_folder}.")
print(f"train file {train_json_path}")
print(f"validation file {val_json_path}")
!python huggingface_finetune_clip.py \
    --output_dir {output_folder} \
    --model_name_or_path {repo_id} \
    --train_file {train_json_path} \
    --validation_file {val_json_path} \
    --save_total_limit=2 \
    --eval_strategy="steps" \
    --load_best_model_at_end=True \
    --image_column image \
    --caption_column caption \
    --max_seq_length=77 \
    --num_train_epochs={num_train_epochs} \
    --remove_unused_columns=False \
    --do_train \
    --do_eval \
    --per_device_train_batch_size={batch_size} \
    --learning_rate="5e-7" --warmup_steps="0" --weight_decay 0.5 \
    --auto_find_batch_size=True \
    --hub_token="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" \
    --push_to_hub \
    --hub_strategy="all_checkpoints" \


#     --test_file {test_json_path} \
#     --do_predict \
#     --resume_from_checkpoint="./output/clip-finetuned-csu-b32-b8e3l55/checkpoint-6500"
#     --overwrite_output_dir=True \


print("--\nDONE")
print(f"If it worked, trained data should be in {output_folder}")

Finetuning openai/clip-vit-large-patch14-336 for 3 epochs with batch size 128, and then saving output to ./output/clip-finetuned-csu-p14-336-e4l57-l.
train file train_aesthetics_large.json
validation file val_aesthetics_large.json
------training phase
 does not have profile information (Triggered internally at  /opt/conda/conda-bld/pytorch_1659484806139/work/torch/csrc/jit/codegen/cuda/graph_fuser.cpp:104.)
  return forward_call(*input, **kwargs)

  0%|                                                   | 0/441 [00:10<?, ?it/s][A
  0%|                                                   | 0/879 [00:03<?, ?it/s]

  0%|                                                  | 0/1758 [00:01<?, ?it/s][A
  0%|                                                  | 0/3516 [00:01<?, ?it/s]

  0%|                                                  | 0/7032 [00:00<?, ?it/s][A

  0%|                                       | 1/14064 [00:01<6:34:15,  1.68s/it][A
{'loss': 0.3808, 'grad_norm': 401.6118774414062

 42%|█████████████████▌                        | 52/124 [00:25<00:36,  2.00it/s][A
 43%|█████████████████▉                        | 53/124 [00:26<00:33,  2.10it/s][A
 44%|██████████████████▎                       | 54/124 [00:26<00:34,  2.04it/s][A
 44%|██████████████████▋                       | 55/124 [00:27<00:34,  2.00it/s][A
 45%|██████████████████▉                       | 56/124 [00:28<00:41,  1.62it/s][A
 46%|███████████████████▎                      | 57/124 [00:28<00:38,  1.72it/s][A
 47%|███████████████████▋                      | 58/124 [00:28<00:35,  1.86it/s][A
 48%|███████████████████▉                      | 59/124 [00:29<00:38,  1.71it/s][A
 48%|████████████████████▎                     | 60/124 [00:30<00:35,  1.79it/s][A
 49%|████████████████████▋                     | 61/124 [00:30<00:33,  1.91it/s][A
 50%|█████████████████████                     | 62/124 [00:31<00:31,  2.00it/s][A
 51%|█████████████████████▎                    | 63/124 [00:31<00:29,  2.09i

  7%|███                                        | 9/124 [00:04<00:55,  2.07it/s][A
  8%|███▍                                      | 10/124 [00:04<00:53,  2.11it/s][A
  9%|███▋                                      | 11/124 [00:05<00:55,  2.05it/s][A
 10%|████                                      | 12/124 [00:05<00:55,  2.01it/s][A
 10%|████▍                                     | 13/124 [00:06<00:55,  2.01it/s][A
 11%|████▋                                     | 14/124 [00:06<00:57,  1.91it/s][A
 12%|█████                                     | 15/124 [00:07<00:56,  1.94it/s][A
 13%|█████▍                                    | 16/124 [00:07<00:55,  1.94it/s][A
 14%|█████▊                                    | 17/124 [00:08<00:53,  1.98it/s][A
 15%|██████                                    | 18/124 [00:08<00:52,  2.02it/s][A
 15%|██████▍                                   | 19/124 [00:09<00:49,  2.10it/s][A
 16%|██████▊                                   | 20/124 [00:09<00:49,  2.11i

 85%|███████████████████████████████████      | 106/124 [00:53<00:09,  1.90it/s][A
 86%|███████████████████████████████████▍     | 107/124 [00:54<00:08,  2.01it/s][A
 87%|███████████████████████████████████▋     | 108/124 [00:54<00:07,  2.05it/s][A
 88%|████████████████████████████████████     | 109/124 [00:55<00:07,  2.13it/s][A
 89%|████████████████████████████████████▎    | 110/124 [00:55<00:06,  2.05it/s][A
 90%|████████████████████████████████████▋    | 111/124 [00:56<00:06,  2.09it/s][A
 90%|█████████████████████████████████████    | 112/124 [00:56<00:05,  2.12it/s][A
 91%|█████████████████████████████████████▎   | 113/124 [00:57<00:05,  1.99it/s][A
 92%|█████████████████████████████████████▋   | 114/124 [00:57<00:05,  1.86it/s][A
 93%|██████████████████████████████████████   | 115/124 [00:58<00:04,  1.92it/s][A
 94%|██████████████████████████████████████▎  | 116/124 [00:59<00:04,  1.68it/s][A
 94%|██████████████████████████████████████▋  | 117/124 [00:59<00:03,  1.82i

 59%|████████████████████████▋                 | 73/124 [00:37<00:24,  2.06it/s][A
 60%|█████████████████████████                 | 74/124 [00:37<00:25,  1.94it/s][A
 60%|█████████████████████████▍                | 75/124 [00:38<00:25,  1.93it/s][A
 61%|█████████████████████████▋                | 76/124 [00:38<00:24,  1.93it/s][A
 62%|██████████████████████████                | 77/124 [00:39<00:25,  1.83it/s][A
 63%|██████████████████████████▍               | 78/124 [00:40<00:25,  1.82it/s][A
 64%|██████████████████████████▊               | 79/124 [00:40<00:23,  1.91it/s][A
 65%|███████████████████████████               | 80/124 [00:41<00:22,  2.00it/s][A
 65%|███████████████████████████▍              | 81/124 [00:41<00:21,  2.01it/s][A
 66%|███████████████████████████▊              | 82/124 [00:42<00:21,  1.99it/s][A
 67%|████████████████████████████              | 83/124 [00:42<00:20,  1.97it/s][A
 68%|████████████████████████████▍             | 84/124 [00:43<00:20,  1.99i

 29%|████████████▏                             | 36/124 [00:18<00:53,  1.63it/s][A
 30%|████████████▌                             | 37/124 [00:19<00:49,  1.74it/s][A
 31%|████████████▊                             | 38/124 [00:19<00:46,  1.84it/s][A
 31%|█████████████▏                            | 39/124 [00:19<00:44,  1.91it/s][A
 32%|█████████████▌                            | 40/124 [00:20<00:43,  1.93it/s][A
 33%|█████████████▉                            | 41/124 [00:20<00:41,  2.00it/s][A
 34%|██████████████▏                           | 42/124 [00:21<00:40,  2.00it/s][A
 35%|██████████████▌                           | 43/124 [00:21<00:39,  2.03it/s][A
 35%|██████████████▉                           | 44/124 [00:22<00:40,  1.98it/s][A
 36%|███████████████▏                          | 45/124 [00:22<00:40,  1.97it/s][A
 37%|███████████████▌                          | 46/124 [00:23<00:39,  1.99it/s][A
 38%|███████████████▉                          | 47/124 [00:23<00:37,  2.04i

  9%|███▏                                | 2500/28128 [28:22<3:19:00,  2.15it/s]
  0%|                                                   | 0/124 [00:00<?, ?it/s][A
  2%|▋                                          | 2/124 [00:00<00:29,  4.20it/s][A
  2%|█                                          | 3/124 [00:00<00:42,  2.86it/s][A
  3%|█▍                                         | 4/124 [00:01<00:54,  2.21it/s][A
  4%|█▋                                         | 5/124 [00:02<00:53,  2.21it/s][A
  5%|██                                         | 6/124 [00:02<00:52,  2.24it/s][A
  6%|██▍                                        | 7/124 [00:02<00:53,  2.20it/s][A
  6%|██▊                                        | 8/124 [00:03<00:55,  2.09it/s][A
  7%|███                                        | 9/124 [00:03<00:54,  2.10it/s][A
  8%|███▍                                      | 10/124 [00:04<00:54,  2.11it/s][A
  9%|███▋                                      | 11/124 [00:04<00:54,  2.06it/s

 78%|████████████████████████████████▊         | 97/124 [00:49<00:13,  1.98it/s][A
 79%|█████████████████████████████████▏        | 98/124 [00:50<00:13,  2.00it/s][A
 80%|█████████████████████████████████▌        | 99/124 [00:50<00:12,  2.05it/s][A
 81%|█████████████████████████████████        | 100/124 [00:51<00:12,  1.91it/s][A
 81%|█████████████████████████████████▍       | 101/124 [00:51<00:11,  1.95it/s][A
 82%|█████████████████████████████████▋       | 102/124 [00:52<00:12,  1.73it/s][A
 83%|██████████████████████████████████       | 103/124 [00:52<00:11,  1.85it/s][A
 84%|██████████████████████████████████▍      | 104/124 [00:53<00:11,  1.77it/s][A
 85%|██████████████████████████████████▋      | 105/124 [00:53<00:10,  1.86it/s][A
 85%|███████████████████████████████████      | 106/124 [00:54<00:09,  1.85it/s][A
 86%|███████████████████████████████████▍     | 107/124 [00:54<00:08,  1.95it/s][A
 87%|███████████████████████████████████▋     | 108/124 [00:55<00:08,  1.98i

 46%|███████████████████▎                      | 57/124 [00:29<00:39,  1.69it/s][A
 47%|███████████████████▋                      | 58/124 [00:29<00:35,  1.84it/s][A
 48%|███████████████████▉                      | 59/124 [00:30<00:38,  1.69it/s][A
 48%|████████████████████▎                     | 60/124 [00:31<00:36,  1.75it/s][A
 49%|████████████████████▋                     | 61/124 [00:31<00:34,  1.84it/s][A
 50%|█████████████████████                     | 62/124 [00:32<00:32,  1.91it/s][A
 51%|█████████████████████▎                    | 63/124 [00:32<00:30,  2.00it/s][A
 52%|█████████████████████▋                    | 64/124 [00:33<00:29,  2.01it/s][A
 52%|██████████████████████                    | 65/124 [00:33<00:28,  2.04it/s][A
 53%|██████████████████████▎                   | 66/124 [00:34<00:28,  2.06it/s][A
 54%|██████████████████████▋                   | 67/124 [00:34<00:29,  1.95it/s][A
 55%|███████████████████████                   | 68/124 [00:35<00:29,  1.87i

 17%|███████                                   | 21/124 [00:10<01:01,  1.69it/s][A
 18%|███████▍                                  | 22/124 [00:10<00:57,  1.78it/s][A
 19%|███████▊                                  | 23/124 [00:11<00:54,  1.84it/s][A
 19%|████████▏                                 | 24/124 [00:11<00:50,  1.98it/s][A
 20%|████████▍                                 | 25/124 [00:12<00:48,  2.03it/s][A
 21%|████████▊                                 | 26/124 [00:12<00:46,  2.10it/s][A
 22%|█████████▏                                | 27/124 [00:13<00:51,  1.87it/s][A
 23%|█████████▍                                | 28/124 [00:13<00:50,  1.90it/s][A
 23%|█████████▊                                | 29/124 [00:14<00:48,  1.98it/s][A
 24%|██████████▏                               | 30/124 [00:14<00:46,  2.01it/s][A
 25%|██████████▌                               | 31/124 [00:15<00:45,  2.04it/s][A
 26%|██████████▊                               | 32/124 [00:15<00:47,  1.93i

 95%|███████████████████████████████████████  | 118/124 [00:59<00:03,  1.85it/s][A
 96%|███████████████████████████████████████▎ | 119/124 [00:59<00:02,  1.93it/s][A
 97%|███████████████████████████████████████▋ | 120/124 [01:00<00:02,  1.96it/s][A
 98%|████████████████████████████████████████ | 121/124 [01:00<00:01,  1.97it/s][A
 98%|████████████████████████████████████████▎| 122/124 [01:01<00:01,  1.81it/s][A
 99%|████████████████████████████████████████▋| 123/124 [01:02<00:00,  1.98it/s][A
                                                                                [A
[A{'eval_loss': 0.7635349035263062, 'eval_runtime': 62.8257, 'eval_samples_per_second': 15.71, 'eval_steps_per_second': 1.974, 'epoch': 0.37}
 12%|████▍                               | 3500/28128 [41:49<2:57:36,  2.31it/s]
100%|█████████████████████████████████████████| 124/124 [01:02<00:00,  2.15it/s][A
{'loss': 0.1149, 'grad_norm': 7.581352710723877, 'learning_rate': 4.2889647326507393e-07, 'epoch': 0.43}

 69%|████████████████████████████▊             | 85/124 [00:43<00:19,  1.97it/s][A
 69%|█████████████████████████████▏            | 86/124 [00:44<00:19,  1.99it/s][A
 70%|█████████████████████████████▍            | 87/124 [00:44<00:18,  2.05it/s][A
 71%|█████████████████████████████▊            | 88/124 [00:45<00:17,  2.07it/s][A
 72%|██████████████████████████████▏           | 89/124 [00:45<00:16,  2.11it/s][A
 73%|██████████████████████████████▍           | 90/124 [00:46<00:16,  2.03it/s][A
 73%|██████████████████████████████▊           | 91/124 [00:46<00:16,  2.03it/s][A
 74%|███████████████████████████████▏          | 92/124 [00:47<00:15,  2.08it/s][A
 75%|███████████████████████████████▌          | 93/124 [00:47<00:15,  2.04it/s][A
 76%|███████████████████████████████▊          | 94/124 [00:48<00:15,  1.90it/s][A
 77%|████████████████████████████████▏         | 95/124 [00:48<00:14,  1.96it/s][A
 77%|████████████████████████████████▌         | 96/124 [00:49<00:14,  1.95i

 40%|████████████████▉                         | 50/124 [00:24<00:36,  2.01it/s][A
 41%|█████████████████▎                        | 51/124 [00:25<00:38,  1.88it/s][A
 42%|█████████████████▌                        | 52/124 [00:25<00:36,  1.95it/s][A
 43%|█████████████████▉                        | 53/124 [00:26<00:34,  2.06it/s][A
 44%|██████████████████▎                       | 54/124 [00:26<00:34,  2.00it/s][A
 44%|██████████████████▋                       | 55/124 [00:27<00:35,  1.97it/s][A
 45%|██████████████████▉                       | 56/124 [00:28<00:42,  1.62it/s][A
 46%|███████████████████▎                      | 57/124 [00:28<00:38,  1.72it/s][A
 47%|███████████████████▋                      | 58/124 [00:29<00:35,  1.87it/s][A
 48%|███████████████████▉                      | 59/124 [00:29<00:37,  1.72it/s][A
 48%|████████████████████▎                     | 60/124 [00:30<00:35,  1.79it/s][A
 49%|████████████████████▋                     | 61/124 [00:30<00:33,  1.88i

 14%|█████▊                                    | 17/124 [00:08<00:54,  1.96it/s][A
 15%|██████                                    | 18/124 [00:08<00:54,  1.96it/s][A
 15%|██████▍                                   | 19/124 [00:09<00:51,  2.05it/s][A
 16%|██████▊                                   | 20/124 [00:09<00:50,  2.07it/s][A
 17%|███████                                   | 21/124 [00:10<01:03,  1.62it/s][A
 18%|███████▍                                  | 22/124 [00:10<00:59,  1.72it/s][A
 19%|███████▊                                  | 23/124 [00:11<00:56,  1.79it/s][A
 19%|████████▏                                 | 24/124 [00:11<00:51,  1.93it/s][A
 20%|████████▍                                 | 25/124 [00:12<00:50,  1.97it/s][A
 21%|████████▊                                 | 26/124 [00:12<00:47,  2.06it/s][A
 22%|█████████▏                                | 27/124 [00:13<00:52,  1.86it/s][A
 23%|█████████▍                                | 28/124 [00:13<00:50,  1.91i

 92%|█████████████████████████████████████▋   | 114/124 [00:58<00:05,  1.86it/s][A
 93%|██████████████████████████████████████   | 115/124 [00:58<00:04,  1.91it/s][A
 94%|██████████████████████████████████████▎  | 116/124 [00:59<00:04,  1.67it/s][A
 94%|██████████████████████████████████████▋  | 117/124 [01:00<00:03,  1.80it/s][A
 95%|███████████████████████████████████████  | 118/124 [01:00<00:03,  1.82it/s][A
 96%|███████████████████████████████████████▎ | 119/124 [01:01<00:02,  1.90it/s][A
 97%|███████████████████████████████████████▋ | 120/124 [01:01<00:02,  1.92it/s][A
 98%|████████████████████████████████████████ | 121/124 [01:02<00:01,  1.91it/s][A
 98%|████████████████████████████████████████▎| 122/124 [01:02<00:01,  1.75it/s][A
 99%|████████████████████████████████████████▋| 123/124 [01:03<00:00,  1.95it/s][A
                                                                                [A
[A{'eval_loss': 0.7118513584136963, 'eval_runtime': 63.9674, 'eval_samples_

 63%|██████████████████████████▍               | 78/124 [00:39<00:24,  1.90it/s][A
 64%|██████████████████████████▊               | 79/124 [00:39<00:22,  1.99it/s][A
 65%|███████████████████████████               | 80/124 [00:40<00:21,  2.06it/s][A
 65%|███████████████████████████▍              | 81/124 [00:40<00:20,  2.07it/s][A
 66%|███████████████████████████▊              | 82/124 [00:41<00:20,  2.07it/s][A
 67%|████████████████████████████              | 83/124 [00:41<00:20,  2.04it/s][A
 68%|████████████████████████████▍             | 84/124 [00:42<00:19,  2.06it/s][A
 69%|████████████████████████████▊             | 85/124 [00:42<00:19,  2.03it/s][A
 69%|█████████████████████████████▏            | 86/124 [00:43<00:18,  2.05it/s][A
 70%|█████████████████████████████▍            | 87/124 [00:43<00:17,  2.11it/s][A
 71%|█████████████████████████████▊            | 88/124 [00:44<00:16,  2.13it/s][A
 72%|██████████████████████████████▏           | 89/124 [00:44<00:16,  2.18i

 32%|█████████████▌                            | 40/124 [00:20<00:43,  1.93it/s][A
 33%|█████████████▉                            | 41/124 [00:21<00:41,  1.99it/s][A
 34%|██████████████▏                           | 42/124 [00:21<00:41,  2.00it/s][A
 35%|██████████████▌                           | 43/124 [00:21<00:39,  2.05it/s][A
 35%|██████████████▉                           | 44/124 [00:22<00:40,  1.98it/s][A

If it all worked, your finetuned CLIP model is in the `output_folder` defined above.

In [None]:
import os
last_checkpoint = "output/clip-finetuned-csu-p14-336-e3l37-l/checkpoint-26649/trainer_state.json"
if os.path.exists(last_checkpoint):
    print("REMOVED")
    !rm -rf output/clip-finetuned-csu-p14-336-e3l37-l
else:
    print("not REMOVED")
    print(xxx)

    
train_json_path = 'train_aesthetics_large.json'
val_json_path = 'val_aesthetics_large.json'

repo_id =  "openai/clip-vit-large-patch14-336"
output_folder = "./output/clip-finetuned-csu-p14-336-e3l87-l"
batch_size = 128
num_train_epochs = 3

print(f"Finetuning {repo_id} for {num_train_epochs} epochs with batch size {batch_size}, and then saving output to {output_folder}.")
print(f"train file {train_json_path}")
print(f"validation file {val_json_path}")
!python huggingface_finetune_clip.py \
    --output_dir {output_folder} \
    --model_name_or_path {repo_id} \
    --train_file {train_json_path} \
    --validation_file {val_json_path} \
    --save_total_limit=2 \
    --eval_strategy="steps" \
    --load_best_model_at_end=True \
    --image_column image \
    --caption_column caption \
    --max_seq_length=77 \
    --num_train_epochs={num_train_epochs} \
    --remove_unused_columns=False \
    --do_train \
    --do_eval \
    --per_device_train_batch_size={batch_size} \
    --learning_rate="8e-7" --warmup_steps="0" --weight_decay 0.1 \
    --auto_find_batch_size=True \
    --hub_token="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" \
    --push_to_hub \
    --hub_strategy="all_checkpoints" \

#     --test_file {test_json_path} \
#     --do_predict \
#     --resume_from_checkpoint="./output/clip-finetuned-csu-b32-b8e3l55/checkpoint-6500"
#     --overwrite_output_dir=True \

print("--\nDONE")
print(f"If it worked, trained data should be in {output_folder}")

In [None]:
import os
last_checkpoint = "output/clip-finetuned-csu-p14-336-e3l87-l/checkpoint-26649/trainer_state.json"
if os.path.exists(last_checkpoint):
    print("REMOVED")
    !rm -rf output/clip-finetuned-csu-p14-336-e3l87-l
else:
    print("not REMOVED")
    print(xxx)

    
train_json_path = 'train_aesthetics_large.json'
val_json_path = 'val_aesthetics_large.json'

repo_id =  "openai/clip-vit-large-patch14-336"
output_folder = "./output/clip-finetuned-csu-p14-336-e4l27-l"
batch_size = 128
num_train_epochs = 4

print(f"Finetuning {repo_id} for {num_train_epochs} epochs with batch size {batch_size}, and then saving output to {output_folder}.")
print(f"train file {train_json_path}")
print(f"validation file {val_json_path}")
!python huggingface_finetune_clip.py \
    --output_dir {output_folder} \
    --model_name_or_path {repo_id} \
    --train_file {train_json_path} \
    --validation_file {val_json_path} \
    --save_total_limit=2 \
    --eval_strategy="steps" \
    --load_best_model_at_end=True \
    --image_column image \
    --caption_column caption \
    --max_seq_length=77 \
    --num_train_epochs={num_train_epochs} \
    --remove_unused_columns=False \
    --do_train \
    --do_eval \
    --per_device_train_batch_size={batch_size} \
    --learning_rate="2e-7" --warmup_steps="0" --weight_decay 0.1 \
    --auto_find_batch_size=True \
    --hub_token="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" \
    --push_to_hub \
    --hub_strategy="all_checkpoints" \

#     --test_file {test_json_path} \
#     --do_predict \
#     --resume_from_checkpoint="./output/clip-finetuned-csu-b32-b8e3l55/checkpoint-6500"
#     --overwrite_output_dir=True \

print("--\nDONE")
print(f"If it worked, trained data should be in {output_folder}")