# dAiv AI_Competition[2024]_Pro

## Import Libraries

In [None]:
#%pip install pygwalker wandb

In [None]:
from os import path, rename, mkdir, listdir

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets, utils, transforms, models

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import pygwalker as pyg
import wandb

datasets.utils.tqdm = tqdm
%matplotlib inline

In [None]:
# WandB Initialization
#wandb.init(project="dAiv-ai-competition-2024-pro")

### Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number 0~7
DEVICE_NUM = 0

device = torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.set_device(DEVICE_NUM)
    device = torch.device("cuda")
print("INFO: Using device -", device)

## Load DataSets

In [None]:
from typing import Callable, Optional
from sklearn.model_selection import train_test_split


class ImageDataset(datasets.ImageFolder):
    download_url = "https://daiv-cnu.duckdns.org/contest/ai_competition[2024]_pro/dataset/archive.zip"
    random_state = 20241028

    def __init__(
            self, root: str, force_download: bool = True,
            train: bool = False, valid: bool = False, split_ratio: float = 0.8,
            test: bool = False, unlabeled: bool = False,
            transform: Optional[Callable] = None, target_transform: Optional[Callable] = None
    ):
        self.download(root, force=force_download)  # Download Dataset from server

        if train or valid:  # Set-up directory
            root = path.join(root, "train")
        else:
            root = path.join(root, "test" if test else "unlabeled" if unlabeled else None)

        # Initialize ImageFolder
        super().__init__(root=root, transform=transform, target_transform=target_transform)

        if train or valid:  # Split Train and Validation Set
            seperated = train_test_split(
                self.samples, self.targets, test_size=1-split_ratio, stratify=self.targets, random_state=self.random_state
            )
            self.samples, self.targets = (seperated[0], seperated[2]) if train else (seperated[1], seperated[3])
            self.imgs = self.samples

    @property
    def df(self) -> pd.DataFrame:
        return pd.DataFrame(dict(path=[d[0] for d in self.samples], label=[self.classes[lb] for lb in self.targets]))

    @classmethod
    def download(cls, root: str, force: bool = False):
        if force or not path.isfile(path.join(root, "archive.zip")):
            # Download and Extract Dataset
            datasets.utils.download_and_extract_archive(cls.download_url, download_root=root, extract_root=root, filename="archive.zip")

            # Arrange Dataset Directory
            for target_dir in [path.join(root, "test"), path.join(root, "unlabeled")]:
                for file in listdir(target_dir):
                    mkdir(path.join(target_dir, file.replace(".jpg", "")))
                    rename(path.join(target_dir, file), path.join(target_dir, file.replace(".jpg", ""), file))

            print("INFO: Dataset archive downloaded and extracted.")
        else:
            print("INFO: Dataset archive found in the root directory. Skipping download.")

### Dataset Initialization

In [None]:
# Image Resizing and Tensor Conversion
IMG_SIZE = (256, 256)
IMG_NORM = dict(  # ImageNet Normalization
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)

resizer = transforms.Compose([
    transforms.Resize(IMG_SIZE),  # Resize Image
    transforms.ToTensor(),  # Convert Image to Tensor
    transforms.Normalize(**IMG_NORM)  # Normalization
])

In [None]:
DATA_ROOT = path.join(".", "data")

train_dataset = ImageDataset(root=DATA_ROOT, force_download=False, train=True, transform=resizer)
valid_dataset = ImageDataset(root=DATA_ROOT, force_download=False, valid=True, transform=resizer)

test_dataset = ImageDataset(root=DATA_ROOT, force_download=False, test=True, transform=resizer)
unlabeled_dataset = ImageDataset(root=DATA_ROOT, force_download=False, unlabeled=True, transform=resizer)

print(f"INFO: Dataset loaded successfully. Number of samples - Train({len(train_dataset)}), Valid({len(valid_dataset)}), Test({len(test_dataset)}), Unlabeled({len(unlabeled_dataset)})")

### Visualize Dataset Distribution
    - for checking...

In [None]:
# Label Check
for i, label in zip(range(5), train_dataset.targets):
    print(i, "-", train_dataset.classes[label])

In [None]:
train_dataset.df

In [None]:
# Train Dataset Distribution
pyg.walk(train_dataset.df)

In [None]:
valid_dataset.df

In [None]:
# Valid Dataset Distribution
walker = pyg.walk(valid_dataset.df, theme_key="streamlit")

## Data Augmentation if needed

In [None]:
ROTATE_ANGLE = 20
COLOR_TRANSFORM = 0.1

In [None]:
augmenter = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(ROTATE_ANGLE),
    transforms.ColorJitter(
        brightness=COLOR_TRANSFORM, contrast=COLOR_TRANSFORM,
        saturation=COLOR_TRANSFORM, hue=COLOR_TRANSFORM
    ),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0), ratio=(0.75, 1.333)),
    resizer
])

In [None]:
train_dataset = ImageDataset(root=DATA_ROOT, force_download=False, train=True, transform=augmenter)

print(f"INFO: Train dataset has been overridden with augmented state. Number of samples - Train({len(train_dataset)})")

## DataLoader

In [None]:
# Set Batch Size
BATCH_SIZE = 128

In [None]:
MULTI_PROCESSING = True  # Set False if DataLoader is causing issues

from platform import system
if MULTI_PROCESSING and system() != "Windows":  # Multiprocess data loading is not supported on Windows
    import multiprocessing
    cpu_cores = multiprocessing.cpu_count()
    print(f"INFO: Number of CPU cores - {cpu_cores}")
else:
    cpu_cores = 0
    print("INFO: Using DataLoader without multi-processing.")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=cpu_cores)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=cpu_cores)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=cpu_cores)

In [None]:
# Image Visualizer
def imshow(image_list, mean=IMG_NORM['mean'], std=IMG_NORM['std']):
    np_image = np.array(image_list).transpose((1, 2, 0))
    de_norm_image = np_image * std + mean
    plt.figure(figsize=(10, 10))
    plt.imshow(de_norm_image)

In [None]:
#images, targets = next(iter(train_loader))
#grid_images = utils.make_grid(images, nrow=8, padding=10)
#imshow(grid_images)

## Define Model

In [None]:
class VisualEmbedding(nn.Module):
    """ Visual Embedding Model """
    
    def __init__(self, num_classes: int, embedding_dim: int):
        super().__init__()

        # Image Embedding
        self.image_embedding = models.resnet34(pretrained=True)
        self.image_embedding.avgpool = nn.AdaptiveMaxPool2d((1, 1))
        self.hidden_size = self.image_embedding.fc.in_features
        self.image_embedding.fc = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size//2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(self.hidden_size//2, embedding_dim),
            nn.LayerNorm(embedding_dim)
        )

        # Semantic Class Lookup Table
        self.class_lookup = nn.Parameter(
            torch.randn(num_classes, embedding_dim)
        )
        nn.init.xavier_uniform_(self.class_lookup)

    def forward(self, x) -> torch.Tensor:
        return self.image_embedding(x)

In [None]:
class ImageClassifier(nn.Module):
    def __init__(self, embedding_dim: int, num_classes: int, use_softmax=False):
        super().__init__()
        self.use_softmax = use_softmax

        # Visual Embedding
        self.visual_embedding = VisualEmbedding(num_classes, embedding_dim)
        self.class_lookup = self.visual_embedding.class_lookup

        # Cosine Similarity Temperature
        self.temperature = nn.Parameter(torch.ones(1) * 0.2)

    def forward(self, x) -> torch.Tensor:
        img_embeddings = F.normalize(self.visual_embedding(x), p=2, dim=1)
        cls_embeddings = F.normalize(self.class_lookup, p=2, dim=1)
        return torch.mm(img_embeddings, cls_embeddings.t()) / self.temperature  # find embedding location

    def predict_top_k(self, x, k=2, threshold=0.5, min_similarity=0.3):
        if self.use_softmax:
            return self.predict_top_k_by_threshold(x, k=k, threshold=threshold)
        else:
            return self.predict_top_k_by_similarity(x, k=k, min_similarity=min_similarity)

    def predict_top_k_by_threshold(self, x, k=2, threshold=0.5):
        similarity = self(x)
        probabilities = F.softmax(similarity, dim=1)

        top_probs, top_classes = torch.topk(probabilities, k, dim=1)

        if k > 1:  # one class prediction
            relative_probs = top_probs[:, 1] / top_probs[:, 0]
            mask = relative_probs < threshold
            top_classes[mask, 1] = -1

        return top_classes

    def predict_top_k_by_similarity(self, x, k=2, min_similarity=0.3):
        similarity = self(x)  # cos range (-1 ~ 1)
        detected_classes, detected_scores = [], []

        for _ in range(similarity.size(0)):
            scores, classes = similarity[i].sort(descending=True)

            mask = scores >= min_similarity
            valid_classes = classes[mask][:k]  # clip k by similarity
            valid_scores = scores[mask][:k]

            if len(valid_classes) < k:
                padding = torch.full((k-len(valid_classes),), -1, device=device)
                valid_classes = torch.cat([valid_classes, padding])
                valid_scores = torch.cat([valid_scores, torch.zeros_like(padding)])

            detected_classes.append(valid_classes)
            detected_scores.append(valid_scores)

        return torch.stack(detected_classes)

In [None]:
CLASS_LABELS = len(train_dataset.classes) + 1
EMBEDDING_DIM = 16
USE_SOFTMAX = False

MODEL_PARAMS = dict(
    embedding_dim=EMBEDDING_DIM, num_classes=CLASS_LABELS, use_softmax=USE_SOFTMAX
)

In [None]:
# Initialize Model
model = ImageClassifier(**MODEL_PARAMS)
model.to(device)

In [None]:
class CrossContrastiveLoss(nn.Module):
    def __init__(self, margin=0.1, alpha=0.7):
        super().__init__()
        self.margin = margin
        self.alpha = alpha
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, similarity, labels):
        ce_loss = self.cross_entropy(similarity, labels)

        batch_size = similarity.size(0)
        pos_mask = torch.zeros_like(similarity, dtype=torch.bool)  # let be the positive pair closer
        pos_mask[torch.arange(batch_size), labels] = True
        neg_mask = ~pos_mask  # let be the negative pair farther
        
        pos_similarity = similarity[pos_mask].mean()
        neg_similarity = similarity[neg_mask].mean()

        contrastive_loss = torch.clamp(neg_similarity - pos_similarity + self.margin, min=0.0)

        return self.alpha * ce_loss + (1 - self.alpha) * contrastive_loss

In [None]:
class CrossSimilarityLoss(nn.Module):
    def __init__(self, pos_margin=0.7, neg_margin=0.3, alpha=0.7):
        super().__init__()
        self.pos_margin = pos_margin
        self.neg_margin = neg_margin
        self.alpha = alpha
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, similarity, labels):
        ce_loss = self.cross_entropy(similarity, labels)
        
        batch_size = similarity.size(0)
        pos_mask = torch.zeros_like(similarity, dtype=torch.bool)  # let be the positive pair closer
        pos_mask[torch.arange(batch_size), labels] = True
        neg_mask = ~pos_mask  # let be the negative pair farther

        pos_loss = torch.clamp(self.pos_margin - similarity[pos_mask], min=0.0).mean()
        neg_loss = torch.clamp(similarity[neg_mask] - self.neg_margin, min=0.0).mean()

        return self.alpha * ce_loss + (1 - self.alpha) * (pos_loss + neg_loss)

In [None]:
LEARNING_RATE = 0.001

criterion = CrossContrastiveLoss() if USE_SOFTMAX else CrossSimilarityLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE, steps_per_epoch=len(train_loader), epochs=50)

## Training Loop

In [None]:
from IPython.display import display
import ipywidgets as widgets

# Interactive Loss Plot Update
def create_plot():
    losses = []

    # Enable Interactive Mode
    plt.ion()

    # Loss Plot Setting
    fig, ax = plt.subplots(figsize=(6, 2))
    line, = ax.plot(losses)
    ax.set_xlabel("Iteration")
    ax.set_ylabel("Loss")
    ax.set_title("Cross Entropy Loss")

    # Display Plot
    plot = widgets.Output()
    display(plot)

    def update_plot(new_loss):
        losses.append(new_loss.item())
        line.set_ydata(losses)
        line.set_xdata(range(len(losses)))
        ax.relim()
        ax.autoscale_view()
        with plot:
            plot.clear_output(wait=True)
            display(fig)

    return update_plot

In [None]:
#wandb.watch(model, criterion, log="all", log_freq=10)

In [None]:
# Set Epoch Count
num_epochs = 50

In [None]:
train_length, valid_length = map(len, (train_loader, valid_loader))

epochs = tqdm(range(num_epochs), desc="Running Epochs")
with (tqdm(total=train_length, desc="Training") as train_progress,
      tqdm(total=valid_length, desc="Validation") as valid_progress):  # Set up Progress Bars
    update = create_plot()  # Create Loss Plot

    for epoch in epochs:
        train_progress.reset(total=train_length)
        valid_progress.reset(total=valid_length)

        # Training
        model.train()
        for i, (inputs, targets) in enumerate(train_loader):
            optimizer.zero_grad()

            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)

            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            update(loss)
            train_progress.update(1)
            #if i != train_length-1: wandb.log({'Loss': loss.item()})
            print(f"\rEpoch [{epoch+1:2}/{num_epochs}], Step [{i+1:2}/{train_length}], Loss: {loss.item():.6f}", end="")

        val_acc, val_loss = 0, 0

        # Validation
        model.eval()
        with torch.no_grad():
            for inputs, targets in valid_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)

                val_loss += criterion(outputs, targets).item() / valid_length
                val_acc += (torch.max(outputs, 1)[1] == targets.data).sum() / len(valid_dataset)
                valid_progress.update(1)

        #wandb.log({'Loss': loss.item(), 'Val Acc': val_acc, 'Val Loss': val_loss})
        print(f"\rEpoch [{epoch+1:2}/{num_epochs}], Step [{train_length}/{train_length}], Loss: {loss.item():.6f}, Valid Acc: {val_acc:.6%}, Valid Loss: {val_loss:.6f}", end="\n" if (epoch+1) % 5 == 0 or (epoch+1) == num_epochs else "")

In [None]:
if not path.isdir(path.join(".", "models")):
    mkdir(path.join(".", "models"))

# Model Save
save_path = path.join(".", "models", f"visual_embedding.pt")
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

# Model Evaluation

In [None]:
# Load Model
model_id = "visual_embedding"

model = ImageClassifier(**MODEL_PARAMS)
model.load_state_dict(torch.load(path.join(".", "models", f"{model_id}.pt")))
model.to(device)

In [None]:
_ids, _preds = [], []
test_length = len(test_dataset)

model.eval()
with torch.no_grad():
    for inputs, ids in tqdm(test_loader):
        inputs = inputs.to(device)
        _ids.extend([test_dataset.classes[i] for i in ids])
        _preds.extend(model.predict_top_k(inputs, k=2, min_similarity=0.3))

In [None]:
results = dict(id=[], label1=[], label2=[])
for i, labels in zip(_ids, _preds):
    results['id'].append(i)
    labels = [-2 if v == CLASS_LABELS-1 else v for v in (labels[0].item(), labels[1].item())]
    results['label1'].append(min(labels))
    results['label2'].append(max(labels))

results_df = pd.DataFrame(results)
results_df

In [None]:
# Save Results
submission_dir = "submissions"
if not path.isdir(submission_dir):
    mkdir(submission_dir)

submit_file_path = path.join(submission_dir, f"{model_id}.csv")
results_df.to_csv(submit_file_path, index=False)
print("File saved to", submit_file_path)