In [1]:
import sys
import os
from pathlib import Path
sub_project_dir = Path(os.path.abspath(''))
project_dir = sub_project_dir.parent
sys.path.insert(0, project_dir.parent.as_posix())

import evaluate
import numpy as np
import torch
from matplotlib import pyplot as plt
from transformers import DataCollatorWithPadding, Trainer, TrainingArguments

%matplotlib inline
%load_ext autoreload
%autoreload 2
from llm_papers.clip.dataset import load_cifar100, load_coco2017
from llm_papers.clip.model import CLIP, load_pretrained_roberta, load_pretrained_vit
from llm_papers.utils import device

In [2]:
tokenizer, pretrained_text_model = load_pretrained_roberta()
transform, pretrained_vision_model = load_pretrained_vit()

In [3]:
model = CLIP()
model.load(pretrained_text_model, pretrained_vision_model)
model.freeze()
model.to(device)

CLIP(
  (text_model): Roberta(
    (embeddings): RobertaEmbeddings(
      (word_embed): Embedding(50265, 768, padding_idx=1)
      (pos_embed): Embedding(514, 768, padding_idx=1)
      (type_embed): Embedding(1, 768, padding_idx=0)
      (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (blocks): Sequential(
      (0): Block(
        (norm_attn): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (Wqkv): Linear(in_features=768, out_features=2304, bias=True)
          (Wo): Linear(in_features=768, out_features=768, bias=True)
        )
        (norm_mlp): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (1): Block(
        (norm_attn): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        

In [4]:
train_dataset_coco = load_coco2017(transform, tokenizer)
train_dataset_cifar, test_dataset = load_cifar100(transform, tokenizer)
data_collator_pad = DataCollatorWithPadding(tokenizer=tokenizer, pad_to_multiple_of=8)


def compute_metrics(eval_pred):
    (text_logits, image_logits), labels = eval_pred
    image_predictions = np.argmax(image_logits, -1)

    accuracy = evaluate.load("accuracy")
    results = accuracy.compute(predictions=image_predictions, references=labels)

    return results


def data_collator(features):
    batch = {}
    for key in features[0]:
        if key in ["input_ids", "attention_mask"]:
            continue
        batch[key] = torch.stack([f[key] for f in features])
    if "input_ids" not in features[0]:
        batch["input_ids"] = test_dataset.input_ids
        batch["attention_mask"] = test_dataset.attention_mask
    else:
        text_batch = data_collator_pad(
            [
                {"input_ids": f["input_ids"], "attention_mask": f["attention_mask"]}
                for f in features
            ]
        )
        batch.update(text_batch)
    return batch

In [5]:
def model_init():
    model.init_weights()
    return model


lr = 5e-4
batch_size = 256
gradient_accumulation_steps = 8
epochs = 5
wd = 0.0
logging_steps = 5
eval_steps = 100
save_steps = 100
report_to = "wandb"
train_coco = False
if train_coco:
    train_ds_name = "coco2017"
    train_dataset = train_dataset_coco
else:
    train_ds_name = "cifar100"
    train_dataset = train_dataset_cifar
output_dir = (
    sub_project_dir
    / "checkpoints"
    / f"distilroberta_vit_b_32_224_{train_ds_name}_lr{lr}_ep{epochs}_bs{batch_size}_wd{wd}"
)
# output_dir.mkdir(parents=True, exist_ok=True)
os.environ["WANDB_ENTITY"] = "ztzhu11"
os.environ["WANDB_PROJECT"] = "CLIP"
args = TrainingArguments(
    output_dir=output_dir.as_posix(),
    per_device_train_batch_size=batch_size // gradient_accumulation_steps,
    gradient_accumulation_steps=gradient_accumulation_steps,
    num_train_epochs=epochs,
    learning_rate=lr,
    lr_scheduler_type="cosine_with_min_lr",
    lr_scheduler_kwargs={"min_lr_rate": 0.1},
    weight_decay=wd,
    max_grad_norm=5,
    seed=42,
    fp16=True,
    fp16_full_eval=True,
    eval_strategy="steps",
    eval_on_start=True,
    per_device_eval_batch_size=64,
    eval_steps=eval_steps,
    dataloader_num_workers=8,
    dataloader_prefetch_factor=4,
    dataloader_persistent_workers=True,
    save_strategy="steps",
    save_steps=save_steps,
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    log_level="info",
    logging_first_step=True,
    logging_steps=logging_steps,
    report_to=report_to,
    run_name=output_dir.name,
)
trainer = Trainer(
    None,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    model_init=model_init,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)

Using auto half precision backend


In [6]:
import wandb

try:
    trainer.train()
except Exception as e:
    raise e
finally:
    if report_to == "wandb":
        wandb.finish()

***** Running training *****
  Num examples = 50,000
  Num Epochs = 5
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 256
  Gradient Accumulation steps = 8
  Total optimization steps = 980
  Number of trainable parameters = 787,457
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mztzhu1[0m ([33mztzhu11[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Could not log the number of model parameters in Weights & Biases due to an AttributeError.

***** Running Evaluation *****
  Num examples = 10000
  Batch size = 64


Step,Training Loss,Validation Loss,Accuracy
0,No log,4.60499,0.0093
100,1.025100,1.561237,0.638
200,0.717400,0.961037,0.7815
300,0.561700,0.788277,0.8128
400,0.512900,0.686301,0.8271
500,0.484600,0.636011,0.836
600,0.445100,0.607316,0.8422
700,0.452800,0.597921,0.8442
800,0.422200,0.571849,0.8471
900,0.406700,0.562375,0.8493



***** Running Evaluation *****
  Num examples = 10000
  Batch size = 64
Saving model checkpoint to /workspace/LLM-papers-from-scratch/llm_papers/clip/checkpoints/distilroberta_vit_b_32_224_cifar100_lr0.0005_ep5_bs256_wd0.0/checkpoint-100
Trainer.model is not a `PreTrainedModel`, only saving its state dict.

***** Running Evaluation *****
  Num examples = 10000
  Batch size = 64
Saving model checkpoint to /workspace/LLM-papers-from-scratch/llm_papers/clip/checkpoints/distilroberta_vit_b_32_224_cifar100_lr0.0005_ep5_bs256_wd0.0/checkpoint-200
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Deleting older checkpoint [/workspace/LLM-papers-from-scratch/llm_papers/clip/checkpoints/distilroberta_vit_b_32_224_cifar100_lr0.0005_ep5_bs256_wd0.0/checkpoint-100] due to args.save_total_limit

***** Running Evaluation *****
  Num examples = 10000
  Batch size = 64
Saving model checkpoint to /workspace/LLM-papers-from-scratch/llm_papers/clip/checkpoints/distilroberta_vit_b_32_

0,1
eval/accuracy,▁▆▇███████
eval/loss,█▃▂▁▁▁▁▁▁▁
eval/runtime,█▁▁▂▃▄▄▄▄▄
eval/samples_per_second,▁█▇▇▆▅▅▅▅▄
eval/steps_per_second,▁█▇▇▆▅▅▅▅▄
train/epoch,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▆▇▇▇▇▇▇▇▇▇█
train/global_step,▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇█████
train/grad_norm,▁▁▂▅▇▇▆▅▇▅▆█▇▅▅▆▅▅▅▅▅▅▅▆▅▄▅▅▅▆▆▅▅▅▄▅▅▅▄▅
train/learning_rate,██████▇▇▇▇▆▆▆▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁
train/loss,██▇▄▃▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval/accuracy,0.8493
eval/loss,0.56238
eval/runtime,11.7985
eval/samples_per_second,847.566
eval/steps_per_second,13.307
total_flos,0
train/epoch,5
train/global_step,980
train/grad_norm,5.63026
train/learning_rate,5e-05


In [6]:
import wandb

try:
    trainer.train()
except Exception as e:
    raise e
finally:
    if report_to == "wandb":
        wandb.finish()

***** Running training *****
  Num examples = 118,287
  Num Epochs = 5
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 256
  Gradient Accumulation steps = 8
  Total optimization steps = 2,315
  Number of trainable parameters = 787,457
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mztzhu1[0m ([33mztzhu11[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Could not log the number of model parameters in Weights & Biases due to an AttributeError.

***** Running Evaluation *****
  Num examples = 10000
  Batch size = 64


Step,Training Loss,Validation Loss,Accuracy
0,No log,4.60499,0.0093
100,1.008400,3.5499,0.1687
200,0.722700,3.352227,0.2028
300,0.654000,3.233922,0.2226
400,0.556700,3.176095,0.2394
500,0.530800,3.156679,0.2424
600,0.540700,3.106933,0.2525
700,0.482600,3.10388,0.2532
800,0.452200,3.091398,0.2524
900,0.461800,3.045178,0.262



***** Running Evaluation *****
  Num examples = 10000
  Batch size = 64
Saving model checkpoint to /workspace/LLM-papers-from-scratch/llm_papers/clip/checkpoints/distilroberta_vit_b_32_224_coco2017_lr0.0005_ep5_bs256_wd0.0/checkpoint-100
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Deleting older checkpoint [/workspace/LLM-papers-from-scratch/llm_papers/clip/checkpoints/distilroberta_vit_b_32_224_coco2017_lr0.0005_ep5_bs256_wd0.0/checkpoint-900] due to args.save_total_limit

***** Running Evaluation *****
  Num examples = 10000
  Batch size = 64
Saving model checkpoint to /workspace/LLM-papers-from-scratch/llm_papers/clip/checkpoints/distilroberta_vit_b_32_224_coco2017_lr0.0005_ep5_bs256_wd0.0/checkpoint-200
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Deleting older checkpoint [/workspace/LLM-papers-from-scratch/llm_papers/clip/checkpoints/distilroberta_vit_b_32_224_coco2017_lr0.0005_ep5_bs256_wd0.0/checkpoint-100] due to args.save_to

0,1
eval/accuracy,▁▅▆▇▇▇▇▇▇██▇████████████
eval/loss,█▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eval/runtime,▄▁▂▁▂▂▃▄▄▃▂▃█▃▄▃▄▄▅▃▃▅▅▄
eval/samples_per_second,▅█▇█▇▇▆▅▅▆▇▅▁▅▅▆▅▅▄▆▆▄▄▅
eval/steps_per_second,▅█▇█▇▇▆▅▅▆▇▅▁▅▅▆▅▅▄▆▆▄▄▅
train/epoch,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇███
train/global_step,▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▅▆▆▆▇▇▇█████
train/grad_norm,▁▁▅▆▅▅█▇▇▅▆▅▆▆▅▅▆▆▅▇▆▆▅▅▇▅▆▆▄▆▅▆▅▅▅▅▆▆▄▆
train/learning_rate,██████████▇▇▇▇▇▆▆▆▆▆▅▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁
train/loss,█▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval/accuracy,0.2773
eval/loss,3.00801
eval/runtime,12.2012
eval/samples_per_second,819.59
eval/steps_per_second,12.868
total_flos,0
train/epoch,5
train/global_step,2315
train/grad_norm,13.88826
train/learning_rate,5e-05
