In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unzip -d /content/data /content/drive/MyDrive/classified_photos.zip

In [None]:
!pip install wandb transformers torch torchvision scikit-learn

In [72]:
import torch
import pandas as pd
import wandb
import random
from torchvision import transforms
from torchvision.models import resnet50
from torchvision.models import ResNet50_Weights
import os
import sys
# sys.path.append(os.path.abspath(".."))
# from models.ui_dataset import UIDataset
from torch.utils.data import random_split, DataLoader
from transformers import ViTImageProcessor,ViTForImageClassification, TrainingArguments, Trainer
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import accuracy_score, precision_recall_fscore_support


In [50]:
import torch
from torch.utils.data import Dataset
import os
from PIL import Image

class UIDataset(Dataset):
    def __init__(self, root_dir, processor=None):
        self.root_dir = root_dir
        self.processor = processor
        self.image_paths = []
        self.label_to_index = {}  # Dictionary to store label encoding
        self.index_to_label = {}  # (Optional) Reverse mapping for decoding

        # Collect unique labels
        unique_labels = sorted([label for label in os.listdir(self.root_dir) if os.path.isdir(os.path.join(self.root_dir, label))])

        # Create a mapping from label names to integers
        self.label_to_index = {label: idx for idx, label in enumerate(unique_labels)}
        self.index_to_label = {idx: label for label, idx in self.label_to_index.items()}  # Reverse mapping (optional)

        # Store image paths and corresponding encoded labels
        for label in unique_labels:
            subdir_path = os.path.join(self.root_dir, label)
            for filename in os.listdir(subdir_path):
                if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff', '.webp')):
                    file_path = os.path.join(subdir_path, filename)
                    self.image_paths.append((file_path, self.label_to_index[label]))  # Store encoded label

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        file_path, label = self.image_paths[index]  # label is now an integer

        image = Image.open(file_path).convert("RGB")

        if self.processor:
            inputs = self.processor(images=image, return_tensors="pt")
            image_tensor = inputs["pixel_values"].squeeze(0)

        return image_tensor, label  # Label is now an integer


In [51]:
wandb.init(
    project="ui-classification-experiments",
    config={
        "learning_rate": 0.02,
        "architecture": "VIT Transformer",
        "dataset": "DesktopUI",
        "epochs": 10,
    }
)

In [52]:
!export WANDB_API_KEY="7e54b169f751025afccd3ef14aacaf584f797cb1"

In [53]:
os.environ["WANDB_PROJECT"] = "ui-classification-experiments"
os.environ["WANDB_LOG_MODEL"] = "checkpoint"

In [54]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)

In [55]:
labels = ["clean-ui", "ui-to-crop", "unnecessary"]

In [56]:
dataset_path = "/content/data/classified_photos_blip_xgboost"

In [57]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)

In [58]:
ui_dataset_all = UIDataset(root_dir=dataset_path, processor=processor)

ui_train_size = int(len(ui_dataset_all) * 0.6)
ui_val_size = int(len(ui_dataset_all) * 0.2)

train_dataset, val_dataset, test_dataset = random_split(ui_dataset_all, [ui_train_size, ui_val_size, len(ui_dataset_all) - ui_val_size - ui_train_size])

In [59]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [60]:
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [66]:
def collate_fn(batch):
    images, labels = zip(*batch)
    return {
        'pixel_values': torch.stack(images),
        'labels': torch.tensor([int(label) for label in labels])
    }

In [73]:

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average="weighted")

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1
    }


In [75]:
training_args = TrainingArguments(
    output_dir="./vit-base-beans",
    per_device_train_batch_size=16,
    evaluation_strategy="steps",
    num_train_epochs=10,
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    load_best_model_at_end=True,
    report_to="wandb",
    fp16=True
)




In [76]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor,
)

  trainer = Trainer(


In [77]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1 Score
100,0.3734,0.481041,0.834741,0.835451,0.834741,0.834599
200,0.1859,0.404243,0.880579,0.880571,0.880579,0.880489
300,0.2009,0.364972,0.864897,0.865849,0.864897,0.864658
400,0.1383,0.458759,0.858866,0.861446,0.858866,0.858949
500,0.0552,0.610033,0.852835,0.858238,0.852835,0.85261
600,0.0988,0.465402,0.891435,0.89123,0.891435,0.891039
700,0.0115,0.496726,0.874548,0.874928,0.874548,0.873562
800,0.0091,0.494349,0.890229,0.890017,0.890229,0.889937
900,0.034,0.508053,0.892642,0.893233,0.892642,0.892211
1000,0.0138,0.505846,0.885404,0.885924,0.885404,0.885148


[34m[1mwandb[0m: Adding directory to artifact (./vit-base-beans/checkpoint-100)... Done. 6.8s
[34m[1mwandb[0m: Adding directory to artifact (./vit-base-beans/checkpoint-200)... Done. 8.0s
[34m[1mwandb[0m: Adding directory to artifact (./vit-base-beans/checkpoint-300)... Done. 5.9s
[34m[1mwandb[0m: Adding directory to artifact (./vit-base-beans/checkpoint-400)... Done. 5.5s
[34m[1mwandb[0m: Adding directory to artifact (./vit-base-beans/checkpoint-500)... Done. 27.8s
[34m[1mwandb[0m: Adding directory to artifact (./vit-base-beans/checkpoint-600)... Done. 14.6s
[34m[1mwandb[0m: Adding directory to artifact (./vit-base-beans/checkpoint-700)... Done. 5.2s
[34m[1mwandb[0m: Adding directory to artifact (./vit-base-beans/checkpoint-800)... Done. 8.8s
[34m[1mwandb[0m: Adding directory to artifact (./vit-base-beans/checkpoint-900)... Done. 6.5s
[34m[1mwandb[0m: Adding directory to artifact (./vit-base-beans/checkpoint-1000)... Done. 14.0s
[34m[1mwandb[0m: Adding

***** train metrics *****
  epoch                    =         10.0
  total_flos               = 1796328377GF
  train_loss               =        0.087
  train_runtime            =   0:51:31.45
  train_samples_per_second =        8.051
  train_steps_per_second   =        0.505


In [79]:
metrics = trainer.evaluate(val_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)



***** eval metrics *****
  epoch                   =       10.0
  eval_accuracy           =     0.8649
  eval_f1_score           =     0.8647
  eval_loss               =      0.365
  eval_precision          =     0.8658
  eval_recall             =     0.8649
  eval_runtime            = 0:01:03.98
  eval_samples_per_second =     12.956
  eval_steps_per_second   =      1.625
