In [None]:
!git clone https://github.com/shashnkvats/Indofashionclip.git

Cloning into 'Indofashionclip'...
remote: Enumerating objects: 17, done.[K
remote: Counting objects: 100% (17/17), done.[K
remote: Compressing objects: 100% (16/16), done.[K
remote: Total 17 (delta 6), reused 0 (delta 0), pack-reused 0[K
Receiving objects: 100% (17/17), 6.30 KiB | 6.30 MiB/s, done.
Resolving deltas: 100% (6/6), done.


In [None]:
import os
os.chdir('/content/Indofashionclip')

In [None]:
!pip install -r requirements.txt

In [None]:
from google.colab import drive
drive.mount("/content/mydrive")

Drive already mounted at /content/mydrive; to attempt to forcibly remount, call drive.mount("/content/mydrive", force_remount=True).


In [None]:
# Dataset reference: https://www.kaggle.com/datasets/validmodel/indo-fashion-dataset
!unzip -qq '/content/mydrive/MyDrive/Colab Notebooks/COSE474/archive.zip'

In [None]:
!pip install tensorflow-gpu==2.8.0

In [None]:
import json
from PIL import Image

from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import clip
from transformers import CLIPProcessor, CLIPModel

In [None]:
# Choose computation device
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load pre-trained CLIP model
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

In [None]:
# Define a custom dataset
class image_title_dataset():
    def __init__(self, list_image_path,list_txt):
        # Initialize image paths and corresponding texts
        self.image_path = list_image_path
        # Tokenize text using CLIP's tokenizer
        self.title  = clip.tokenize(list_txt)

    def __len__(self):
        return len(self.title)

    def __getitem__(self, idx):
        # Preprocess image using CLIP's preprocessing function
        image = preprocess(Image.open(self.image_path[idx]))
        title = self.title[idx]
        return image, title

In [None]:
# Create train dataloader
json_path = '/content/Indofashionclip/train_data.json'
image_path = '/content/Indofashionclip/images/train/'

with open(json_path, 'r') as f:
    input_data = []
    for line in f:
        obj = json.loads(line)
        input_data.append(obj)

list_image_path = []
list_txt = []
for item in input_data:
  img_path = image_path + item['image_path'].split('/')[-1]
  caption = item['class_label'][:40]
  list_image_path.append(img_path)
  list_txt.append(caption)

dataset = image_title_dataset(list_image_path, list_txt)
train_dataloader = DataLoader(dataset, batch_size=256, shuffle=True) # Define train dataloader

In [None]:
# Create valid dataloader
json_path = '/content/Indofashionclip/val_data.json'
image_path = '/content/Indofashionclip/images/val/'

with open(json_path, 'r') as f:
    input_data = []
    for line in f:
        obj = json.loads(line)
        input_data.append(obj)

list_image_path = []
list_txt = []
for item in input_data:
  img_path = image_path + item['image_path'].split('/')[-1]
  caption = item['class_label'][:40]
  list_image_path.append(img_path)
  list_txt.append(caption)

dataset = image_title_dataset(list_image_path, list_txt)
val_dataloader = DataLoader(dataset, batch_size=256, shuffle=True) # Define valid dataloader

In [None]:
# Function to convert model's parameters to FP32 format
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        p.grad.data = p.grad.data.float()


if device == "cpu":
  model.float()

In [None]:
!pip install wandb

In [None]:
import os
import wandb

# Initialize WandB
wandb.init(project="Set Your Project", name="Set Your Log Name")

# Define variables to keep track of the best model and its corresponding loss
best_loss = float('inf')
best_model_path = 'best_model(30).pth'
epoch_losses = []
validation_losses = []  # List to store validation losses

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2)
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()

# Validation function with accuracy calculation
def validate(model, dataloader, loss_img, loss_txt, device):
    model.eval()
    total_loss_img = 0.0
    total_loss_txt = 0.0
    correct_img = 0
    correct_txt = 0
    total_samples = 0

    with torch.no_grad():
        for batch in dataloader:
            images, texts = batch
            images = images.to(device)
            texts = texts.to(device)

            # Forward pass
            logits_per_image, logits_per_text = model(images, texts)

            # Compute loss
            ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
            loss_img_val = loss_img(logits_per_image, ground_truth)
            loss_txt_val = loss_txt(logits_per_text, ground_truth)

            # Accumulate loss
            total_loss_img += loss_img_val.item()
            total_loss_txt += loss_txt_val.item()

            # Accuracy calculation
            predicted_img = torch.argmax(logits_per_image, 1)
            predicted_txt = torch.argmax(logits_per_text, 1)
            correct_img += (predicted_img == ground_truth).sum().item()
            correct_txt += (predicted_txt == ground_truth).sum().item()

            total_samples += len(images)

    avg_loss_img = total_loss_img / len(dataloader)
    avg_loss_txt = total_loss_txt / len(dataloader)

    # Calculate accuracy
    accuracy_img = correct_img / total_samples
    accuracy_txt = correct_txt / total_samples

    return (avg_loss_img + avg_loss_txt) / 2, accuracy_img, accuracy_txt


# Train the model
num_epochs = 30
for epoch in range(1, num_epochs):
    pbar = tqdm(train_dataloader, total=len(train_dataloader))
    epoch_loss = 0.0

    for batch in pbar:
        optimizer.zero_grad()

        images, texts = batch
        images = images.to(device)
        texts = texts.to(device)

        # Forward pass
        logits_per_image, logits_per_text = model(images, texts)

        # Compute loss
        ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2

        # Backward pass
        total_loss.backward()
        if device == "cpu":
            optimizer.step()
        else:
            convert_models_to_fp32(model)
            optimizer.step()
            clip.model.convert_weights(model)

        epoch_loss += total_loss.item()
        pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss.item():.4f}")

    # Calculate average loss for the epoch
    avg_epoch_loss = epoch_loss / len(train_dataloader)
    epoch_losses.append(avg_epoch_loss)

    # Validate the model on the validation set
    val_loss, acc_img, acc_txt = validate(model, val_dataloader, loss_img, loss_txt, device)
    validation_losses.append(val_loss)

    # Save the best model
    if avg_epoch_loss < best_loss:
        best_loss = avg_epoch_loss
        torch.save(model.state_dict(), best_model_path)

    # Log metrics to WandB
    wandb.log({"epoch": epoch, "loss": avg_epoch_loss, "validation_loss": val_loss, "img_acc": acc_img, "txt_acc": acc_txt})

# Print and save the best loss
print(f"Best Loss: {best_loss:.4f}")
with open('epoch_losses.txt', 'w') as f:
    for epoch, loss in enumerate(epoch_losses):
        f.write(f"Epoch {epoch}/{num_epochs}, Loss: {loss:.4f}\n")

# Save validation losses to a file
with open('validation_losses.txt', 'w') as f:
    for epoch, val_loss in enumerate(validation_losses):
        f.write(f"Epoch {epoch}/{num_epochs}, Validation Loss: {val_loss:.4f}\n")

VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▂▄▅▇█
img_acc,█▁▁▁▁▁
loss,▁█████
txt_acc,█▁▁▁▁▁
validation_loss,▁█████

0,1
epoch,6.0
img_acc,0.004
loss,5.54085
txt_acc,0.004
validation_loss,5.50638


Epoch 1/30, Loss: 1.9902: 100%|██████████| 357/357 [08:37<00:00,  1.45s/it]
Epoch 2/30, Loss: 2.3770: 100%|██████████| 357/357 [08:39<00:00,  1.45s/it]
Epoch 3/30, Loss: 2.0078: 100%|██████████| 357/357 [08:42<00:00,  1.46s/it]
Epoch 4/30, Loss: 2.2480: 100%|██████████| 357/357 [08:38<00:00,  1.45s/it]
Epoch 5/30, Loss: 2.2324: 100%|██████████| 357/357 [08:36<00:00,  1.45s/it]
Epoch 6/30, Loss: 2.2500: 100%|██████████| 357/357 [08:42<00:00,  1.46s/it]
Epoch 7/30, Loss: 2.2285: 100%|██████████| 357/357 [08:39<00:00,  1.46s/it]
Epoch 8/30, Loss: 2.2148: 100%|██████████| 357/357 [08:31<00:00,  1.43s/it]
Epoch 9/30, Loss: 2.0332: 100%|██████████| 357/357 [08:38<00:00,  1.45s/it]
Epoch 10/30, Loss: 2.4023: 100%|██████████| 357/357 [09:04<00:00,  1.53s/it]
Epoch 11/30, Loss: 2.3164: 100%|██████████| 357/357 [09:12<00:00,  1.55s/it]
Epoch 12/30, Loss: 2.0898: 100%|██████████| 357/357 [09:11<00:00,  1.55s/it]
Epoch 13/30, Loss: 2.2422: 100%|██████████| 357/357 [09:10<00:00,  1.54s/it]
Epoch 14

Best Loss: 3.8397


In [None]:
# Create test dataloader
json_path = '/content/Indofashionclip/test_data.json'
image_path = '/content/Indofashionclip/images/test/'

with open(json_path, 'r') as f:
    input_data = []
    for line in f:
        obj = json.loads(line)
        input_data.append(obj)

list_image_path = []
list_txt = []
for item in input_data:
  img_path = image_path + item['image_path'].split('/')[-1]
  caption = item['class_label'][:40]
  list_image_path.append(img_path)
  list_txt.append(caption)

dataset = image_title_dataset(list_image_path, list_txt)
test_dataloader = DataLoader(dataset, batch_size=256, shuffle=True) # Define test dataloader

In [None]:
# Test
test_loss, test_acc_img, test_acc_txt = validate(model, test_dataloader, loss_img, loss_txt, device) # test 함수랑 validate 함수랑 같음
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy - Image: {test_acc_img:.4f}, Text: {test_acc_txt:.4f}")

In [None]:
# Inference Section
path = "/content/Indofashionclip/images/test/7500.jpeg" # Set your path
image = Image.open(path)

image_input = preprocess(image).unsqueeze(0).to(device)
indo_classes = ['saree', 'blouse', 'dhoti_pants', 'dupattas', 'gowns', 'kurta_men', 'leggings_and_salwars', 'lehenga', 'mojaris_men', 'mojaris_women', 'nehru_jackets', 'palazzos', 'petticoats', 'sherwanis', 'women_kurta']
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in indo_classes]).to(device)

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{indo_classes[index]:>16s}: {100 * value.item():.2f}%")

In [None]:
import torch, gc
gc.collect()
torch.cuda.empty_cache() # This code is to empty cuda memory if it occurs out of memory issue when fine-tuning