In [None]:
import os
from tqdm import tqdm
import copy
from PIL import Image


import torch
import torch.nn as nn
import torchvision.transforms as v2
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import ViTConfig, ViTImageProcessor, ViTModel

In [None]:
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
print(f'device: {device}')

In [None]:
labels = ['bag', 'bottom', 'dress', 'hat', 'outer', 'shoes', 'top', 'etc']
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for i, label in enumerate(labels)}

In [None]:
ckpt = 'google/vit-base-patch16-224-in21k'
config = ViTConfig.from_pretrained(ckpt)
image_processor = ViTImageProcessor.from_pretrained(ckpt)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_dir, config, image_processor):
        self.label_ids = []
        self.image_paths = []
        self.image_processor = image_processor
        categories = [d for d in os.listdir(image_dir) if not d.startswith('.')]
        for category in categories:
            image_fnames = [f for f in os.listdir(f'{image_dir}/{category}') if not f.startswith('.')]
            for image_fname in image_fnames:
                image_path = f'{image_dir}/{category}/{image_fname}'
                label_id = label2id[category] 
                self.label_ids.append(label_id)
                self.image_paths.append(image_path)
        
        self.transform = v2.Compose([
            v2.Resize((config.image_size, config.image_size)),
            v2.ToTensor(),
            v2.Normalize(mean=self.image_processor.image_mean, std=self.image_processor.image_std),
        ])
        
    def __len__(self):
        return len(self.label_ids)
    
    def __getitem__(self, i):
        image = Image.open(self.image_paths[i]).convert('RGB')
        image = self.transform(image)
        return self.label_ids[i], image
    
dataset = CustomDataset('../crawl/kream_thumbnails', config, image_processor)

In [None]:
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size

generator = torch.Generator().manual_seed(2024)
train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=generator)

print(f'train: {len(train_dataset)}, test: {len(test_dataset)}')

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, pin_memory=True)

In [None]:
class Classifier(nn.Module):
    def __init__(self, num_labels):
        super(Classifier, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.fc = nn.Linear(config.hidden_size, num_labels)
        
    def forward(self, x):
        logits = self.fc(self.vit(x).pooler_output)
        return logits

In [None]:
model = Classifier(num_labels=len(labels)).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

In [None]:
def train(model, dataloader, criterion, optimizer):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader):
        labels, images = batch
        labels = labels.to(device)
        images = images.to(device)
        
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
    return total_loss / len(dataloader)

def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    total = 0
    corrects = 0
    for batch in tqdm(dataloader):
        labels, images = batch
        labels = labels.to(device)
        images = images.to(device)
        
        with torch.no_grad():
            logits = model(images)
        loss = criterion(logits, labels)
        preds = logits.argmax(dim=1)
        
        total += len(labels)
        corrects += (preds == labels).sum().item()
        
        total_loss += loss.item()
        
    return total_loss / len(dataloader), corrects / total

In [None]:
min_loss = float('inf')
early_stop_counter = 0
model_save_dir = './model_ckpt'
if not os.path.isdir(model_save_dir):
    os.makedirs(model_save_dir)

for epoch in range(50):
    train_loss = train(model, train_dataloader, criterion, optimizer)
    val_loss, val_accuracy = evaluate(model, val_dataloader, criterion)
    
    print(f'Epoch: {epoch}, Train loss{train_loss:.3f}, Eval loss: {val_loss:.3f}, Eval accuracy: {val_accuracy:.3f}')
    
    if val_loss < min_loss:
        torch.save(
            copy.deepcopy(model).to(torch.device("cpu")).state_dict(),
            f"{model_save_dir}/classifier.pt",
        )
        min_loss = val_loss
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter == 3:
            print(f"Early stopped at epoch {epoch + 1}")
            break