In [None]:
import os
import torch
from torchvision.transforms import v2
import numpy as np
import pandas as pd
from PIL import Image
import open_clip
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict
import random
import re
import spacy
from huggingface_hub import HfApi


os.environ['CUDA_VISIBLE_DEVICES'] = '7'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Model(torch.nn.Module):
    def __init__(self, model_name='ViT-B-32', pretrained='laion2b_s34b_b79k') -> None:
        super().__init__()
        self.tokenizer = open_clip.get_tokenizer(model_name)
        self.feature_extractor, _, self.processor = open_clip.create_model_and_transforms(
            model_name=model_name,
            pretrained=pretrained
        )
        
        # Get CLIP embedding dimension
        self.embed_dim = self.feature_extractor.visual.output_dim
        
        # Additional projection layers
        self.query_projection = torch.nn.Sequential(
            torch.nn.Linear(self.embed_dim * 2, self.embed_dim),
            torch.nn.LayerNorm(self.embed_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(self.embed_dim, self.embed_dim)
        )
        
        self.database_projection = torch.nn.Sequential(
            torch.nn.Linear(self.embed_dim, self.embed_dim),
            torch.nn.LayerNorm(self.embed_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(self.embed_dim, self.embed_dim)
        )
        
        self.set_param_trainable_mode(module=self.feature_extractor, status=False)


    def set_param_trainable_mode(self, module, status):
        for param in module.parameters():
            param.requires_grad = status
    
    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path, weights_only=True))

    def forward(self, query_image, query_text):
        # Get base embeddings from CLIP
        image_features = self.feature_extractor.encode_image(query_image)
        text_features = self.feature_extractor.encode_text(query_text)
        
        # Concatenate image and text features
        combined_features = torch.cat([image_features, text_features], dim=1)
        
        # Project through learnable layers
        query_embedding = self.query_projection(combined_features)
        
        return query_embedding
    
    def encode_database_image(self, image):
        image_features = self.feature_extractor.encode_image(image)
        database_embedding = self.database_projection(image_features)
        return database_embedding

In [None]:
class InfoNCELoss(torch.nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, query_embeds, database_embeds):
        """
        InfoNCE loss implementation
        
        Args:
            query_embeds: Query embeddings [batch_size, embed_dim]
            database_embeds: Database embeddings [batch_size, embed_dim]
            
        Returns:
            loss: InfoNCE loss value
        """
        # Normalize embeddings
        query_embeds = torch.nn.functional.normalize(query_embeds, dim=1)
        database_embeds = torch.nn.functional.normalize(database_embeds, dim=1)
        
        # Calculate similarity matrix
        similarity_matrix = torch.matmul(query_embeds, database_embeds.T) / self.temperature
        
        # Labels are the diagonal elements (positive pairs)
        labels = torch.arange(len(query_embeds)).to(query_embeds.device)
        
        # Calculate loss in both directions (query->database and database->query)
        loss_q2d = self.criterion(similarity_matrix, labels)
        loss_d2q = self.criterion(similarity_matrix.T, labels)
        
        # Total loss is the average of both directions
        return (loss_q2d + loss_d2q) / 2

In [4]:
SPLIT_RATIO = 0.95
IMAGE_ROOT_DIR = os.path.join(os.getcwd(), 'dataset', 'images')
ANNOTATIONS_FILE_PATH = os.path.join(os.getcwd(), 'dataset', 'data.csv')
TEST_ROOT_DIR = os.path.join(os.getcwd(), 'sample_evaluation', 'images')
TEST_ANNOTATIONS_FILE_PATH = os.path.join(os.getcwd(), 'sample_evaluation', 'data.csv')
BATCH_SIZE = 80
NUM_WORKERS = 128
NUM_EPOCHS = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps")
MODEL_NAME = 'ViT-B-32'
PRETRAINED_WEIGHTS = 'laion2b_s34b_b79k'
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.01
LOSS_TEMPERATURE = 0.07
SCHEDULER_T_0 = 5
SCHEDULER_T_MULT = 2


model = Model(model_name=MODEL_NAME, pretrained=PRETRAINED_WEIGHTS).to(DEVICE)
criterion = InfoNCELoss(temperature=LOSS_TEMPERATURE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=SCHEDULER_T_0, T_mult=SCHEDULER_T_MULT)

In [5]:
class RetrievalDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir_path: str, annotations_file_path: str, split: str, transform=None, tokenizer=None) -> None:
        self.img_dir_path = img_dir_path
        self.transform = transform
        self.tokenizer = tokenizer
        self.split = split
        self.annotations = self.split_data(
            # self.data_health_check(
                self.convert_image_names_to_path(
                    pd.read_csv(annotations_file_path)
                )
            # )
        )
    
    def __len__(self) -> int:
        return len(self.annotations)

    def __getitem__(self, idx: int) -> tuple:
        query_img_path = self.annotations.iloc[idx]['query_image']
        query_text = self.annotations.iloc[idx]['query_text']
        target_img_path = self.annotations.iloc[idx]['target_image']
        query_img = Image.open(query_img_path).convert('RGB')
        target_img = Image.open(target_img_path).convert('RGB')
        # query_img = torchvision.io.read_image(path=query_img_path, mode=torchvision.io.image.ImageReadMode.RGB)
        # target_img = torchvision.io.read_image(path=target_img_path, mode=torchvision.io.image.ImageReadMode.RGB)
        if self.transform:
            query_img = self.transform(query_img)
            target_img = self.transform(target_img)
        if self.tokenizer:
            query_text = self.tokenizer(query_text).squeeze(0)
        return query_img, query_text, target_img
    
    def split_data(self, annotations):
        shuffled_df = annotations.sample(frac=1, random_state=42).reset_index(drop=True)
        if self.split == "test":
            return shuffled_df # sample test set
        if self.split == "train":
            return shuffled_df.iloc[:int(SPLIT_RATIO * len(shuffled_df))] # train set
        if self.split == "validation":
            return shuffled_df.iloc[int(SPLIT_RATIO * len(shuffled_df)):] # validation set
        raise Exception("split is not valid")

    def load_queries(self):
        return self.annotations.drop(columns=["target_image"])
    
    def load_database(self):
        return self.annotations[["target_image"]]
    
    def convert_image_names_to_path(self, df):
        df["query_image"] = self.img_dir_path + "/" + df["query_image"]
        df["target_image"] = self.img_dir_path + "/" + df["target_image"]
        return df
    
    # def data_health_check(self, annotations):
    #     img_files = os.listdir(self.img_dir_path)
    #     broken_files = [img for img in img_files if self.is_truncated(os.path.join(self.img_dir_path, img))]
    #     annotations = annotations[
    #         ~annotations['target_image'].isin(broken_files) &
    #         ~annotations['query_image'].isin(broken_files)
    #     ]
    #     return annotations
    
    # def is_truncated(self, image_path):
    #     try:
    #         with Image.open(image_path) as img:
    #             img.verify()
    #         return False
    #     except (IOError, SyntaxError, Image.DecompressionBombError) as e:
    #         return True

In [6]:
class UniqueTargetImageBatchSampler(torch.utils.data.Sampler):
    def __init__(self, dataset, batch_size, shuffle=True):
        """
        Initializes the sampler.

        Args:
            dataset (RetrievalDataset): The dataset to sample from.
            batch_size (int): Number of samples per batch.
            shuffle (bool): Whether to shuffle the data every epoch.
        """
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle

        # Create a mapping from target_image to list of indices
        self.target_to_indices = defaultdict(list)
        for idx in range(len(self.dataset)):
            target_image = self.dataset.annotations.iloc[idx]['target_image']
            self.target_to_indices[target_image].append(idx)
        
        # List of unique target_images
        self.unique_target_images = list(self.target_to_indices.keys())
        if self.shuffle:
            random.shuffle(self.unique_target_images)
            for indices in self.target_to_indices.values():
                random.shuffle(indices)

    def __iter__(self):
        """
        Yields lists of indices where each list represents a batch with unique target_images.
        """
        # Create a copy of indices per target_image to preserve original order
        queues = [indices.copy() for indices in self.target_to_indices.values()]
        
        if self.shuffle:
            random.shuffle(queues)
        
        batch = []
        while any(queues):
            for queue in queues:
                if queue:
                    batch.append(queue.pop())
                    if len(batch) == self.batch_size:
                        yield batch
                        batch = []
            # Optional: Shuffle queues after each full pass to ensure randomness
            if self.shuffle:
                random.shuffle(queues)
        
        if batch:
            yield batch

    def __len__(self):
        """
        Returns the number of batches per epoch.
        """
        total = len(self.dataset)
        return (total + self.batch_size - 1) // self.batch_size


In [7]:
train_transform = v2.Compose([
    v2.RandomResizedCrop(224, scale=(0.8, 1.0)),
    v2.RandomHorizontalFlip(),
    v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_transform = v2.Compose([
    v2.Resize(256),
    v2.CenterCrop(224),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


train_dataset = RetrievalDataset(
    img_dir_path=IMAGE_ROOT_DIR,
    annotations_file_path=ANNOTATIONS_FILE_PATH,
    split='train',
    transform=model.processor if hasattr(model, 'processor') else test_transform,
    tokenizer=model.tokenizer
)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    # batch_size=BATCH_SIZE, 
    # shuffle=True,
    num_workers=NUM_WORKERS,
    batch_sampler=UniqueTargetImageBatchSampler(dataset=train_dataset, batch_size=BATCH_SIZE)
)

val_dataset = RetrievalDataset(
    img_dir_path=IMAGE_ROOT_DIR,
    annotations_file_path=ANNOTATIONS_FILE_PATH,
    split='validation',
    transform=model.processor if hasattr(model, 'processor') else test_transform,
    tokenizer=model.tokenizer
)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
)

test_dataset = RetrievalDataset(
    img_dir_path=TEST_ROOT_DIR,
    annotations_file_path=TEST_ANNOTATIONS_FILE_PATH,
    split='test',
    transform=model.processor if hasattr(model, 'processor') else test_transform,
    tokenizer=model.tokenizer
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
)

In [8]:
def parse_actions(text):
    # Load English language model
    nlp = spacy.load("en_core_web_sm")
    
    # List of common action verbs and their normalized forms
    action_verbs = {
        'add': 'add',
        'insert': 'add',
        'introduce': 'add',
        'bring': 'add',
        'place': 'add',
        'include': 'add',
        'remove': 'remove',
        'eliminate': 'remove',
        'discard': 'remove',
        'take': 'remove',
        'get rid': 'remove'
    }
    
    # Process the text
    doc = nlp(text.lower())
    
    # Initialize result structure
    result = [
        {"verb": [], "nouns": []},
        {"verb": [], "nouns": []}
    ]
    
    # Split text into two parts if possible
    parts = re.split(r'\s+(?:and|,|then|\.|\s+)+\s*', text.lower())
    
    current_action = 0
    current_verb = None
    
    for token in doc:
        # Check for verbs
        lemma = token.lemma_.lower()
        if any(verb in lemma for verb in action_verbs.keys()):
            # Handle multi-word verbs
            verb_phrase = lemma
            if token.i + 1 < len(doc) and doc[token.i + 1].text in ['in', 'away', 'out']:
                verb_phrase += ' ' + doc[token.i + 1].text
            
            normalized_verb = None
            for verb, norm in action_verbs.items():
                if verb in verb_phrase:
                    normalized_verb = norm
                    break
            
            if normalized_verb:
                if current_verb != normalized_verb:
                    current_verb = normalized_verb
                    if current_action < 2:
                        result[current_action]["verb"] = [normalized_verb]
                        current_action += 1
        
        # Collect nouns
        elif token.pos_ == "NOUN":
            if current_action > 0 and len(result[current_action-1]["verb"]) > 0:
                if token.text not in result[current_action-1]["nouns"]:
                    result[current_action-1]["nouns"].append(token.text)
    
    # Clean up empty actions and ensure proper structure
    for action in result:
        if not action["verb"]:
            action["verb"] = ["add" if not result[0]["verb"] else "remove"]
    
    return result

In [None]:
def encode_queries(df: pd.DataFrame) -> np.ndarray:
    """
    Process query pairs and generate embeddings.

    Args:
    df (pd. DataFrame ): DataFrame with columns:
    - query_image: str, paths to query images
    - query_text: str, text descriptions

    Returns:
    np.ndarray: Embeddings array (num_queries, embedding_dim)
    """
    model.eval()
    all_embeddings = []
    for i in tqdm(range(0, len(df), BATCH_SIZE)):
        query_imgs = torch.stack([model.processor(Image.open(query_image_path)) for query_image_path in df['query_image'][i:i+BATCH_SIZE]]).to(DEVICE)
        samples = []
        for sample in df['query_text'][i:i+BATCH_SIZE].map(lambda text: parse_actions(text)):
            with torch.no_grad():
                samples.append({sample[0]['verb'][0]: model.feature_extractor.encode_text(model.tokenizer(list(map(lambda x: "a photo of " + x, sample[0]['nouns']))).to(DEVICE)), sample[1]['verb'][0]: model.feature_extractor.encode_text(model.tokenizer(list(map(lambda x: "a photo of " + x, sample[1]['nouns']))).to(DEVICE))})
        # query_texts = model.tokenizer(df['query_text'][i:i+BATCH_SIZE]).to(DEVICE)
        with torch.no_grad():
            # query_embedding = model(query_imgs, query_texts)
            query_embedding = model.feature_extractor.encode_image(query_imgs)
        assert len(samples) == len(query_embedding)
        for j in range(len(samples)):
            for val in samples[j]['add']:
                query_embedding[j] += val
            for val in samples[j]['remove']:
                query_embedding[j] -= val
        query_embedding = torch.nn.functional.normalize(query_embedding, dim=1, p=2)
        all_embeddings.append(query_embedding.detach().cpu().numpy())
    return np.concatenate(all_embeddings)


def encode_database(df: pd.DataFrame) -> np.ndarray :
    """
    Process database images and generate embeddings.

    Args:
    df (pd. DataFrame ): DataFrame with column:
    - target_image: str, paths to database images

    Returns:
    np.ndarray: Embeddings array (num_images, embedding_dim)
    """
    model.eval()
    all_embeddings = []
    for i in tqdm(range(0, len(df), BATCH_SIZE)):
        target_imgs = torch.stack([model.processor(Image.open(target_image_path)) for target_image_path in df['target_image'][i:i+BATCH_SIZE]]).to(DEVICE)
        with torch.no_grad():
            # target_imgs_embedding = model.encode_database_image(target_imgs)
            target_imgs_embedding = model.feature_extractor.encode_image(target_imgs)
        target_imgs_embedding = torch.nn.functional.normalize(target_imgs_embedding, dim=1, p=2)
        all_embeddings.append(target_imgs_embedding.detach().cpu().numpy())
    return np.concatenate(all_embeddings)

In [10]:
def calculate_accuracy(predictions: np.ndarray, ground_truth: np.ndarray) -> float:
    assert predictions.shape == ground_truth.shape, "Predictions and ground truth must have the same shape."
    
    # Calculate the number of correct predictions
    correct_predictions = (predictions == ground_truth).sum()
    total_predictions = len(predictions)
    
    # Calculate accuracy as a percentage
    accuracy = correct_predictions / total_predictions
    return accuracy

def evaluate(dataset):
    query_embeddings = encode_queries(dataset.load_queries())
    database_embeddings = encode_database(dataset.load_database())
    similarities = cosine_similarity(query_embeddings, database_embeddings)
    predictions = np.argmax(similarities, axis=1)
    ground_truth = np.arange(len(database_embeddings))
    accuracy = calculate_accuracy(predictions, ground_truth)
    return accuracy

In [None]:
print(f"Validation Accuracy: {100*evaluate(val_dataset)}")
print(f"Test Accuracy: {100*evaluate(test_dataset)}")

In [14]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    with tqdm(train_loader, desc="Training") as pbar:
        for batch_idx, (query_imgs, query_texts, target_imgs) in enumerate(pbar):
            optimizer.zero_grad()
            
            # Move data to device
            query_imgs = query_imgs.to(device)
            target_imgs = target_imgs.to(device)
            query_texts = query_texts.to(device)
            
            # Forward pass
            query_embeds = model(query_imgs, query_texts)
            database_embeds = model.encode_database_image(target_imgs)
            
            # Calculate loss
            loss = criterion(query_embeds, database_embeds)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            total_loss += loss.item()
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'avg_loss': f"{total_loss / (batch_idx + 1):.4f}"
            })
    
    return total_loss / len(train_loader)

In [13]:
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for query_imgs, query_texts, target_imgs in tqdm(loader, desc="Validation"):
            # Move data to device
            query_imgs = query_imgs.to(device)
            target_imgs = target_imgs.to(device)
            
            # Forward pass
            query_embeds = model(query_imgs, query_texts)
            database_embeds = model.encode_database_image(target_imgs)
            
            # Calculate loss
            loss = criterion(query_embeds, database_embeds)
            total_loss += loss.item()
            
            # Calculate accuracy
            similarity = torch.matmul(query_embeds, database_embeds.T)
            predictions = similarity.argmax(dim=1)
            labels = torch.arange(len(predictions)).to(device)
            correct += (predictions == labels).sum().item()
            total += len(predictions)
    
    return total_loss / len(loader), correct / total

In [None]:
best_val_acc = 0
model.set_param_trainable_mode(model.feature_extractor, status=True)
for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    train_loss = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
    print(f"Test Accuracy: {100*evaluate(test_dataset)}")
    scheduler.step()

In [20]:
print(f"Validation Accuracy: {100*evaluate(val_dataset)}")
print(f"Test Accuracy: {100*evaluate(test_dataset)}")

100%|██████████| 1/1 [01:49<00:00, 109.09s/it]
100%|██████████| 1/1 [01:47<00:00, 107.93s/it]


Validation Accuracy: 67.08860759493672


100%|██████████| 1/1 [00:09<00:00,  9.18s/it]
100%|██████████| 1/1 [00:09<00:00,  9.14s/it]

Test Accuracy: 55.00000000000001





In [35]:
model.save("/home/nafisi/temp/rayan-phase2-q1/weights.pth")

In [None]:
api = HfApi()
api.upload_file(
    path_or_fileobj="/home/nafisi/temp/rayan-phase2-q1/weights.pth",
    path_in_repo="weights.pth",
    repo_id="safinal/rayan-phase2-q1",
    repo_type="model",
)