In [24]:

from tqdm import tqdm
import torch
from torch import nn
from torchvision.ops import box_convert
import yaml
from transformers import OwlViTProcessor, OwlViTForObjectDetection

In [7]:
from src.dataset import get_dataloaders
from src.losses import ContrastiveDetectionLoss

In [8]:
import importlib
import src.dataset

In [9]:
import src.dataset
importlib.reload(src.dataset)

<module 'src.dataset' from '/scratch/sd5251/cap/OWL4PACO/OWL-ViT-Object-Detection/src/dataset.py'>

In [10]:
import src.utils
importlib.reload(src.utils)

<module 'src.utils' from '/scratch/sd5251/cap/OWL4PACO/OWL-ViT-Object-Detection/src/utils.py'>

In [11]:
import src.utils
importlib.reload(src.losses)
from src.losses import ContrastiveDetectionLoss

In [12]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cuda:0'

In [15]:
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") # Image Processor + Text Tokenizer
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
model = model.to(device)

In [16]:
train_dataloader, test_dataloader = get_dataloaders(4, processor)

In [17]:
def get_training_config():
    with open("config.yaml", "r") as stream:
        data = yaml.safe_load(stream)
        return data["training"]

In [18]:
training_cfg = get_training_config()

In [19]:
criterion = ContrastiveDetectionLoss()

In [20]:
optimizer = torch.optim.AdamW(
                model.parameters(),
                lr=float(training_cfg["learning_rate"]),
                weight_decay=training_cfg["weight_decay"],
                )

In [25]:
num_epochs = training_cfg["n_epochs"]
num_training_steps = num_epochs * len(train_dataloader)

progress_bar = tqdm(range(num_training_steps))

  0%|          | 0/100 [00:00<?, ?it/s]

In [25]:
model.train()
for epoch in range(training_cfg["n_epochs"]):
    for i, (inputs, target_labels, boxes, metadata) in enumerate(train_dataloader):
        optimizer.zero_grad()
        
        inputs['input_ids'] = inputs['input_ids'].view(-1,16)
        inputs['attention_mask'] = inputs['attention_mask'].view(-1,16)
        
        inputs = inputs.to(device)
        
        outputs = model(**inputs)
        
        
        logits = outputs["logits"]
        pred_boxes = outputs["pred_boxes"]
        
        batch_size = boxes.shape[0]
        
        target_labels = target_labels.to(device)
        boxes = boxes.to(device)
        
        loss = criterion(logits, pred_boxes, boxes, target_labels, metadata)
        loss.backward()
        optimizer.step()
        progress_bar.update(1)
        progress_bar.set_description(f"Loss: {loss.item():.3f}")

Loss: 473.836: 100%|██████████| 100/100 [00:45<00:00,  2.58it/s]

In [None]:
pred_boxes[:, :, 2:].shape

In [None]:
target_labels = nn.functional.one_hot(torch.zeros(1).to(torch.int64), num_classes=num_queries).to(device)
target_labels = target_labels.repeat(batch_size,1,1)

In [None]:
inputs["input_ids"].shape

In [2]:
from datetime import datetime

In [6]:
datetime.now().strftime("%Y%m%d_%H%M")

'20231116_0021'

In [7]:
! python3 main.py

  0%|                                                  | 0/4500 [00:00<?, ?it/s][34m[1mwandb[0m: Currently logged in as: [33msharad-dargan[0m ([33ma-is-all-we-need[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.16.0
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/scratch/sd5251/cap/OWL4PACO/OWL-ViT-Object-Detection/wandb/run-20231120_215719-nulxo1xy[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mfancy-valley-4[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/a-is-all-we-need/owl-vit[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/a-is-all-we-need/owl-vit/runs/nulxo1xy[0m
Loss: 0.039:   8%|██▎                        | 375/4500 [05:17<58:34,  1.17it/s]Loss: 0.088, Focal Loss: 0.224, BBox Loss: 0.002, GIOU Loss: -0.002
Loss: 0.077:  17%|████▌                      | 750/4500 [10:07<49:58,  1.25it/s]Loss: 0.025