In [14]:
import os
import cv2
import numpy as np
import torch
import clip
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [15]:
# Load the CLIP model and tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device)

In [17]:
# Step 1: Load and Preprocess Image Data for Text Extraction
class ImageTextDataset(Dataset):
    def __init__(self, data_folder, transform=None):
        self.data_folder = data_folder
        self.transform = transform
        self.images, self.labels = self.load_images()

    def load_images(self):
        images = []
        labels = []
        for class_label, class_name in enumerate(os.listdir(self.data_folder)):
            class_folder = os.path.join(self.data_folder, class_name)
            if not os.path.isdir(class_folder):
                continue
            for file_name in os.listdir(class_folder):
                file_path = os.path.join(class_folder, file_name)
                image = cv2.imread(file_path)
                images.append(image)
                labels.append(class_label)
        return images, labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image = self.images[index]
        label = self.labels[index]
        if self.transform:
            image = self.transform(image)
        return image, label

In [23]:
# Define the transformation to preprocess images for the CLIP model
target_size = (224, 224)  # Specify the target size for images
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(target_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [24]:
# Step 2: Extract Text from Images and Perform Text Classification
image_text_data_folder = "text-dataset"  # Change this to your dataset folder
image_text_dataset = ImageTextDataset(image_text_data_folder, transform=transform)

In [25]:
# Prepare DataLoader
batch_size = 32
image_text_loader = DataLoader(image_text_dataset, batch_size=batch_size, shuffle=True)

In [26]:
# Step 3: Text Classification using CLIP
classifier_head = torch.nn.Linear(model.text_projection.shape[1], len(set(image_text_dataset.labels))).to(device)

num_epochs = 10
optimizer = torch.optim.Adam(classifier_head.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    classifier_head.train()
    for images, labels in image_text_loader:
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            text_features = model.encode_image(images)

        optimizer.zero_grad()
        logits = classifier_head(text_features)
        loss = criterion(logits, labels.long())
        loss.backward()
        optimizer.step()


In [27]:
# Step 4: Evaluate the Text Classification Model
text_pred = []
text_labels = []
classifier_head.eval()

with torch.no_grad():
    for images, labels in image_text_loader:
        images = images.to(device)
        labels = labels.to(device)

        text_features = model.encode_image(images)
        logits = classifier_head(text_features)
        pred = logits.argmax(dim=-1).cpu().numpy()
        text_pred.extend(pred)
        text_labels.extend(labels.cpu().numpy())

text_accuracy = accuracy_score(text_labels, text_pred)
text_precision = precision_score(text_labels, text_pred, average="weighted")
text_recall = recall_score(text_labels, text_pred, average="weighted")
text_f1_score = f1_score(text_labels, text_pred, average="weighted")

print("Text Accuracy:", text_accuracy)
print("Text Precision:", text_precision)
print("Text Recall:", text_recall)
print("Text F1-score:", text_f1_score)

Text Accuracy: 0.802
Text Precision: 0.8021740522540983
Text Recall: 0.802
Text F1-score: 0.8019714838936808


In [28]:
# Step 5: Save Text Classification Model
text_model_file = "text_model.pt"
torch.save(classifier_head.state_dict(), text_model_file)
print("Text Classification Model saved successfully.")

Text Classification Model saved successfully.
