<a href="https://colab.research.google.com/github/pitthexai/IEEE_BHI_2023_Tutorial_From_Few_to_None/blob/main/Code/FewShotSegmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install segmentation-models-pytorch --quiet

In [12]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os

from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils as smp_utils

from zipfile import ZipFile

In [None]:
! wget https://github.com/pitthexai/IEEE_BHI_2023_Tutorial_From_Few_to_None/raw/bcce5fd52b349659fb03fd065f9037e70acc83a9/SampleDataset/BHI_Segmentation.zip
! unzip /content/BHI_Segmentation.zip

In [21]:
DATA_ROOT = "/content/BHI_Segmentation"
RANDOM_STATE = 42

In [22]:
class JointSpaceSegmentationDataset(Dataset):
    def __init__(self, img_root, mask_root, image_files, mask_files, transforms=None, preprocessing=None):
        self.img_root = img_root
        self.mask_root = mask_root
        self.img_files = image_files
        self.mask_files = mask_files
        self.transforms = transforms
        self.preprocessing = preprocessing

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        image = np.array(Image.open(os.path.join(self.img_root, self.img_files[idx])))
        mask = np.array(Image.open(os.path.join(self.mask_root, self.mask_files[idx])))

        # image = np.stack([image, image, image], axis=0)
        # mask = np.expand_dims(mask, axis=0)
        if self.transforms is not None:
            transformed = self.transforms(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
        mask = torch.unsqueeze(mask, 0)
        if self.preprocessing is not None:
            transformed = self.preprocessing(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]

        return image.type(torch.FloatTensor), mask/255.0

In [24]:
from sklearn.model_selection import train_test_split
def generate_datasets(root_dir):
    x_dir = os.path.join(root_dir, "Images")
    y_dir = os.path.join(root_dir, "Annotations")
    records = [[img.split(".")[0][:-1], img, img] for img in os.listdir(x_dir)]

    data_records = pd.DataFrame(records, columns=["pid", "images", "masks"])

    train, test = train_test_split(data_records.pid.unique(), test_size=0.5, random_state=RANDOM_STATE)
    valid, test = train_test_split(test, test_size=0.5, random_state=RANDOM_STATE)

    train = data_records[data_records.pid.isin(train)].reset_index(drop=True)
    valid = data_records[data_records.pid.isin(valid)].reset_index(drop=True)
    test = data_records[data_records.pid.isin(test)].reset_index(drop=True)

    return train, valid, test

In [25]:
def get_few_shot_sample(dataset, k=1, random_state=RANDOM_STATE):
    if k > len(dataset):
        return dataset

    return dataset.sample(k, random_state=random_state).reset_index(drop=True)

In [26]:
train, valid, test = generate_datasets(os.path.join(DATA_ROOT))

In [27]:
train_few = get_few_shot_sample(train, k=5)
valid_few = get_few_shot_sample(valid, k=5)

In [28]:
def get_preprocessing_fn(encoder, encoder_weights):
    return smp.encoders.get_preprocessing_fn(encoder, encoder_weights)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform

    Args:
        preprocessing_fn (callbale): data normalization function
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose

    """

    _transform = [
        A.Lambda(image=preprocessing_fn),
        ToTensorV2(),
    ]
    return A.Compose(_transform)

In [42]:
encoder = "resnet34"
encoder_weights = "imagenet"
activation = "sigmoid"
num_classes = 1 # 0=background, 1=joint space

preprocessing_fn = get_preprocessing_fn(encoder, encoder_weights)
augmentations = A.Compose([A.Resize(256, 256), ToTensorV2()])

In [38]:
train_set = JointSpaceSegmentationDataset(os.path.join(DATA_ROOT, "Images"),
                                          os.path.join(DATA_ROOT, "Annotations"),
                                          train_few.images, train_few.masks,
                                          preprocessing=None,#get_preprocessing(preprocessing_fn),
                                          transforms=augmentations)

valid_set = JointSpaceSegmentationDataset(os.path.join(DATA_ROOT, "Images"),
                                          os.path.join(DATA_ROOT, "Annotations"),
                                          valid_few.images, valid_few.masks,
                                          preprocessing=None,#get_preprocessing(preprocessing_fn),
                                          transforms=augmentations)

test_set = JointSpaceSegmentationDataset(os.path.join(DATA_ROOT, "Images"),
                                         os.path.join(DATA_ROOT, "Annotations"),
                                         test.images, test.masks,
                                         preprocessing=None,#get_preprocessing(preprocessing_fn),
                                         transforms=augmentations)

In [43]:
import copy
model = copy.deepcopy(smp.Unet(encoder_name=encoder, encoder_weights=encoder_weights, in_channels=1,
                 classes=num_classes, activation=activation))
model.encoder.requires_grad_ = True
loss = nn.BCELoss()
loss.__name__="loss"
metrics = [smp_utils.metrics.IoU(threshold=0.5), smp_utils.metrics.Fscore(0.5)]
optimizer = torch.optim.Adam(model.parameters(), lr=5e-04)

In [44]:
train_loader = DataLoader(train_set, batch_size=1, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_set, batch_size=1, shuffle=False, num_workers=2)

In [45]:
# create epoch runners
# it is a simple loop of iterating over dataloader`s samples
train_epoch = smp.utils.train.TrainEpoch(
    model,
    loss=loss,
    metrics=metrics,
    optimizer=optimizer,
    device="cpu",
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model,
    loss=loss,
    metrics=metrics,
    device="cpu",
    verbose=True,
)

In [None]:
# train model for 40 epochs

max_score = 0

for i in range(1, 50):

    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)

    # do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, './best_model.pth')
        print('Model saved!')


In [47]:
test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=2)
test_img, test_mask  = next(iter(test_loader))

In [None]:
avg_iou = 0.0
avg_fscore = 0.0
iou_metric = smp.utils.metrics.IoU(threshold=0.5)
fscore_metric = smp.utils.metrics.Fscore(threshold=0.5)
for img, mask in test_loader:
    out = model(img)
    mask = mask
    avg_iou += iou_metric(out, mask).item()
    avg_fscore += fscore_metric(out, mask).item()

print(avg_iou/len(test_loader), avg_fscore/len(test_loader))