**Code reference:** https://www.appsloveworld.com/python/1383/how-to-extract-foreground-objects-from-coco-dataset-or-open-images-v6-dataset?expand_article=1

### **Preprocess COCO dataset**

In [1]:
import os
import cv2 as cv
import numpy as np

In [2]:
def extract_classwise_instances(samples, output_dir, label_field, ext=".png"):
    print("Extracting object instances...")
    for sample in samples.iter_samples(progress=True):
        img = cv.imread(sample.filepath)
        img_h, img_w, c = img.shape
        for det in sample[label_field].detections:
            mask = det.mask
            [x, y, w, h] = det.bounding_box
            x = int(x * img_w)
            y = int(y * img_h)
            h, w = mask.shape
            mask_img = img[y:y+h, x:x+w, :]
            alpha = mask.astype(np.uint8)*255
            alpha = np.expand_dims(alpha, 2)
            mask_img = np.concatenate((mask_img, alpha), axis=2)

            label = det.label
            label_dir = os.path.join(output_dir, label)
            if not os.path.exists(label_dir):
                os.mkdir(label_dir)
            output_filepath = os.path.join(label_dir, det.id+ext)
            cv.imwrite(output_filepath, mask_img)

In [3]:
def save_composite(samples, output_dir, label_field, ext=".png"):
    print("Saving composite images...")
    for sample in samples.iter_samples(progress=True):
        img = cv.imread(sample.filepath)
        img_h, img_w, c = img.shape
        output_filepath = output_dir

        counter = 0
        for i, det in enumerate(sample[label_field].detections):
            if counter > 0:
              break
            label = det.label
            label_dir = os.path.join(output_dir, label)
            if not os.path.exists(label_dir):
                os.mkdir(label_dir)
            output_filepath = os.path.join(label_dir, det.id+ext)
        cv.imwrite(output_filepath, img)

In [4]:
!pip install fiftyone
!pip install fiftyone-db-ubuntu2204

Collecting fiftyone
  Downloading fiftyone-0.22.0-py3-none-any.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m17.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting aiofiles (from fiftyone)
  Downloading aiofiles-23.2.1-py3-none-any.whl (15 kB)
Collecting argcomplete (from fiftyone)
  Downloading argcomplete-3.1.2-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.5/41.5 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting boto3 (from fiftyone)
  Downloading boto3-1.28.57-py3-none-any.whl (135 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m135.8/135.8 kB[0m [31m21.8 MB/s[0m eta [36m0:00:00[0m
Collecting dacite<1.8.0,>=1.6.0 (from fiftyone)
  Downloading dacite-1.7.0-py3-none-any.whl (12 kB)
Collecting Deprecated (from fiftyone)
  Downloading Deprecated-1.2.14-py2.py3-none-any.whl (9.6 kB)
Collecting ftfy (from fiftyone)
  Downloading ftfy-6.1.1-py3-none-any.whl (53 

In [5]:
import fiftyone as fo
import fiftyone.zoo as foz
from fiftyone import ViewField as F

Migrating database to v0.22.0


INFO:fiftyone.migrations.runner:Migrating database to v0.22.0


In [6]:
dataset_name = "coco-image-example"
if dataset_name in fo.list_datasets():
    fo.delete_dataset(dataset_name)

In [7]:
label_field = "ground_truth"
classes = ["horse", "airplane"]

In [8]:
dataset = foz.load_zoo_dataset(
    "coco-2017",
    split="train",
    label_types=["segmentations"],
    classes=classes,
    max_samples=500,
    shuffle=True,
    label_field=label_field,
    dataset_name=dataset_name,
)

Downloading split 'train' to '/root/fiftyone/coco-2017/train' if necessary


INFO:fiftyone.zoo.datasets:Downloading split 'train' to '/root/fiftyone/coco-2017/train' if necessary


Downloading annotations to '/root/fiftyone/coco-2017/tmp-download/annotations_trainval2017.zip'


INFO:fiftyone.utils.coco:Downloading annotations to '/root/fiftyone/coco-2017/tmp-download/annotations_trainval2017.zip'


 100% |██████|    1.9Gb/1.9Gb [5.0s elapsed, 0s remaining, 418.6Mb/s]       


INFO:eta.core.utils: 100% |██████|    1.9Gb/1.9Gb [5.0s elapsed, 0s remaining, 418.6Mb/s]       


Extracting annotations to '/root/fiftyone/coco-2017/raw/instances_train2017.json'


INFO:fiftyone.utils.coco:Extracting annotations to '/root/fiftyone/coco-2017/raw/instances_train2017.json'


Downloading 500 images


INFO:fiftyone.utils.coco:Downloading 500 images


 100% |██████████████████| 500/500 [16.8s elapsed, 0s remaining, 28.7 images/s]      


INFO:eta.core.utils: 100% |██████████████████| 500/500 [16.8s elapsed, 0s remaining, 28.7 images/s]      


Writing annotations for 500 downloaded samples to '/root/fiftyone/coco-2017/train/labels.json'


INFO:fiftyone.utils.coco:Writing annotations for 500 downloaded samples to '/root/fiftyone/coco-2017/train/labels.json'


Dataset info written to '/root/fiftyone/coco-2017/info.json'


INFO:fiftyone.zoo.datasets:Dataset info written to '/root/fiftyone/coco-2017/info.json'


Loading 'coco-2017' split 'train'


INFO:fiftyone.zoo.datasets:Loading 'coco-2017' split 'train'


 100% |█████████████████| 500/500 [5.2s elapsed, 0s remaining, 106.6 samples/s]      


INFO:eta.core.utils: 100% |█████████████████| 500/500 [5.2s elapsed, 0s remaining, 106.6 samples/s]      


Dataset 'coco-image-example' created


INFO:fiftyone.zoo.datasets:Dataset 'coco-image-example' created


In [9]:
view = dataset.filter_labels(label_field, F("label").is_in(classes))

In [10]:
foreground_output_dir = "/data/foreground_dataset"
composite_output_dir = "/data/composite_dataset"
os.makedirs(foreground_output_dir, exist_ok=True)
os.makedirs(composite_output_dir, exist_ok=True)

In [11]:
extract_classwise_instances(view, foreground_output_dir, label_field)

Extracting object instances...
 100% |█████████████████| 500/500 [8.1s elapsed, 0s remaining, 59.6 samples/s]       


INFO:eta.core.utils: 100% |█████████████████| 500/500 [8.1s elapsed, 0s remaining, 59.6 samples/s]       


In [12]:
save_composite(view, composite_output_dir, label_field)

Saving composite images...
 100% |█████████████████| 500/500 [12.5s elapsed, 0s remaining, 40.3 samples/s]      


INFO:eta.core.utils: 100% |█████████████████| 500/500 [12.5s elapsed, 0s remaining, 40.3 samples/s]      


### **Turn preprocessed images into a custom dataset**

In [13]:
import torch
from torchvision import transforms, datasets

In [17]:
data_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                          std=[0.229, 0.224, 0.225]),
                                     transforms.Resize([400, 600])])

In [18]:
foreground_dataset = datasets.ImageFolder(root='/data/foreground_dataset',
                                          transform=data_transform)

In [19]:
fore_dataset_loader = torch.utils.data.DataLoader(foreground_dataset,
                                                  batch_size=4, shuffle=True,
                                                  num_workers=1)

In [20]:
composite_dataset = datasets.ImageFolder(root='/data/composite_dataset',
                                         transform=data_transform)

In [21]:
generator = torch.Generator().manual_seed(42)
composite_train, composite_test = torch.utils.data.random_split(composite_dataset,
                                                                [int(len(composite_dataset)*0.8),
                                                                 int(len(composite_dataset)*0.2)],
                                                                generator=generator)


In [22]:
composite_dataset_loader = torch.utils.data.DataLoader(composite_dataset,
                                                       batch_size=4, shuffle=True,
                                                       num_workers=1)

In [23]:
print(fore_dataset_loader.dataset.classes)
print(composite_dataset_loader.dataset.classes)

['airplane', 'horse']
['airplane', 'horse']


**Code reference:** https://debuggercafe.com/training-resnet18-from-scratch-using-pytorch/

### **Build ResNet-18**

In [24]:
import torch.nn as nn

from torch import Tensor
from typing import Type

In [25]:
class BasicBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int = 1,
        expansion: int = 1,
        downsample: nn.Module = None
    ) -> None:
        super(BasicBlock, self).__init__()
        # Multiplicative factor for the subsequent conv2d layer's output channels
        # It is 1 for ResNet18 and ResNet34
        self.expansion = expansion
        self.downsample = downsample
        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(
            out_channels,
            out_channels*self.expansion,
            kernel_size=3,
            padding=1,
            bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels*self.expansion)

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return  out

In [26]:
class ResNet(nn.Module):
    def __init__(
        self,
        img_channels: int,
        num_layers: int,
        block: Type[BasicBlock],
        num_classes: int  = 1000
    ) -> None:
        super(ResNet, self).__init__()
        if num_layers == 18:
            # The following `layers` list defines the number of `BasicBlock`
            # to use to build the network and how many basic blocks to stack together
            layers = [2, 2, 2, 2]
            self.expansion = 1

        self.in_channels = 64
        # All ResNets (18 to 152) contain a Conv2d => BN => ReLU for the first
        # three layers. Here, kernel size is 7
        self.conv1 = nn.Conv2d(
            in_channels=img_channels,
            out_channels=self.in_channels,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512*self.expansion, num_classes)

    def _make_layer(
        self,
        block: Type[BasicBlock],
        out_channels: int,
        blocks: int,
        stride: int = 1
    ) -> nn.Sequential:
        downsample = None
        if stride != 1:
            """
            This should pass from `layer2` to `layer4` or
            when building ResNets50 and above. Section 3.3 of the paper
            Deep Residual Learning for Image Recognition
            (https://arxiv.org/pdf/1512.03385v1.pdf).
            """
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.in_channels,
                    out_channels*self.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False
                ),
                nn.BatchNorm2d(out_channels * self.expansion),
            )
        layers = []
        layers.append(
            block(
                self.in_channels, out_channels, stride, self.expansion, downsample
            )
        )
        self.in_channels = out_channels * self.expansion

        for i in range(1, blocks):
            layers.append(block(
                self.in_channels,
                out_channels,
                expansion=self.expansion
            ))
        return nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        # The spatial dimension of the final layer's feature
        # map should be (7, 7) for all ResNets
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

### **Utils**

In [27]:
import matplotlib.pyplot as plt
import os

from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

plt.style.use('ggplot')

In [28]:
def save_plots(train_acc, valid_acc, train_loss, valid_loss, name=None):
    """
    Function to save the loss and accuracy plots to disk.
    """
    # Accuracy plots
    plt.figure(figsize=(10, 7))
    plt.plot(
        train_acc, color='tab:blue', linestyle='-',
        label='train accuracy'
    )
    plt.plot(
        valid_acc, color='tab:red', linestyle='-',
        label='validataion accuracy'
    )
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    if not os.path.exists('outputs'):
        os.makedirs('outputs')
    plt.savefig(os.path.join('outputs', name+'_accuracy.png'))

    # Loss plots
    plt.figure(figsize=(10, 7))
    plt.plot(
        train_loss, color='tab:blue', linestyle='-',
        label='train loss'
    )
    plt.plot(
        valid_loss, color='tab:red', linestyle='-',
        label='validataion loss'
    )
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(os.path.join('outputs', name+'_loss.png'))

### **Set up for training**

In [29]:
from tqdm import tqdm

In [30]:
def train(model, trainloader, optimizer, criterion, device):
    model.train()
    print('Training...')
    train_running_loss = 0.0
    train_running_correct = 0
    counter = 0
    for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):
        counter += 1
        image, labels = data
        image = image.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        # Forward pass
        outputs = model(image)
        # Calculate loss
        loss = criterion(outputs, labels)
        train_running_loss += loss.item()
        # Calculate accuracy
        _, preds = torch.max(outputs.data, 1)
        train_running_correct += (preds == labels).sum().item()
        # Backpropagation
        loss.backward()
        # Update weights
        optimizer.step()

    # Loss & accuracy for the complete epoch
    epoch_loss = train_running_loss / counter
    # epoch_acc = 100. * (train_running_correct / len(trainloader.dataset))
    epoch_acc = 100. * (train_running_correct / len(trainloader.dataset))
    return epoch_loss, epoch_acc

In [31]:
"""
def validate(model, testloader, criterion, device):
    model.eval()
    print('Validation')
    valid_running_loss = 0.0
    valid_running_correct = 0
    counter = 0

    with torch.no_grad():
        for i, data in tqdm(enumerate(testloader), total=len(testloader)):
            counter += 1

            image, labels = data
            image = image.to(device)
            labels = labels.to(device)
            # Forward pass
            outputs = model(image)
            # Calculate loss
            loss = criterion(outputs, labels)
            valid_running_loss += loss.item()
            # Calculate accuracy
            _, preds = torch.max(outputs.data, 1)
            valid_running_correct += (preds == labels).sum().item()

    # Loss & accuracy for the complete epoch
    epoch_loss = valid_running_loss / counter
    epoch_acc = 100. * (valid_running_correct / len(testloader.dataset))
    return epoch_loss, epoch_acc
"""

"\ndef validate(model, testloader, criterion, device):\n    model.eval()\n    print('Validation')\n    valid_running_loss = 0.0\n    valid_running_correct = 0\n    counter = 0\n\n    with torch.no_grad():\n        for i, data in tqdm(enumerate(testloader), total=len(testloader)):\n            counter += 1\n\n            image, labels = data\n            image = image.to(device)\n            labels = labels.to(device)\n            # Forward pass\n            outputs = model(image)\n            # Calculate loss\n            loss = criterion(outputs, labels)\n            valid_running_loss += loss.item()\n            # Calculate accuracy\n            _, preds = torch.max(outputs.data, 1)\n            valid_running_correct += (preds == labels).sum().item()\n\n    # Loss & accuracy for the complete epoch\n    epoch_loss = valid_running_loss / counter\n    epoch_acc = 100. * (valid_running_correct / len(testloader.dataset))\n    return epoch_loss, epoch_acc\n"

### **Training**

In [32]:
import torch.optim as optim
import numpy as np
import random

In [33]:
# Set seed
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
np.random.seed(seed)
random.seed(seed)

In [34]:
epochs = 20
batch_size = 64
learning_rate = 0.01
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [35]:
model = ResNet(img_channels=3, num_layers=18, block=BasicBlock, num_classes=2).to(device)
plot_name = 'ResNet-18'

In [36]:
# Total parameters & trainable parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

# Optimizer
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# Loss function
criterion = nn.CrossEntropyLoss()

11,177,538 total parameters.
11,177,538 training parameters.


In [37]:
# Lists to keep track of losses & accuracies
train_loss = []
train_acc = []

# Start training
for epoch in range(epochs):
    print(f"[INFO]: Epoch {epoch+1} of {epochs}")
    train_epoch_loss, train_epoch_acc = train(model,
                                              composite_dataset_loader,
                                              optimizer,
                                              criterion,
                                              device)
    train_loss.append(train_epoch_loss)
    train_acc.append(train_epoch_acc)
    print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
    print('-'*50)

# Save the loss & accuracy plots
# save_plots(train_acc, valid_acc, train_loss, valid_loss, name=plot_name)
# print('TRAINING COMPLETE')

[INFO]: Epoch 1 of 20
Training...


100%|██████████| 125/125 [00:17<00:00,  7.13it/s]

Training loss: 0.730, training acc: 67.200
--------------------------------------------------
[INFO]: Epoch 2 of 20
Training...



100%|██████████| 125/125 [00:10<00:00, 12.45it/s]

Training loss: 0.731, training acc: 68.000
--------------------------------------------------
[INFO]: Epoch 3 of 20
Training...



100%|██████████| 125/125 [00:09<00:00, 12.53it/s]

Training loss: 0.568, training acc: 72.600
--------------------------------------------------
[INFO]: Epoch 4 of 20
Training...



100%|██████████| 125/125 [00:09<00:00, 12.61it/s]

Training loss: 0.552, training acc: 73.600
--------------------------------------------------
[INFO]: Epoch 5 of 20
Training...



100%|██████████| 125/125 [00:09<00:00, 12.73it/s]

Training loss: 0.553, training acc: 76.000
--------------------------------------------------
[INFO]: Epoch 6 of 20
Training...



100%|██████████| 125/125 [00:09<00:00, 12.65it/s]


Training loss: 0.505, training acc: 75.800
--------------------------------------------------
[INFO]: Epoch 7 of 20
Training...


100%|██████████| 125/125 [00:09<00:00, 12.66it/s]

Training loss: 0.524, training acc: 76.200
--------------------------------------------------
[INFO]: Epoch 8 of 20
Training...



100%|██████████| 125/125 [00:09<00:00, 12.69it/s]

Training loss: 0.458, training acc: 79.800
--------------------------------------------------
[INFO]: Epoch 9 of 20
Training...



100%|██████████| 125/125 [00:09<00:00, 12.65it/s]

Training loss: 0.468, training acc: 80.200
--------------------------------------------------
[INFO]: Epoch 10 of 20
Training...



100%|██████████| 125/125 [00:09<00:00, 12.68it/s]

Training loss: 0.452, training acc: 80.400
--------------------------------------------------
[INFO]: Epoch 11 of 20
Training...



100%|██████████| 125/125 [00:09<00:00, 12.63it/s]

Training loss: 0.439, training acc: 79.200
--------------------------------------------------
[INFO]: Epoch 12 of 20
Training...



100%|██████████| 125/125 [00:10<00:00, 12.45it/s]

Training loss: 0.476, training acc: 79.000
--------------------------------------------------
[INFO]: Epoch 13 of 20
Training...



100%|██████████| 125/125 [00:09<00:00, 12.75it/s]

Training loss: 0.429, training acc: 82.600
--------------------------------------------------
[INFO]: Epoch 14 of 20
Training...



100%|██████████| 125/125 [00:09<00:00, 12.61it/s]


Training loss: 0.399, training acc: 82.200
--------------------------------------------------
[INFO]: Epoch 15 of 20
Training...


100%|██████████| 125/125 [00:09<00:00, 12.68it/s]

Training loss: 0.377, training acc: 84.400
--------------------------------------------------
[INFO]: Epoch 16 of 20
Training...



100%|██████████| 125/125 [00:09<00:00, 12.71it/s]

Training loss: 0.365, training acc: 84.400
--------------------------------------------------
[INFO]: Epoch 17 of 20
Training...



100%|██████████| 125/125 [00:10<00:00, 12.46it/s]

Training loss: 0.360, training acc: 84.800
--------------------------------------------------
[INFO]: Epoch 18 of 20
Training...



100%|██████████| 125/125 [00:09<00:00, 12.55it/s]

Training loss: 0.360, training acc: 84.200
--------------------------------------------------
[INFO]: Epoch 19 of 20
Training...



100%|██████████| 125/125 [00:10<00:00, 12.29it/s]

Training loss: 0.367, training acc: 85.200
--------------------------------------------------
[INFO]: Epoch 20 of 20
Training...



100%|██████████| 125/125 [00:09<00:00, 12.74it/s]

Training loss: 0.370, training acc: 84.000
--------------------------------------------------



