In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
from transformers import CLIPProcessor, CLIPModel

device = "cuda" if torch.cuda.is_available() else "cpu"

# 1. 载入 CLIP
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# 2. 定义 Dataset
class IconDataset(Dataset):
    def __init__(self, csv_file, img_dir, processor):
        df = pd.read_csv(csv_file, names=["fname", "label"])

        # 过滤掉 NaN 或空字符串
        df = df.dropna(subset=["label"])
        df = df[df["label"].astype(str).str.strip() != ""]

        # 转成 int
        df["label"] = df["label"].astype(int)

        self.data = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.processor = processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = os.path.join(self.img_dir, row["fname"])
        image = Image.open(img_path).convert("RGB")
        inputs = self.processor(images=image, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze(0)
        label = torch.tensor(int(row["label"]), dtype=torch.long)
        return pixel_values, label

img_dir = "./drive/MyDrive/lol_profile_icons"
dataset = IconDataset("./drive/MyDrive/icon_labels.csv", img_dir, clip_processor)


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

In [2]:
# 冻结 CLIP 参数
for p in clip_model.parameters():
    p.requires_grad = False

# 定义分类头
class Classifier(torch.nn.Module):
    def __init__(self, feature_dim, hidden_dim=256):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(feature_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, 2)  # 二分类
        )
    def forward(self, x):
        return self.net(x)

# CLIP 输出的特征维度 = 512 (vit-base-patch32)
clf = Classifier(feature_dim=512).to(device)


In [8]:
from torch.utils.data import random_split

# 划分训练/验证集
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

# batch_size 32
batch_size = 1
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size)


In [9]:

# images have dim 512
for X, y in train_loader:
    print(clip_model.get_image_features(X).shape)
    break


torch.Size([1, 512])


In [10]:


criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(clf.parameters(), lr=1e-3)

EPOCHS = 5
for epoch in range(EPOCHS):
    clf.train()
    for pixel_values, labels in train_loader:
        pixel_values, labels = pixel_values.to(device), labels.to(device)

        with torch.no_grad():
            feats = clip_model.get_image_features(pixel_values)
        logits = clf(feats)

        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 验证
    clf.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for pixel_values, labels in val_loader:
            pixel_values, labels = pixel_values.to(device), labels.to(device)
            feats = clip_model.get_image_features(pixel_values)
            logits = clf(feats)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    acc = correct / total
    print(f"Epoch {epoch+1}: val acc={acc:.4f}")

clf.eval()
results = []
with torch.no_grad():
    for pixel_values, labels in DataLoader(dataset, batch_size=32):
        feats = clip_model.get_image_features(pixel_values.to(device))
        logits = clf(feats)
        probs = torch.softmax(logits, dim=1)[:,1]  # “喜欢”的概率
        results.extend(probs.cpu().numpy())

dataset.data["prob_like"] = results
dataset.data.to_csv("captions.csv", index=False)


Epoch 1: val acc=1.0000
Epoch 2: val acc=1.0000
Epoch 3: val acc=1.0000
Epoch 4: val acc=0.0000
Epoch 5: val acc=0.0000
