In [None]:
from datasets import load_dataset
from transformers import CLIPTextModel, CLIPProcessor, AdamW, DataCollatorWithPadding
from torch.utils.data import DataLoader
import torch.nn as nn
import torch

In [None]:
dataset = load_dataset("glue", "mrpc")

In [None]:
model_name = "openai/clip-vit-base-patch32"
clip_model = CLIPTextModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)

In [None]:
class CLIPForSequenceClassification(nn.Module):
    def __init__(self, clip_model, num_labels):
        super(CLIPForSequenceClassification, self).__init__()
        self.clip_model = clip_model
        self.classifier = nn.Linear(clip_model.config.hidden_size, num_labels)

    def forward(self, input_ids=None, attention_mask=None, labels=None):
        text_features = self.clip_model(input_ids=input_ids, attention_mask=attention_mask).pooler_output
        logits = self.classifier(text_features)
        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
        return logits, loss

model = CLIPForSequenceClassification(clip_model, num_labels=2)

In [None]:
def preprocess_function(examples):
    inputs = processor(text=examples["sentence1"], text_pair=examples["sentence2"], padding=True, truncation=True, return_tensors="pt")
    inputs['labels'] = torch.tensor(examples['label'])
    return inputs

tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names)

In [None]:
data_collator = DataCollatorWithPadding(tokenizer=processor.tokenizer)
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"]
test_dataset = tokenized_datasets["test"]

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8, collate_fn=data_collator)
eval_dataloader = DataLoader(eval_dataset, batch_size=8, collate_fn=data_collator)
test_dataloader = DataLoader(test_dataset, batch_size=8, collate_fn=data_collator)

In [None]:
optimizer = AdamW(model.parameters(), lr=5e-5)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
model.train()
for epoch in range(5):
    for batch in train_dataloader:
        inputs = {k: v.to(device) for k, v in batch.items() if k in ["input_ids", "attention_mask"]}
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        logits, loss = model(**inputs, labels=labels)

        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch} completed")

In [None]:
model.eval()
all_predictions = []
all_labels = []

with torch.no_grad():
    for batch in eval_dataloader:
        inputs = {k: v.to(device) for k, v in batch.items() if k in ["input_ids", "attention_mask"]}
        labels = batch["labels"].to(device)

        logits, _ = model(**inputs)
        predictions = torch.argmax(logits, dim=-1)

        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

In [None]:
from sklearn.metrics import accuracy_score, recall_score, precision_score
accuracy = accuracy_score(all_labels, all_predictions)
recall = recall_score(all_labels, all_predictions)
precision = precision_score(all_labels, all_predictions)
print(f"Test Accuracy: {accuracy}")
print(f"Test Recall: {recall}")
print(f"Test Precision: {precision}")