In [1]:
import torch
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, CLIPModel, TrainingArguments, Trainer
from datasets import load_dataset
from PIL import Image
import requests
from io import BytesIO
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from peft import LoraConfig, get_peft_model
from torch.optim import AdamW


In [3]:
model_name_ = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name_)
processor = CLIPProcessor.from_pretrained(model_name_)


lora_config = LoraConfig(
    r=8,  # Rank of the adaptation matrix
    lora_alpha=32,  # Scaling factor
    lora_dropout=0.1,  # Dropout for LoRA layers
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],  # Targeting Linear layers within CLIPAttention
    task_type="vision_language_modeling"  # Task type specific to vision-language models
)

#LoRA
model = get_peft_model(model, lora_config)

dataset = load_dataset("cifar10")

print(dataset['train'][0])

#preprocess function --> converts labels to strings for concatenation
def preprocess_function(examples):
    inputs = processor(
        text=[" photo of a " + str(label) for label in examples["label"]],
        images=[img for img in examples["img"]],
        return_tensors="pt",
        padding=True,
    )
    return inputs

dataset = dataset.map(preprocess_function, batched=True)
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "pixel_values", "label"])
train_dataloader = DataLoader(dataset["train"], batch_size=32, shuffle=True)
eval_dataloader = DataLoader(dataset["test"], batch_size=64)


optimizer = AdamW(model.parameters(), lr=5e-5)


model.train()
for epoch in range(3):  
    for batch in train_dataloader:
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            pixel_values=batch["pixel_values"]
        )
        
        #contrastive loss
        logits_per_image = outputs.logits_per_image
        logits_per_text = outputs.logits_per_text
        ground_truth = torch.arange(len(logits_per_image)).long().to(logits_per_image.device)
        loss = (torch.nn.functional.cross_entropy(logits_per_image, ground_truth) + 
                torch.nn.functional.cross_entropy(logits_per_text, ground_truth)) / 2
        
        
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch + 1} completed. Loss: {loss.item()}")


model.eval()
all_preds = []
all_labels = []
for batch in eval_dataloader:
    with torch.no_grad():
        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            pixel_values=batch["pixel_values"]
        )
        
        
        preds = torch.argmax(outputs.logits_per_image, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(batch["label"].cpu().numpy())

# ------------------accuracy-------------------------------------
accuracy = sum([pred == label for pred, label in zip(all_preds, all_labels)]) / len(all_preds)
print(f"Accuracy: {accuracy}")


{'img': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32 at 0x1045C8790>, 'label': 0}


KeyboardInterrupt: 

In [None]:
#ReFT
import torch
from transformers import CLIPProcessor, CLIPModel, Trainer, TrainingArguments
from datasets import load_dataset
from pyreft import ReFTTrainer, ReFTConfig
