In [None]:
# detection
# https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
# https://chatgpt.com/c/67fa0c48-7acc-8003-bb1b-966899fcc1a4
# pip install tqdm pandas pillow torch torchvision

from glob import glob
from pathlib import Path
from tqdm import tqdm
from PIL import Image
import pandas as pd

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn

### Data preparing


In [2]:
CATEGORIES_FN = "../data/UECFOOD256/category.txt"

with open(CATEGORIES_FN, "r") as file:
    data = file.readlines()

    id2name = {}
    for i in range(1, len(data)):
        category_id, category_name = data[i].split("\t")
        label_id, name = int(category_id), category_name.strip()
        id2name[label_id] = name

id2name

{1: 'rice',
 2: 'eels on rice',
 3: 'pilaf',
 4: "chicken-'n'-egg on rice",
 5: 'pork cutlet on rice',
 6: 'beef curry',
 7: 'sushi',
 8: 'chicken rice',
 9: 'fried rice',
 10: 'tempura bowl',
 11: 'bibimbap',
 12: 'toast',
 13: 'croissant',
 14: 'roll bread',
 15: 'raisin bread',
 16: 'chip butty',
 17: 'hamburger',
 18: 'pizza',
 19: 'sandwiches',
 20: 'udon noodle',
 21: 'tempura udon',
 22: 'soba noodle',
 23: 'ramen noodle',
 24: 'beef noodle',
 25: 'tensin noodle',
 26: 'fried noodle',
 27: 'spaghetti',
 28: 'Japanese-style pancake',
 29: 'takoyaki',
 30: 'gratin',
 31: 'sauteed vegetables',
 32: 'croquette',
 33: 'grilled eggplant',
 34: 'sauteed spinach',
 35: 'vegetable tempura',
 36: 'miso soup',
 37: 'potage',
 38: 'sausage',
 39: 'oden',
 40: 'omelet',
 41: 'ganmodoki',
 42: 'jiaozi',
 43: 'stew',
 44: 'teriyaki grilled fish',
 45: 'fried fish',
 46: 'grilled salmon',
 47: 'salmon meuniere',
 48: 'sashimi',
 49: 'grilled pacific saury',
 50: 'sukiyaki',
 51: 'sweet and sour

In [3]:
images = glob("../data/UECFOOD256/*/*.jpg")
boxes_files = glob("../data/UECFOOD256/*/bb_info.txt")


ids, boxes = [], []
for j in range(len(boxes_files)):
    with open(boxes_files[j], "r") as file:
        data = file.readlines()

        for i in range(1, len(data)):
            line = list(map(int, data[i].split()))
            ids.append(line[0])
            boxes.append(line[1:5])


id2path = {int(Path(image).stem): image for image in images}
id2class = {int(Path(image).stem): int(Path(image).parts[-2]) for image in images}


df = pd.DataFrame({"image_id": ids, "bbox": boxes})
df["image_path"] = df["image_id"].apply(lambda x: id2path.get(x))
df["class"] = df["image_id"].apply(lambda x: id2class.get(x))

df

Unnamed: 0,image_id,bbox,image_path,class
0,190485,"[25, 1, 571, 398]",../data/UECFOOD256/135/190485.jpg,135
1,89365,"[29, 38, 406, 486]",../data/UECFOOD256/135/89365.jpg,135
2,89147,"[88, 49, 442, 325]",../data/UECFOOD256/135/89147.jpg,135
3,89414,"[10, 12, 493, 479]",../data/UECFOOD256/135/89414.jpg,135
4,190903,"[4, 70, 496, 472]",../data/UECFOOD256/135/190903.jpg,135
...,...,...,...,...
31640,16769,"[2, 3, 557, 414]",../data/UECFOOD256/25/16769.jpg,25
31641,16770,"[2, 46, 556, 375]",../data/UECFOOD256/25/16770.jpg,25
31642,16773,"[92, 41, 545, 472]",../data/UECFOOD256/25/16773.jpg,25
31643,16781,"[3, 8, 473, 355]",../data/UECFOOD256/25/16781.jpg,25


### Dataset class


In [4]:
class CustomObjectDetectionDataset(Dataset):
    def __init__(self):
        self.df = df
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )

    def __len__(self):
        return df.shape[0]

    def __getitem__(self, idx):
        r = df.iloc[idx]

        image = self.transform(Image.open(r["image_path"]).convert("RGB"))
        boxes = torch.Tensor(r["bbox"]).reshape(-1, 4)
        labels = torch.tensor([r["class"]], dtype=torch.int64)

        target = {"boxes": boxes, "labels": labels}

        return image, target

### Train


In [5]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [9]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)

device = "cpu"
dataset = CustomObjectDetectionDataset()

dataloader = DataLoader(dataset, batch_size=5, shuffle=True, collate_fn=collate_fn)

model = fasterrcnn_resnet50_fpn(pretrained=True)
model.to(device)

num_classes = df["class"].max() + 1
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = (
    torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
)

model.train()

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

In [10]:
num_epochs = 1

for epoch in range(num_epochs):
    for images, targets in tqdm(dataloader):

        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        losses.backward()
        optimizer.step()

    print(f"Epoch #{epoch} Loss: {losses.item()}")

torch.save(model.state_dict(), "fasterrcnn.pth")

  0%|          | 5/6329 [02:32<53:45:09, 30.60s/it]


KeyboardInterrupt: 