In [117]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [118]:
# Personal Color 분류 CLIP 3 - CoCoOp

!pip install git+https://github.com/openai/CLIP.git
!pip install ftfy regex tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import clip
from PIL import Image
import os
import glob
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from torch.cuda.amp import autocast, GradScaler
from typing import List, Tuple, Optional

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-fsjkvukw
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-fsjkvukw
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [119]:
class PromptLearner(nn.Module):
    def __init__(self, classnames, clip_model, device, n_ctx=16, ctx_init=None):
        super().__init__()
        self.device = device
        self.n_classes = len(classnames)
        self.ctx_dim = clip_model.ln_final.weight.shape[0]
        self.n_ctx = n_ctx
        self.dtype = clip_model.dtype

        # Improved Meta-Net with Layer Normalization and Residual Connection
        self.meta_net = nn.Sequential(
            nn.LayerNorm(self.ctx_dim),
            nn.Linear(self.ctx_dim, self.ctx_dim * 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(self.ctx_dim * 4, self.ctx_dim),
            nn.Dropout(0.1)
        )

        # Improved Context Initialization
        if ctx_init is None:
            ctx_vectors = []
            for name in classnames:
                with torch.no_grad():
                    tokens = clip.tokenize(f"a photo of a person with {name} color tone").to(device)
                    ctx_vector = clip_model.token_embedding(tokens[0]).detach()
                    ctx_vectors.append(ctx_vector)
            ctx_init = torch.stack(ctx_vectors).mean(dim=0)[:self.n_ctx]

        # Add positional embeddings to context initialization
        pos_embeddings = torch.arange(self.n_ctx).float().to(device)
        pos_embeddings = pos_embeddings / self.n_ctx
        pos_embeddings = pos_embeddings.unsqueeze(1).expand(-1, self.ctx_dim)
        ctx_init = ctx_init + 0.1 * pos_embeddings

        self.ctx = nn.Parameter(ctx_init)

        # Improved prompt construction
        prompt_templates = [
            f"a photo of a person with {name} color tone",
            f"a portrait showing {name} seasonal color characteristics",
            f"an image demonstrating {name} personal color features"
        ]

        # Use the first prompt template as the base
        template = prompt_templates[0]
        self.tokenized_prompts = torch.cat([clip.tokenize(template.format(name=name)) for name in classnames]).to(device)

        with torch.no_grad():
            embedding = clip_model.token_embedding(self.tokenized_prompts).type(self.dtype)

        self.register_buffer('embedding', embedding)
        self.prompt_prefix_length = 4
        self.name_length = embedding.size(1) - self.prompt_prefix_length

    def forward(self, batch_size):
        # Dynamic context generation with residual connection
        ctx_features = self.meta_net(self.ctx)
        ctx = self.ctx + ctx_features  # Residual connection

        # Add batch dimension for processing
        ctx = ctx.unsqueeze(0)  # Shape: [1, n_ctx, ctx_dim]

        # Multi-scale context features
        ctx_scales = [ctx]
        for i in range(2):  # Generate 3 scales
            # Correct permutation for 1D pooling
            ctx_permuted = ctx.permute(0, 2, 1)  # Shape: [1, ctx_dim, n_ctx]
            ctx_scaled = F.adaptive_avg_pool1d(ctx_permuted, ctx.size(1) // (2 ** (i+1)))
            ctx_scaled = F.interpolate(ctx_scaled, size=ctx.size(1), mode='linear')
            ctx_scales.append(ctx_scaled.permute(0, 2, 1))  # Shape: [1, n_ctx, ctx_dim]

        # Average the multi-scale features
        ctx = torch.stack(ctx_scales).mean(dim=0)  # Shape: [1, n_ctx, ctx_dim]
        ctx = ctx.expand(batch_size * self.n_classes, -1, -1)

        # Process name embeddings
        name_embeddings = self.embedding[:, self.prompt_prefix_length:, :]
        name_embeddings = name_embeddings.unsqueeze(0).expand(batch_size, -1, -1, -1)
        name_embeddings = name_embeddings.reshape(-1, self.name_length, self.ctx_dim)

        # Concatenate context and name embeddings
        prompts = torch.cat([ctx, name_embeddings], dim=1)

        # Handle prompt length
        if prompts.size(1) > 77:
            prompts = prompts[:, :77, :]
        else:
            padding = torch.zeros(
                prompts.size(0),
                77 - prompts.size(1),
                prompts.size(2),
                dtype=self.dtype,
                device=self.device
            )
            prompts = torch.cat([prompts, padding], dim=1)

        return prompts

In [120]:
class CustomCLIP(nn.Module):
    def __init__(self, classnames, clip_model, device):
        super().__init__()
        self.device = device
        # Initialize model parameters in float32 for training stability
        self.dtype = torch.float32
        self.clip_dtype = clip_model.dtype  # Store original CLIP dtype for inference

        self.prompt_learner = PromptLearner(classnames, clip_model, device)
        self.image_encoder = clip_model.visual
        self.text_encoder = clip_model.transformer
        self.text_projection = clip_model.text_projection
        self.logit_scale = clip_model.logit_scale
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final

        # Configure fusion components
        hidden_dim = clip_model.visual.output_dim
        self.fusion = nn.ModuleDict({
            'attention': nn.MultiheadAttention(
                embed_dim=hidden_dim,
                num_heads=8,
                dropout=0.1,
                batch_first=True
            ),
            'norm1': nn.LayerNorm(hidden_dim),
            'norm2': nn.LayerNorm(hidden_dim),
            'mlp': nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim * 4),
                nn.GELU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim * 4, hidden_dim),
                nn.Dropout(0.1)
            )
        })

        # Convert all modules to float32
        self.to(dtype=torch.float32)

    def forward(self, image):
        batch_size = image.size(0)

        # Ensure input image is in float32
        image = image.to(device=self.device, dtype=torch.float32)

        # Process image features
        with torch.amp.autocast('cuda', dtype=torch.float32):
            image_features = self.image_encoder(image)
            image_features = image_features.to(dtype=torch.float32)

            # Prepare for attention
            image_features = image_features.unsqueeze(1)  # [batch_size, 1, hidden_dim]

            # Apply attention
            attn_output, _ = self.fusion['attention'](
                image_features, image_features, image_features
            )

            # Residual connection and normalization
            image_features = image_features + attn_output
            image_features = self.fusion['norm1'](image_features)

            # MLP with residual connection
            mlp_output = self.fusion['mlp'](image_features)
            image_features = image_features + mlp_output
            image_features = self.fusion['norm2'](image_features)

            # Remove sequence dimension and normalize
            image_features = image_features.squeeze(1)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

            # Process text features
            prompts = self.prompt_learner(batch_size)
            x = prompts + self.positional_embedding.type(torch.float32)
            x = x.permute(1, 0, 2)
            x = self.text_encoder(x)
            x = x.permute(1, 0, 2)
            x = self.ln_final(x)

            # Extract and normalize text features
            text_features = x[:, -1, :] @ self.text_projection.type(torch.float32)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            # Compute logits
            image_features = image_features.unsqueeze(1).expand(-1, self.prompt_learner.n_classes, -1)
            image_features = image_features.reshape(-1, image_features.size(-1))

            logit_scale = self.logit_scale.exp()
            logits = logit_scale * torch.sum(image_features * text_features, dim=-1)
            logits = logits.view(batch_size, self.prompt_learner.n_classes)

            return logits

In [121]:
class PersonalColorDataset(Dataset):
    def __init__(
        self,
        image_paths: List[str],
        labels: List[int],
        transform: Optional[transforms.Compose] = None,
        augment: bool = False
    ):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.augment = augment

        # Improvement 5: Enhanced color augmentation
        if augment:
            self.color_aug = transforms.Compose([
                transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
                transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.5),
                transforms.RandomAutocontrast(p=0.5)
            ])
        else:
            self.color_aug = None

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        image = Image.open(self.image_paths[idx]).convert('RGB')

        if self.transform:
            image = self.transform(image)
            if self.augment and np.random.random() > 0.5:
                image = self.color_aug(image)

        return image, self.labels[idx]

In [122]:
def train_model(model, train_loader, val_loader, num_epochs, device):
    scaler = torch.amp.GradScaler()

    # Configure optimizer
    optimizer = torch.optim.AdamW([
        {'params': [p for n, p in model.prompt_learner.named_parameters() if "meta_net" not in n], 'lr': 2e-3},
        {'params': model.prompt_learner.meta_net.parameters(), 'lr': 1e-3},
        {'params': model.fusion.parameters(), 'lr': 5e-4}
    ], weight_decay=0.05)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=num_epochs // 3,
        T_mult=2,
        eta_min=1e-6
    )

    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    criterion = criterion.to(device=device)

    best_val_acc = 0.0
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0

        for images, labels in train_loader:
            images = images.to(device=device)
            labels = labels.to(device=device)

            optimizer.zero_grad()

            # Use float32 for training
            with torch.amp.autocast('cuda', dtype=torch.float32):
                logits = model(images)
                loss = criterion(logits, labels)

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()

        scheduler.step()
        val_acc = validate_with_tta(model, val_loader, device)

        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Training Loss: {train_loss/len(train_loader):.4f}')
        print(f'Validation Accuracy: {val_acc:.4f}')

In [123]:
def validate_with_tta(model, val_loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device=device)
            labels = labels.to(device=device)

            # Test-time augmentation
            tta_outputs = []

            # Use float32 for validation as well
            with torch.amp.autocast('cuda', dtype=torch.float32):
                for flip in [False, True]:
                    img = images.flip(3) if flip else images
                    for scale in [0.9, 1.0, 1.1]:
                        size = int(224 * scale)
                        if size != 224:
                            img_scaled = F.interpolate(img, size=(size, size),
                                                     mode='bilinear',
                                                     align_corners=False)
                            img_scaled = F.interpolate(img_scaled, size=(224, 224),
                                                     mode='bilinear',
                                                     align_corners=False)
                        else:
                            img_scaled = img

                        outputs = model(img_scaled)
                        tta_outputs.append(outputs)

            # Average predictions
            outputs = torch.stack(tta_outputs).mean(0)
            _, predicted = outputs.max(1)

            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    return correct / total

In [124]:
def main():
    dataset_dir = '/content/drive/Othercomputers/내 노트북/personal-color-data/'
    dataset_types = ['train', 'test']
    class_folders = ['spring', 'summer', 'fall', 'winter']

    image_paths = {'train': [], 'test': []}
    labels = {'train': [], 'test': []}

    for dataset_type in dataset_types:
        for idx, class_folder in enumerate(class_folders):
            class_dir = os.path.join(dataset_dir, dataset_type, class_folder)
            for img_path in glob.glob(os.path.join(class_dir, '*.*')):
                if img_path.lower().endswith(('.jpg', '.jpeg', '.png')):
                    image_paths[dataset_type].append(img_path)
                    labels[dataset_type].append(idx)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    clip_model, _ = clip.load("ViT-B/32", device=device)
    model = CustomCLIP(class_folders, clip_model, device).to(device)

    # 개선 9: Strong augmentation for training
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.RandomApply([transforms.ColorJitter(0.3, 0.3, 0.3, 0.1)], p=0.5),
        transforms.RandomApply([transforms.GaussianBlur(3)], p=0.3),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                           std=(0.26862954, 0.26130258, 0.27577711))
    ])

    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                           std=(0.26862954, 0.26130258, 0.27577711))
    ])

    train_dataset = PersonalColorDataset(image_paths['train'], labels['train'],
                                       transform=train_transform, augment=True)
    val_dataset = PersonalColorDataset(image_paths['test'], labels['test'],
                                     transform=val_transform, augment=False)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,
                            num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False,
                          num_workers=4, pin_memory=True)

    train_model(model, train_loader, val_loader, num_epochs=15, device=device)

if __name__ == "__main__":
    main()

Epoch 1/15:
Training Loss: 1.3863
Validation Accuracy: 0.2284
Epoch 2/15:
Training Loss: 1.3863
Validation Accuracy: 0.2284
Epoch 3/15:
Training Loss: 1.3863
Validation Accuracy: 0.2284
Epoch 4/15:
Training Loss: 1.3863
Validation Accuracy: 0.2284
Epoch 5/15:
Training Loss: 1.3863
Validation Accuracy: 0.2284
Epoch 6/15:
Training Loss: 1.3863
Validation Accuracy: 0.2284
Epoch 7/15:
Training Loss: 1.3863
Validation Accuracy: 0.2284
Epoch 8/15:
Training Loss: 1.3863
Validation Accuracy: 0.2284
Epoch 9/15:
Training Loss: 1.3863
Validation Accuracy: 0.2284
Epoch 10/15:
Training Loss: 1.3863
Validation Accuracy: 0.2284
Epoch 11/15:
Training Loss: 1.3863
Validation Accuracy: 0.2284
Epoch 12/15:
Training Loss: 1.3863
Validation Accuracy: 0.2284
Epoch 13/15:
Training Loss: 1.3863
Validation Accuracy: 0.2284
Epoch 14/15:
Training Loss: 1.3863
Validation Accuracy: 0.2284
Epoch 15/15:
Training Loss: 1.3863
Validation Accuracy: 0.2284
