In [1]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
from transformers import pipeline, CLIPModel, CLIPProcessor

# Configuration
LOADER_PATCH_SIZE = 32


  from .autonotebook import tqdm as notebook_tqdm


# Notebook for interactive testing for CLIP

In [2]:

class Cfg:
    model_id: str = "openai/clip-vit-base-patch32"
    batch_size: int = 32
    epochs: int = 20
    seed:   int = 42

    # -------- Optim & Loss ----------
    lr_head: float = 1e-3      # 线性头
    wd_head: float = 1e-4
    lr_lora: float = 1e-4      # LoRA 注入层
    wd_lora: float = 1e-2
    lambda_text: float = 0.3   # 文本对齐辅助损失权重（链路3）

    # -------- LoRA ----------
    lora_rank: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.0
    lora_target: tuple = ("q_proj","k_proj","v_proj","out_proj")  # 只对注意力投影层做LoRA
    # 也可扩展到 MLP 内部 proj，但注意稳定性

    amp: bool = True

cfg = Cfg()
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(cfg.seed)
model_id = "openai/clip-vit-base-patch32"
clip_model = CLIPModel.from_pretrained(model_id).to(device).eval()
processor  = CLIPProcessor.from_pretrained(model_id)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [4]:
# ---------- 2. Load Dataset ----------
# We do transform in Epoch loops
train_set = datasets.Flowers102(root="./data", split="train", download=True)
val_set   = datasets.Flowers102(root="./data", split="val",   download=True)
test_set  = datasets.Flowers102(root="./data", split="test",  download=True)
classname = val_set.classes
classname

['pink primrose',
 'hard-leaved pocket orchid',
 'canterbury bells',
 'sweet pea',
 'english marigold',
 'tiger lily',
 'moon orchid',
 'bird of paradise',
 'monkshood',
 'globe thistle',
 'snapdragon',
 "colt's foot",
 'king protea',
 'spear thistle',
 'yellow iris',
 'globe-flower',
 'purple coneflower',
 'peruvian lily',
 'balloon flower',
 'giant white arum lily',
 'fire lily',
 'pincushion flower',
 'fritillary',
 'red ginger',
 'grape hyacinth',
 'corn poppy',
 'prince of wales feathers',
 'stemless gentian',
 'artichoke',
 'sweet william',
 'carnation',
 'garden phlox',
 'love in the mist',
 'mexican aster',
 'alpine sea holly',
 'ruby-lipped cattleya',
 'cape flower',
 'great masterwort',
 'siam tulip',
 'lenten rose',
 'barbeton daisy',
 'daffodil',
 'sword lily',
 'poinsettia',
 'bolero deep blue',
 'wallflower',
 'marigold',
 'buttercup',
 'oxeye daisy',
 'common dandelion',
 'petunia',
 'wild pansy',
 'primula',
 'sunflower',
 'pelargonium',
 'bishop of llandaff',
 'gaura',

In [5]:
# ---------- 3. DataLoader ----------
# ---------- Process of images is put on epoch loops
def collate_pil(batch):
    # batch: List[ (PIL.Image.Image, int) ]
    images, labels = zip(*batch)           # images: tuple of PIL, labels: tuple of int
    return list(images), torch.tensor(labels)  # 让 processor 接收 list[PIL]，labels 变成 LongTensor

train_loader = DataLoader(
    train_set, batch_size=cfg.batch_size, shuffle=True,
    num_workers=0, pin_memory=True, collate_fn=collate_pil
    #workers should be 4, but got problems in notebook
)
val_loader = DataLoader(
    val_set, batch_size=cfg.batch_size, shuffle=False,
    num_workers=0, pin_memory=True, collate_fn=collate_pil
)
test_loader = DataLoader(
    test_set, batch_size=cfg.batch_size, shuffle=False,
    num_workers=0, pin_memory=True, collate_fn=collate_pil
)

---
## Model Setting and Training

In [6]:
# Word Prompt Embedding
import torch
from tqdm import tqdm

promptTemplate = {
    "A photo of {}.",
    "A photo of flower {}.",
    "Botanic picture of {}",
    "A example picture of type {}"
}
# Use more templates to reduce sensitivity to other contexts

@torch.no_grad()
def build_text_embeddings(names):
    embs = []
    for name in tqdm(names, desc="TextEmbed"):
        prompts = [t.format(name.replace("_"," ")) for t in promptTemplate] # insert class names
        inputs = processor(text=prompts, return_tensors="pt", padding=True).to(device)
        te = clip_model.get_text_features(**inputs)     # [T, D]
        te = te / te.norm(dim=-1, keepdim=True)
        embs.append(te.mean(dim=0))                     # [D]
    text = torch.stack(embs, dim=0)                     # [C, D]
    return text / text.norm(dim=-1, keepdim=True)

text_embs = build_text_embeddings(classname)          # [102, D], 固定不训练



TextEmbed: 100%|██████████| 102/102 [00:01<00:00, 57.74it/s]


----
## Building CLIP model with LoRA and word embedding.

Have to implement a LoRA linear layer ourselves.

In [7]:
# LoRA Injection (Trains LoRA matries only)
import torch, torch.nn as nn
from transformers.models.clip.modeling_clip import CLIPVisionModel

# y = w0 + x*(BA)*alpha/rank
# Shape of A: din by rank / Shape of B: rank by dout
class LoRALinearLayer(nn.Module):
    def __init__(self, base: nn.Linear, r=8, alpha=16, dropout=0.0):

        super().__init__()
        self.base = base # linear layer frozen for training LoRA parameters
        self.r = r
        self.scaling = alpha / r

        if r > 0:
            self.lora_A = nn.Linear(base.in_features, r, bias=False)
            self.lora_B = nn.Linear(r, base.out_features, bias=False)
            self.dropout = nn.Dropout(dropout)
            nn.init.kaiming_uniform_(self.lora_A.weight,a=5**0.5)
            nn.init.zeros_(self.lora_B.weight) # set B to 0, avoid any bias introduced.
        else:
            self.lora_A = None
            self.lora_B = None
            self.dropout = nn.Identity()

            #Frozen
        for p in self.base.parameters():
            p.requires_grad = False

    def forward(self, x):
        if self.r > 0:
            return self.base(x) + self.dropout(self.lora_B(self.lora_A(x))) * self.scaling
        else:
            return self.base(x)


# LoRA Injection with warped LoRA layer shown above.

def lora_injection(clip_model: nn.Module, target_names=("q_proj","k_proj","v_proj","out_proj")):
    """
    """
    assert isinstance(clip_model.vision_model, CLIPVisionModel.__mro__[0].__class__) or hasattr(clip_model, "vision_model")
    lora_params = []
    for name, module in clip_model.vision_model.named_modules():
        # injection to clip/transformer attention layer: q_proj/k_proj/v_proj/out_proj
        for t in target_names:
            if hasattr(module, t):
                lin = getattr(module, t)
                if isinstance(lin, nn.Linear):
                    lora_lin = LoRALinearLayer(lin, r=cfg.lora_rank, alpha=cfg.lora_alpha, dropout=cfg.lora_dropout)
                    setattr(module, t, lora_lin)
                    lora_params += list(lora_lin.lora_A.parameters()) + list(lora_lin.lora_B.parameters())
    # Freeze the parameters
    for p in clip_model.vision_model.parameters():
        p.requires_grad = False
    for p in lora_params:
        p.requires_grad = True
    return lora_params

def build_head_and_optim(clip_model: CLIPModel):
    feat_dim = clip_model.config.projection_dim  # ViT-B/32 = 512
    head = nn.Linear(feat_dim, 102).to(device)

    lora_params = lora_injection(clip_model, target_names=cfg.lora_target)

    # 2 parameter groups: LoRA and linear head
    optim = torch.optim.AdamW(
        [
            {"params": head.parameters(),      "lr": cfg.lr_head, "weight_decay": cfg.wd_head},
            {"params": lora_params,            "lr": cfg.lr_lora, "weight_decay": cfg.wd_lora},
        ]
    )
    scaler = torch.amp.GradScaler(enabled=(device=="cuda" and cfg.amp))
    return head, optim, scaler


In [8]:
head, optimizer, scaler = build_head_and_optim(clip_model)
ce = torch.nn.CrossEntropyLoss()

def get_image_feats(images):
    inputs = processor(images=images, return_tensors="pt").to(device)
    feats = clip_model.get_image_features(**inputs)           # [B, D]
    feats = feats / feats.norm(dim=-1, keepdim=True)
    return feats

def supervised_logits(feats):
    return head(feats)                                        # [B, 102]

def text_logits(feats):
    # perform cosine similarity with text embedding.
    return (feats @ text_embs.T) * clip_model.logit_scale.exp()

----
## Main Training Epoch

In [12]:
def run_epoch(loader: DataLoader, train: bool=True):
    if train:
        head.train()
        clip_model.train()
    else:
        head.eval()
        clip_model.eval()

    total, correct_cls, correct_txt = 0, 0, 0
    loss_sum = 0.0
    for images, labels in tqdm(loader, desc="Train" if train else "Eval"):
        labels = labels.to(device)
        with torch.amp.autocast(device_type=device,enabled=(device=="cuda" and cfg.amp)):
            feats = get_image_feats(images)                   # [B, D]

            logits_cls = supervised_logits(feats) # logits of classification score from linear layer head
            loss_cls = ce(logits_cls, labels)

            logits_txt = text_logits(feats) # logits of text embedding trained in transformer
            loss_txt = ce(logits_txt, labels)
            # Alignment between text and img.

            loss = loss_cls + cfg.lambda_text * loss_txt # weighted


        if train:
            # backward propagation
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()


        #Stats

        loss_sum += loss.item() * labels.size(0)
        total += labels.size(0)
        correct_cls += (logits_cls.argmax(dim=-1) == labels).sum().item()
        correct_txt += (logits_txt.argmax(dim=-1) == labels).sum().item()

    return {
    "loss": loss_sum/total,
    "acc_cls": correct_cls/total,   # 线性头准确率
    "acc_txt": correct_txt/total,   # 文本读出准确率（zero-shot 风格）
}



In [13]:
best_val = -1.0
best_head = None

for ep in range(1,cfg.epochs+1):
    training = run_epoch(train_loader, train=True)
    val = run_epoch(val_loader, train=False)
    print(f"[{ep}/{cfg.epochs}] "
          f"Train: loss={training['loss']:.4f} acc_cls={training['acc_cls']:.4f} acc_txt={training['acc_txt']:.4f} | "
          f"Val:   loss={val['loss']:.4f} acc_cls={val['acc_cls']:.4f} acc_txt={val['acc_txt']:.4f}")


    if val["acc_cls"] > best_val:
        best_val = val["acc_cls"]
        best_head = { k: v.detach().cpu() for k, v in head.state_dict().items() } # Detach the parameters from autograd (keeps weights only)


if best_head is not None:
    head.load_state_dict({k: v.to(device) for k, v in best_head.items()})
te = run_epoch(test_loader, train=False)
print(f"Test: loss={te['loss']:.4f}  acc_cls={te['acc_cls']:.4f}  acc_txt={te['acc_txt']:.4f}")


Train: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]
Eval: 100%|██████████| 32/32 [00:18<00:00,  1.74it/s]


[1/20] Train: loss=5.0636 acc_cls=0.0529 acc_txt=0.6569 | Val:   loss=4.9526 acc_cls=0.2196 acc_txt=0.6990


Train: 100%|██████████| 32/32 [00:32<00:00,  1.00s/it]
Eval: 100%|██████████| 32/32 [00:17<00:00,  1.84it/s]


[2/20] Train: loss=4.8573 acc_cls=0.4549 acc_txt=0.7255 | Val:   loss=4.7771 acc_cls=0.6343 acc_txt=0.7500


Train: 100%|██████████| 32/32 [00:32<00:00,  1.01s/it]
Eval: 100%|██████████| 32/32 [00:17<00:00,  1.85it/s]


[3/20] Train: loss=4.6089 acc_cls=0.8108 acc_txt=0.8059 | Val:   loss=4.5698 acc_cls=0.8186 acc_txt=0.7843


Train: 100%|██████████| 32/32 [00:33<00:00,  1.03s/it]
Eval: 100%|██████████| 32/32 [00:18<00:00,  1.73it/s]


[4/20] Train: loss=4.3251 acc_cls=0.9137 acc_txt=0.8833 | Val:   loss=4.3468 acc_cls=0.8696 acc_txt=0.8108


Train: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]
Eval: 100%|██████████| 32/32 [00:18<00:00,  1.72it/s]


[5/20] Train: loss=4.0390 acc_cls=0.9549 acc_txt=0.9480 | Val:   loss=4.1238 acc_cls=0.8873 acc_txt=0.8147


Train: 100%|██████████| 32/32 [00:32<00:00,  1.01s/it]
Eval: 100%|██████████| 32/32 [00:18<00:00,  1.76it/s]


[6/20] Train: loss=3.7671 acc_cls=0.9676 acc_txt=0.9784 | Val:   loss=3.9013 acc_cls=0.8961 acc_txt=0.8088


Train: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]
Eval: 100%|██████████| 32/32 [00:17<00:00,  1.81it/s]


[7/20] Train: loss=3.5170 acc_cls=0.9765 acc_txt=0.9873 | Val:   loss=3.6911 acc_cls=0.8971 acc_txt=0.8176


Train: 100%|██████████| 32/32 [00:32<00:00,  1.03s/it]
Eval: 100%|██████████| 32/32 [00:18<00:00,  1.73it/s]


[8/20] Train: loss=3.2853 acc_cls=0.9833 acc_txt=0.9922 | Val:   loss=3.5087 acc_cls=0.9010 acc_txt=0.8088


Train: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]
Eval: 100%|██████████| 32/32 [00:17<00:00,  1.86it/s]


[9/20] Train: loss=3.0651 acc_cls=0.9892 acc_txt=0.9980 | Val:   loss=3.3400 acc_cls=0.9059 acc_txt=0.8039


Train: 100%|██████████| 32/32 [00:32<00:00,  1.00s/it]
Eval: 100%|██████████| 32/32 [00:17<00:00,  1.78it/s]


[10/20] Train: loss=2.8536 acc_cls=0.9912 acc_txt=0.9961 | Val:   loss=3.1697 acc_cls=0.9127 acc_txt=0.8029


Train: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]
Eval: 100%|██████████| 32/32 [00:18<00:00,  1.73it/s]


[11/20] Train: loss=2.6537 acc_cls=0.9941 acc_txt=0.9971 | Val:   loss=3.0155 acc_cls=0.9059 acc_txt=0.8029


Train: 100%|██████████| 32/32 [00:32<00:00,  1.02s/it]
Eval: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s]


[12/20] Train: loss=2.4603 acc_cls=0.9971 acc_txt=0.9980 | Val:   loss=2.8564 acc_cls=0.9059 acc_txt=0.8010


Train: 100%|██████████| 32/32 [00:33<00:00,  1.03s/it]
Eval: 100%|██████████| 32/32 [00:17<00:00,  1.78it/s]


[13/20] Train: loss=2.2702 acc_cls=0.9961 acc_txt=1.0000 | Val:   loss=2.7038 acc_cls=0.9078 acc_txt=0.8020


Train: 100%|██████████| 32/32 [00:31<00:00,  1.00it/s]
Eval: 100%|██████████| 32/32 [00:18<00:00,  1.77it/s]


[14/20] Train: loss=2.0881 acc_cls=0.9971 acc_txt=1.0000 | Val:   loss=2.5601 acc_cls=0.9049 acc_txt=0.8108


Train: 100%|██████████| 32/32 [00:32<00:00,  1.02s/it]
Eval: 100%|██████████| 32/32 [00:16<00:00,  1.89it/s]


[15/20] Train: loss=1.9143 acc_cls=1.0000 acc_txt=1.0000 | Val:   loss=2.4343 acc_cls=0.9206 acc_txt=0.7922


Train: 100%|██████████| 32/32 [00:32<00:00,  1.03s/it]
Eval: 100%|██████████| 32/32 [00:17<00:00,  1.85it/s]


[16/20] Train: loss=1.7480 acc_cls=1.0000 acc_txt=1.0000 | Val:   loss=2.3027 acc_cls=0.9127 acc_txt=0.8069


Train: 100%|██████████| 32/32 [00:33<00:00,  1.05s/it]
Eval: 100%|██████████| 32/32 [00:17<00:00,  1.86it/s]


[17/20] Train: loss=1.5898 acc_cls=1.0000 acc_txt=1.0000 | Val:   loss=2.2090 acc_cls=0.9176 acc_txt=0.7873


Train: 100%|██████████| 32/32 [00:31<00:00,  1.00it/s]
Eval: 100%|██████████| 32/32 [00:17<00:00,  1.84it/s]


[18/20] Train: loss=1.4403 acc_cls=1.0000 acc_txt=1.0000 | Val:   loss=2.0901 acc_cls=0.9167 acc_txt=0.7971


Train: 100%|██████████| 32/32 [00:33<00:00,  1.03s/it]
Eval: 100%|██████████| 32/32 [00:17<00:00,  1.78it/s]


[19/20] Train: loss=1.3014 acc_cls=1.0000 acc_txt=1.0000 | Val:   loss=1.9913 acc_cls=0.9206 acc_txt=0.7951


Train: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]
Eval: 100%|██████████| 32/32 [00:17<00:00,  1.88it/s]


[20/20] Train: loss=1.1747 acc_cls=1.0000 acc_txt=1.0000 | Val:   loss=1.9094 acc_cls=0.9225 acc_txt=0.7922


Eval: 100%|██████████| 193/193 [01:49<00:00,  1.77it/s]

Test: loss=1.9471  acc_cls=0.9167  acc_txt=0.7831



