In [1]:
import torch
import random
import math
import numpy
import os

from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import numpy as np

In [2]:
#---------------------------------------------------#
#   设置种子
#---------------------------------------------------#
def seed_everything(seed=3407):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
def worker_init_fn(worker_id, seed=3407):
    worker_seed = worker_id + seed
    random.seed(worker_seed)
    np.random.seed(worker_seed)
    torch.manual_seed(worker_seed)

seed=3407
seed_everything(seed)

## 划分数据

## 加载模型

- LORA 微调 CLIP: https://github.com/kesimeg/LORA-turkish-clip
- PEFT https://github.com/datawhalechina/self-llm/blob/master/models/Gemma2/04-Gemma-2-9b-it%20peft%20lora%E5%BE%AE%E8%B0%83.md

In [3]:
# import clip

# device = 'cuda' if torch.cuda.is_available() else "cpu"
# model, preprocess = clip.load("/root/.cache/torch/hub/checkpoints/ViT-B-16.pt", device=device, jit=False)

In [4]:
from huggingface_hub import snapshot_download
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

save_dir = "/root/.cache/torch/hub/checkpoints/clip-vit-base-patch16"

# os.makedirs(save_dir, exist_ok=True)
# snapshot_download(repo_id="openai/clip-vit-base-patch16", local_dir=save_dir)

In [56]:
from transformers import CLIPProcessor, CLIPModel

device = 'cuda' if torch.cuda.is_available() else "cpu"
processor = CLIPProcessor.from_pretrained(save_dir)
pretrained_model = CLIPModel.from_pretrained(save_dir).to(device)



In [57]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

In [58]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none",
)
## 加⼊PEFT策略
model = get_peft_model(pretrained_model, lora_config)

In [59]:
print_trainable_parameters(model)

trainable params: 983040 || all params: 150603777 || trainable%: 0.65


## 加载数据

In [9]:
def print_dataset_folder_info(*dataset_folders):
    for dataset_folder in dataset_folders:
        print(len(dataset_folder), dataset_folder.classes, dataset_folder.class_to_idx)

In [37]:
dataset_root = "./dataset"
batch_size = 32
num_workers = 4

# 加载图像-文本匹配数据集
train_folder = MyDataset(datasets.ImageFolder(root=f"{dataset_root}/train", transform=transform), texts)
val_folder = MyDataset(datasets.ImageFolder(root=f"{dataset_root}/validation", transform=transform), texts)
test_folder = MyDataset(datasets.ImageFolder(root=f"{dataset_root}/test", transform=transform), texts)  # report test-dev
 
# 创建数据加载器
train_loader = DataLoader(train_folder, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=worker_init_fn, collate_fn=customBatchBuilder)
val_loader = DataLoader(val_folder, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=worker_init_fn, collate_fn=customBatchBuilder)
test_loader = DataLoader(test_folder, batch_size=batch_size, shuffle=False, num_workers=num_workers, worker_init_fn=worker_init_fn, collate_fn=customBatchBuilder)

['cats', 'dogs'] {'cats': 0, 'dogs': 1}
['cats', 'dogs'] {'cats': 0, 'dogs': 1}
['cats', 'dogs'] {'cats': 0, 'dogs': 1}


In [38]:
train_labels = train_folder.targets
class_nums = len(train_folder.classes)

In [39]:
inputs = next(iter(train_loader))
for key in inputs.keys():
  print("Sample {} shape ".format(key), inputs[key].shape)

Sample input_ids shape  torch.Size([32, 6])
Sample attention_mask shape  torch.Size([32, 6])
Sample pixel_values shape  torch.Size([32, 3, 224, 224])


## 模型

In [40]:
import torch
from torch import nn

In [41]:
def calc_classes_weights(labels, method="balanced"):
    classes = np.unique(labels)
    nums_list=[len(np.where(labels==cl)[0]) for cl in classes]
    print(nums_list)
    if method=="balanced":
        return compute_class_weight("balanced", classes=classes, y=labels)
    elif method=="max":
        # 即用类别中最大样本数量除以当前类别样本的数量，作为权重系数
        max_nums = np.max(nums_list)
        return [max_nums/nums for nums in nums_list]
    elif method=="reciprocal":
        return [1/nums for nums in nums_list]
    else:
        pass

### 02-CLIP-预训练模型

In [42]:
from tqdm import tqdm
from IPython import display
import matplotlib.pyplot as plt
torch.set_default_dtype(torch.float32)

model_name = "xxx"
start_lr = 0.0001

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(),
                            lr=start_lr,
                            betas=(0.9,0.999),
                            eps=1e-8)

In [43]:
# # HWC_in 卷积 KKC_inC_out  --> H'W' KKC_in * KKC_in C_out  --> H'W'C_out
# # 其中 H'W' 就是 patch 个数，可以看到第二步其实就是在将原始 patch 的得到的特征维度 [KKC_in] --> projection --> C_out
# # stride = kernel_size 无重叠 patch
# model.visual.conv1

In [44]:
# print(model.visual.input_resolution)
# # 取 eot_token 作为特征表述
# print("image feature dim: ", model.visual.proj.shape)
# print("text feature dim", model.text_projection.shape)

## 训练

In [45]:
# def train(device, model, dataloader, loss_img, loss_txt, train=True, optimizer=None, useproba=True, weights=None, verbose=False):
#     correct = 0
#     error = 0
#     total = 0
    
#     fin_probas = None
#     fin_ls = None
    
#     for batch, (images, texts) in enumerate(dataloader):
#         images, texts = images.to(device), texts.to(device)
#         n = images.shape[0]
        
#         logits_per_image, logits_per_text = model(images, texts)
        
#         ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
#         loss = 0.5 * ( loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth) )

#         if train:
#             # 开始优化网络权重
#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()
            
#         error += loss.item()
        
#         # 计算准确率
#         p = torch.max(logits_per_image,1)[1].to(device)
#         correct += (p == ground_truth).sum()
#         total += n
        
#         if verbose:
#             pass
            
#         if useproba:
#             probas = logits_per_image.detach().cpu()
#             l = ground_truth.detach().cpu()
#             fin_probas = probas if fin_probas is None else np.concatenate([fin_probas, probas], axis=0)
#             fin_ls = l if fin_ls is None else np.concatenate([fin_ls, l], axis=0)
    
#     return error / (batch+1), correct / total, fin_probas, fin_ls

In [63]:
def train(device, model, dataloader, loss_img, loss_txt, train=True, optimizer=None, useproba=True, weights=None, verbose=False):
    correct = 0
    error = 0
    total = 0
    
    fin_probas = None
    fin_ls = None
    
    for batch, ipts in enumerate(dataloader):
        input_ids, pixel_values, attention_mask = ipts["input_ids"].to(device), ipts["pixel_values"].to(device), ipts["attention_mask"].to(device)
        n = input_ids.shape[0]

        outputs = model(input_ids, pixel_values, attention_mask)
        logits_per_image, logits_per_text = outputs.logits_per_image, outputs.logits_per_text
        
        ground_truth = torch.arange(n, dtype=torch.long, device=device)
        loss = 0.5 * ( loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth) )

        if train:
            # 开始优化网络权重
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        error += loss.item()
        
        # 计算匹配数
        p = torch.max(logits_per_image, 1)[1].to(device)
        correct += (p == ground_truth).sum()
        total += n
        
        if verbose:
            pass
            
        if useproba:
            probas = logits_per_image.detach().cpu()
            l = ground_truth.detach().cpu()
            fin_probas = probas if fin_probas is None else np.concatenate([fin_probas, probas], axis=0)
            fin_ls = l if fin_ls is None else np.concatenate([fin_ls, l], axis=0)
    
    return error / (batch+1), correct / total, fin_probas, fin_ls

def test(device, model, dataloader, loss_img, loss_txt, useproba=True, weights=None, verbose=False):
    with torch.no_grad():
        return train(device, model, dataloader, loss_img, loss_txt, train=False, optimizer=None, useproba=useproba, weights=weights, verbose=verbose)

In [52]:
total_epochs = 300
save_epoch_fre = 50
save_root = "./results"

In [53]:
def get_epoch_lr(cur_epoch):
    return cur_epoch

In [61]:
idxs=[]
train_errors=[]
train_accs=[]

val_errors=[]
val_accs=[]

In [None]:
saved_epoches = range(50, total_epochs, 50)

epoch_s = 0
epoch_e = total_epochs

rows = 1
cols = 2

for i in range(epoch_s, epoch_e+1):  
    if i % save_epoch_fre == 0 and i>0:
        state = { 'model': model.state_dict(), 'epoch': i, "lr": start_lr}  
        path = f"{save_root}/{model_name}_{i}.pth"
        torch.save(state, path)
        
    model.train()
    train_error,train_acc,train_probas,train_ls = train(device, lora_model, train_loader, criterion, criterion, train=True, optimizer=optimizer, useproba=False)
    model.eval()
    val_error,val_acc,val_probas,val_ls = test(device, model, val_loader, loss_img, loss_txt, useproba=False)
    
    idxs.append(i)  
    
    train_errors.append(train_error) 
    val_errors.append(val_error)
    
    train_accs.append(train_acc.cpu().item())  
    val_accs.append(val_acc.cpu().item()) 
    
    display.clear_output(wait=True)
    
    plt.figure(figsize=(cols*5,rows*5))
    plt.subplot(rows,cols,1)
    plt.plot(idxs,train_errors,c='red',label="train_loss")
    plt.plot(idxs,val_errors,c='blue',label="val_loss")
    plt.legend(bbox_to_anchor=(1.5, 1), loc=1)

    plt.subplot(rows,cols,2)
    plt.plot(idxs,train_accs,c='red',label="train_acc")
    plt.plot(idxs,val_accs,c='blue',label="val_acc")
    plt.legend(bbox_to_anchor=(1.5, 1), loc=1)
    
    plt.tight_layout()
    plt.show()
    plt.pause(0.05)

## 保存

## Zero-Shot 效果

In [None]:
model.eval()

In [None]:
def calc_top_k_acc(probas, lables, k=1):
    max_indics = np.argmax(probas, axis=1)
    return len(np.where(max_indics==lables)[0]) / len(lables)

In [None]:
val_error,val_acc,val_probas,val_ls = test(device, model, val_loader, loss_img, loss_txt, useproba=False)

In [None]:
# class MyDataset(Dataset):
#     def __init__(self, dataset_folder, prompts: list[str], context_length: Optional[int]=77):
#         print_dataset_folder_info(dataset_folder)
        
#         self.dataset_folder = dataset_folder
#         self.prompts = prompts
#         self.context_length = 77 if context_length is None else context_length

#         self.targets = dataset_folder.targets
#         self.classes = dataset_folder.classes
#         self.class_to_idx = dataset_folder.class_to_idx
        
#     def __len__(self):
#         return len(self.dataset_folder)
        
#     def __getitem__(self, idx):
#         X, y = self.dataset_folder[idx]
#         text = clip.tokenize(self.prompts[y], self.context_length).reshape((-1,))
#         return X, text, y

In [None]:
# prompts = [
#     "This image suggests a severe anxiety psychological tendency",  # 重度
#     "This image suggests a mild anxiety psychological tendency",  # 中度
#     "This image suggests no anxiety psychological tendency",  # 无
# ]

prompts = [
    "This is a cat",  
    "This is a dog",  
]