In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# Personal Color 분류 CLIP 5 - LoRA

# 라이브러리 임포트
!pip install git+https://github.com/openai/CLIP.git
!pip install ftfy regex tqdm

import torch
import torch.nn as nn
import clip
from PIL import Image
import os
import glob
from torchvision import transforms
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-889g96gr
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-889g96gr
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369489 sha256=d703af7cd887d486855d8349c6701516b74e5582551fbe95f196a657f984045d
  Stored in directory: /tmp/pip-ephem-wheel-cache-x9y0_xlx/wheels/da/2b/4c/d6691fa9597aac8bb

In [3]:
class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank=2):
        super().__init__()
        self.down = nn.Linear(in_dim, rank, bias=False)
        self.up = nn.Linear(rank, out_dim, bias=False)

        # Initialize weights
        nn.init.kaiming_uniform_(self.down.weight)
        nn.init.zeros_(self.up.weight)

    def forward(self, x):
        return self.up(self.down(x))

In [4]:
class AttentionWithLoRA(nn.Module):
    def __init__(self, dim, num_heads=8, lora_rank=4):
        super().__init__()
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5

        # QKV projection
        self.qkv = nn.Linear(dim, dim * 3)

        # LoRA layer
        self.lora = LoRALayer(dim, dim, rank=lora_rank)

        # Output projection
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        # Add LoRA contribution
        lora_out = self.lora(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        q = q + lora_out

        # Attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

In [5]:
class LoRACLIP(nn.Module):
    def __init__(self, classnames, clip_model, device, lora_rank=4):
        super().__init__()
        self.clip_model = clip_model
        self.device = device
        self.classnames = classnames

        # Convert model to float32
        self.clip_model = self.clip_model.float()
        for param in self.clip_model.parameters():
            param.data = param.data.float()

        # Initialize attention with LoRA layers
        hidden_size = self.clip_model.visual.transformer.width

        # CLIP ViT-B/32는 12개의 attention head를 사용
        num_heads = 12

        # Replace attention layers with custom attention
        for block in self.clip_model.visual.transformer.resblocks:
            block.attn = AttentionWithLoRA(hidden_size, num_heads, lora_rank=lora_rank)

        # Prepare class token embeddings
        with torch.no_grad():
            self.tokenized_prompts = torch.cat([
                clip.tokenize(f"a photo of a person with {name} color tone")
                for name in classnames
            ]).to(device)

    def encode_image(self, x):
        # Input to float32
        x = x.float()

        x = self.clip_model.visual.conv1(x)
        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = x.permute(0, 2, 1)
        x = torch.cat([self.clip_model.visual.class_embedding.to(x.dtype) +
                      torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
        x = x + self.clip_model.visual.positional_embedding.to(x.dtype)
        x = self.clip_model.visual.ln_pre(x)

        # Transformer blocks
        for block in self.clip_model.visual.transformer.resblocks:
            x = x + block.attn(block.ln_1(x))
            x = x + block.mlp(block.ln_2(x))

        x = self.clip_model.visual.ln_post(x[:, 0, :])

        if self.clip_model.visual.proj is not None:
            x = x @ self.clip_model.visual.proj.float()

        return x

    def forward(self, image):
        # image 인코딩
        image_features = self.encode_image(image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # text 인코딩
        with torch.no_grad():
            text_features = self.clip_model.encode_text(self.tokenized_prompts).float()
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # 유사도 계산
        logit_scale = self.clip_model.logit_scale.exp().float()
        logits = logit_scale * image_features @ text_features.t()

        return logits

In [6]:
def evaluate_model(model, data_loader, device):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in tqdm(data_loader):
            images = images.to(device)
            labels = labels.to(device)

            logits = model(images)
            preds = torch.argmax(logits, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    class_names = ['spring', 'summer', 'fall', 'winter']

    report = classification_report(all_labels, all_preds, target_names=class_names)
    conf_matrix = confusion_matrix(all_labels, all_preds)

    return accuracy, report, conf_matrix

In [8]:
def main():
    dataset_dir = '/content/drive/Othercomputers/내 노트북/personal-color-data/'
    dataset_types = ['train', 'test']
    class_folders = ['spring', 'summer', 'fall', 'winter']

    image_paths = {'train': [], 'test': []}
    labels = {'train': [], 'test': []}

    for dataset_type in dataset_types:
        for idx, class_folder in enumerate(class_folders):
            class_dir = os.path.join(dataset_dir, dataset_type, class_folder)
            for img_path in glob.glob(os.path.join(class_dir, '*.*')):
                if img_path.lower().endswith(('.jpg', '.jpeg', '.png')):
                    image_paths[dataset_type].append(img_path)
                    labels[dataset_type].append(idx)

    # CLIP
    device = "cuda" if torch.cuda.is_available() else "cpu"
    clip_model, preprocess = clip.load("ViT-B/32", device=device)

    # Initialize LoRA CLIP
    model = LoRACLIP(class_folders, clip_model, device).to(device)

    # Data loaders
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                           (0.26862954, 0.26130258, 0.27577711))
    ])

    class PersonalColorDataset(torch.utils.data.Dataset):
        def __init__(self, image_paths, labels, transform=None):
            self.image_paths = image_paths
            self.labels = labels
            self.transform = transform

        def __len__(self):
            return len(self.image_paths)

        def __getitem__(self, idx):
            image = Image.open(self.image_paths[idx]).convert('RGB')
            if self.transform:
                image = self.transform(image)
            label = self.labels[idx]
            return image, label

    train_dataset = PersonalColorDataset(image_paths['train'], labels['train'], transform)
    val_dataset = PersonalColorDataset(image_paths['test'], labels['test'], transform)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

    # Optimizer
    optimizer = torch.optim.AdamW([
        {'params': [p for n, p in model.named_parameters() if 'lora' in n or 'attn' in n], 'lr': 1e-3}
    ])
    criterion = nn.CrossEntropyLoss().to(device)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)

    # Training
    num_epochs = 10

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0

        for images, labels in tqdm(train_loader):
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            logits = model(images)
            loss = criterion(logits, labels)

            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        scheduler.step()

        # Evaluate
        val_acc, val_report, val_conf_matrix = evaluate_model(model, val_loader, device)
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"Training Loss: {train_loss/len(train_loader):.4f}")
        print(f"Validation Accuracy: {val_acc:.4f}")
        print("\nConfusion Matrix:")
        print(val_conf_matrix)
        print("\nClassification Report:")
        print(val_report)


if __name__ == "__main__":
    main()

100%|██████████| 81/81 [00:57<00:00,  1.40it/s]
100%|██████████| 15/15 [00:05<00:00,  2.62it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 1/10
Training Loss: 1.6525
Validation Accuracy: 0.2860

Confusion Matrix:
[[  0   0   0 214]
 [  0   0   0 189]
 [  0   0   0 266]
 [  0   0   0 268]]

Classification Report:
              precision    recall  f1-score   support

      spring       0.00      0.00      0.00       214
      summer       0.00      0.00      0.00       189
        fall       0.00      0.00      0.00       266
      winter       0.29      1.00      0.44       268

    accuracy                           0.29       937
   macro avg       0.07      0.25      0.11       937
weighted avg       0.08      0.29      0.13       937



100%|██████████| 81/81 [00:48<00:00,  1.68it/s]
100%|██████████| 15/15 [00:05<00:00,  2.67it/s]



Epoch 2/10
Training Loss: 1.3823
Validation Accuracy: 0.2785

Confusion Matrix:
[[ 42  11  12 149]
 [ 51  11  10 117]
 [ 13   7  36 210]
 [  4   9  83 172]]

Classification Report:
              precision    recall  f1-score   support

      spring       0.38      0.20      0.26       214
      summer       0.29      0.06      0.10       189
        fall       0.26      0.14      0.18       266
      winter       0.27      0.64      0.38       268

    accuracy                           0.28       937
   macro avg       0.30      0.26      0.23       937
weighted avg       0.29      0.28      0.24       937



100%|██████████| 81/81 [00:48<00:00,  1.66it/s]
100%|██████████| 15/15 [00:05<00:00,  2.64it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 3/10
Training Loss: 1.4516
Validation Accuracy: 0.2615

Confusion Matrix:
[[ 10   0   0 204]
 [  8   1   0 180]
 [ 30   1   0 235]
 [ 30   4   0 234]]

Classification Report:
              precision    recall  f1-score   support

      spring       0.13      0.05      0.07       214
      summer       0.17      0.01      0.01       189
        fall       0.00      0.00      0.00       266
      winter       0.27      0.87      0.42       268

    accuracy                           0.26       937
   macro avg       0.14      0.23      0.12       937
weighted avg       0.14      0.26      0.14       937



100%|██████████| 81/81 [00:48<00:00,  1.66it/s]
100%|██████████| 15/15 [00:05<00:00,  2.57it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 4/10
Training Loss: 1.4044
Validation Accuracy: 0.2465

Confusion Matrix:
[[206   0   0   8]
 [184   0   0   5]
 [258   0   0   8]
 [243   0   0  25]]

Classification Report:
              precision    recall  f1-score   support

      spring       0.23      0.96      0.37       214
      summer       0.00      0.00      0.00       189
        fall       0.00      0.00      0.00       266
      winter       0.54      0.09      0.16       268

    accuracy                           0.25       937
   macro avg       0.19      0.26      0.13       937
weighted avg       0.21      0.25      0.13       937



100%|██████████| 81/81 [00:48<00:00,  1.67it/s]
100%|██████████| 15/15 [00:06<00:00,  2.49it/s]



Epoch 5/10
Training Loss: 1.3814
Validation Accuracy: 0.3511

Confusion Matrix:
[[ 93   0   9 112]
 [ 78   0   4 107]
 [ 76   0  14 176]
 [ 35   1  10 222]]

Classification Report:
              precision    recall  f1-score   support

      spring       0.33      0.43      0.38       214
      summer       0.00      0.00      0.00       189
        fall       0.38      0.05      0.09       266
      winter       0.36      0.83      0.50       268

    accuracy                           0.35       937
   macro avg       0.27      0.33      0.24       937
weighted avg       0.29      0.35      0.26       937



100%|██████████| 81/81 [00:48<00:00,  1.66it/s]
100%|██████████| 15/15 [00:05<00:00,  2.53it/s]



Epoch 6/10
Training Loss: 1.3768
Validation Accuracy: 0.3511

Confusion Matrix:
[[ 93   0   9 112]
 [ 78   0   4 107]
 [ 76   0  14 176]
 [ 35   1  10 222]]

Classification Report:
              precision    recall  f1-score   support

      spring       0.33      0.43      0.38       214
      summer       0.00      0.00      0.00       189
        fall       0.38      0.05      0.09       266
      winter       0.36      0.83      0.50       268

    accuracy                           0.35       937
   macro avg       0.27      0.33      0.24       937
weighted avg       0.29      0.35      0.26       937



100%|██████████| 81/81 [00:48<00:00,  1.67it/s]
100%|██████████| 15/15 [00:05<00:00,  2.63it/s]



Epoch 7/10
Training Loss: 1.3801
Validation Accuracy: 0.3767

Confusion Matrix:
[[ 76  44   0  94]
 [ 50  55   0  84]
 [ 76  17   2 171]
 [ 30  15   3 220]]

Classification Report:
              precision    recall  f1-score   support

      spring       0.33      0.36      0.34       214
      summer       0.42      0.29      0.34       189
        fall       0.40      0.01      0.01       266
      winter       0.39      0.82      0.53       268

    accuracy                           0.38       937
   macro avg       0.38      0.37      0.31       937
weighted avg       0.38      0.38      0.30       937



100%|██████████| 81/81 [00:48<00:00,  1.67it/s]
100%|██████████| 15/15 [00:05<00:00,  2.59it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 8/10
Training Loss: 1.3817
Validation Accuracy: 0.3650

Confusion Matrix:
[[  0   2 194  18]
 [  0   4 159  26]
 [  0   2 230  34]
 [  0   1 159 108]]

Classification Report:
              precision    recall  f1-score   support

      spring       0.00      0.00      0.00       214
      summer       0.44      0.02      0.04       189
        fall       0.31      0.86      0.46       266
      winter       0.58      0.40      0.48       268

    accuracy                           0.36       937
   macro avg       0.33      0.32      0.24       937
weighted avg       0.34      0.36      0.27       937



100%|██████████| 81/81 [00:48<00:00,  1.67it/s]
100%|██████████| 15/15 [00:05<00:00,  2.64it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 9/10
Training Loss: 1.3685
Validation Accuracy: 0.3671

Confusion Matrix:
[[ 41 101   0  72]
 [ 36  98   0  55]
 [ 59  46   0 161]
 [ 35  28   0 205]]

Classification Report:
              precision    recall  f1-score   support

      spring       0.24      0.19      0.21       214
      summer       0.36      0.52      0.42       189
        fall       0.00      0.00      0.00       266
      winter       0.42      0.76      0.54       268

    accuracy                           0.37       937
   macro avg       0.25      0.37      0.29       937
weighted avg       0.25      0.37      0.29       937



100%|██████████| 81/81 [00:48<00:00,  1.66it/s]
100%|██████████| 15/15 [00:05<00:00,  2.60it/s]


Epoch 10/10
Training Loss: 1.3274
Validation Accuracy: 0.4045

Confusion Matrix:
[[126  26  45  17]
 [110  29  33  17]
 [101   3  96  66]
 [ 64   1  75 128]]

Classification Report:
              precision    recall  f1-score   support

      spring       0.31      0.59      0.41       214
      summer       0.49      0.15      0.23       189
        fall       0.39      0.36      0.37       266
      winter       0.56      0.48      0.52       268

    accuracy                           0.40       937
   macro avg       0.44      0.40      0.38       937
weighted avg       0.44      0.40      0.39       937




