In [12]:
import agml
import numpy as np

import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split

In [13]:
class MyDataModule(pl.LightningDataModule):
    def __init__(self, train_dataset, val_dataset, test_dataset, batch_size=32, num_workers=4):
        super().__init__()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        # Setup function can be used to split dataset if needed
        # For example, you can split a single dataset into train/val/test here
        pass

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

In [14]:
class AgMLDatasetAdaptor(object):
    """Adapts an AgML dataset for use in a `LightningDataModule`."""

    def __init__(self, loader, adapt_class = False):
        self.loader = loader
        self.adapt_class = adapt_class

    def __len__(self) -> int:
        return len(self.loader)

    def get_image_and_labels_by_idx(self, index):
        image, annotation = self.loader[index]
        image = Image.fromarray(image)
        bboxes = np.array(annotation['bbox']).astype(np.int32)
        x_min = bboxes[:, 0]
        y_min = bboxes[:, 1]
        x_max = bboxes[:, 2] + x_min
        y_max = bboxes[:, 3] + y_min
        x_min, y_min = np.clip(x_min, 0, image.width), np.clip(y_min, 0, image.height)
        x_max, y_max = np.clip(x_max, 0, image.width), np.clip(y_max, 0, image.height)
        bboxes = np.dstack((x_min, y_min, x_max, y_max)).squeeze(axis = 0)
        class_labels = np.array(annotation['category_id']).squeeze()
        if self.adapt_class:
            class_labels = np.ones_like(class_labels)
        return image, bboxes, class_labels, index

In [15]:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.models import resnet18
import pytorch_lightning as pl

class LitDetectorModel(pl.LightningModule):
    
    def __init__(self, num_classes: int = 1, learning_rate: float = 2e-4):
        """Simple Classification model built with PyTorch Lightning.

        Args:
            num_classes (int, optional): Number of classes. Defaults to 1.
            hidden_dim (int, optional): Number of hidden layers. Defaults to 64.
            learning_rate (float, optional): Rate at which to adjust model weights. Defaults to 2e-4.
        """
        # update and save hyperparameters
        super().__init__()
        
        # define properties
        self.save_hyperparameters()
        self.hparams.num_classes = num_classes
        self.backbone = resnet18(pretrained=False)
        self.backbone.fc = nn.Identity()
        self.fc = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 4 + self.hparams.num_classes)  # 4 for bbox coordinates and 1 for class
        )
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.fc(x)
        
    def training_step(self, batch):
        images, targets = batch
        outputs = self(images)
        loss = self.compute_loss(outputs, targets)
        
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch):
        images, targets = batch
        outputs = self(images)
        loss = self.compute_loss(outputs, targets)
        
        self.log('val_loss', loss)
        return loss
    
    def test_step(self, batch):
        images, targets = batch
        outputs = self(images)
        loss = self.compute_loss(outputs, targets)
        
        self.log('test_loss', loss)
        return loss
    
    def compute_loss(self, outputs, targets):
        bboxes_preds = outputs[:, :4]
        class_preds = outputs[:, 4:]
        bboxes_targets = targets[:, :4]
        class_targets = targets[:, 4].long()

        bbox_loss = nn.MSELoss()(bboxes_preds, bboxes_targets)
        class_loss = nn.CrossEntropyLoss()(class_preds, class_targets)
        total_loss = bbox_loss + class_loss
        return total_loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

In [18]:
loader = agml.data.AgMLDataLoader(
    dataset='ghai_iceberg_lettuce_detection',
    dataset_path='/data2/eranario/data/AgML'
)
loader.shuffle()
loader.split(train=0.8, val=0.1, test=0.1)

dm = MyDataModule(
    train_dataset = AgMLDatasetAdaptor(loader.train_data),
    val_dataset = AgMLDatasetAdaptor(loader.val_data),
    test_dataset = AgMLDatasetAdaptor(loader.test_data),
    num_workers = 4, batch_size = 4)
model = LitDetectorModel(num_classes = 1, learning_rate = 2e-4)
trainer = pl.Trainer(max_epochs=1, default_root_dir='/data2/eranario/intermediate_data/Active-Learning/AgML_logs/tests')
trainer.fit(model, dm)

/home/eranario/miniconda3/envs/lightning/lib/python3.10/site-packages/torchvision/models/_utils.py:208: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
/home/eranario/miniconda3/envs/lightning/lib/python3.10/site-packages/torchvision/models/_utils.py:223: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=None`.
06-06-2024 22:29:51 INFO - pytorch_lightning.utilities.rank_zero: GPU available: True (cuda), used: True
06-06-2024 22:29:51 INFO - pytorch_lightning.utilities.rank_zero: TPU available: False, using: 0 TPU cores
06-06-2024 22:29:51 INFO - pytorch_lightning.utilities.rank_zero: IPU available: False, using: 0 IPUs
06-06-2024 22:29:51 INFO - pytorch_lightning.utilities.rank_zero: HPU available: False, using: 0 HPUs
06-06-2024 22:29:51 INFO - pytorch_lightning.accelerators.cuda: LOCAL_RANK: 0 - CUDA_VISI

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/eranario/miniconda3/envs/lightning/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/eranario/miniconda3/envs/lightning/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/eranario/miniconda3/envs/lightning/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
TypeError: 'AgMLDatasetAdaptor' object is not subscriptable
