In [1]:
!pip install torch torchvision



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 10
batch_size = 32


In [4]:
# trasform - download data - dataloader

transform = transforms.Compose([
    transforms.Resize(size=(32,32)),
    transforms.ToTensor(),
])

transform_target = transforms.Compose([
    transforms.Resize(size=(32,32)),
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.OxfordIIITPet(
    root='./data',
    split='trainval',
    target_types='segmentation',
    transform = transform,
    target_transform = transform_target,
    download=True,
    )

test_dataset = torchvision.datasets.OxfordIIITPet(
    root='./data',
    split='test',
    target_types='segmentation',
    transform=transform,
    target_transform = transform_target,
    download=True,
    )


# DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


Downloading https://thor.robots.ox.ac.uk/datasets/pets/images.tar.gz to data/oxford-iiit-pet/images.tar.gz


100%|██████████| 791918971/791918971 [00:37<00:00, 20952064.76it/s]


Extracting data/oxford-iiit-pet/images.tar.gz to data/oxford-iiit-pet
Downloading https://thor.robots.ox.ac.uk/datasets/pets/annotations.tar.gz to data/oxford-iiit-pet/annotations.tar.gz


100%|██████████| 19173078/19173078 [00:02<00:00, 8174607.51it/s] 


Extracting data/oxford-iiit-pet/annotations.tar.gz to data/oxford-iiit-pet


In [5]:
# model U-NET
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=True)
model = model.to(device)

Downloading: "https://github.com/mateuszbuda/brain-segmentation-pytorch/zipball/master" to /root/.cache/torch/hub/master.zip
Downloading: "https://github.com/mateuszbuda/brain-segmentation-pytorch/releases/download/v1.0/unet-e012d006.pt" to /root/.cache/torch/hub/checkpoints/unet-e012d006.pt


In [6]:
# loss and opt

# CE / Focal Loss (better for unlances dataset) / IoU Loss (Jaccard)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [10]:
# training loop

model.train()
for epoch in range(num_epochs):
    for idx, (image, label) in enumerate(train_loader):
        image, label = image.to(device), label.to(device)

        # forward
        output = model(image)

        # loss - backward - op
        label = label*255 # to remove the ToTensor operation
        label = (label==1).float() # 1 foreground all the other points are back
        loss = criterion(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (idx+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')




Epoch [1/10], Step [100/230], Loss: 0.6559
Epoch [1/10], Step [200/230], Loss: 0.6105
Epoch [2/10], Step [100/230], Loss: 0.6570
Epoch [2/10], Step [200/230], Loss: 0.6336
Epoch [3/10], Step [100/230], Loss: 0.6435
Epoch [3/10], Step [200/230], Loss: 0.6446
Epoch [4/10], Step [100/230], Loss: 0.6355
Epoch [4/10], Step [200/230], Loss: 0.6371
Epoch [5/10], Step [100/230], Loss: 0.6517
Epoch [5/10], Step [200/230], Loss: 0.6709
Epoch [6/10], Step [100/230], Loss: 0.6409
Epoch [6/10], Step [200/230], Loss: 0.6401
Epoch [7/10], Step [100/230], Loss: 0.6325
Epoch [7/10], Step [200/230], Loss: 0.6376
Epoch [8/10], Step [100/230], Loss: 0.6548
Epoch [8/10], Step [200/230], Loss: 0.6700
Epoch [9/10], Step [100/230], Loss: 0.6228
Epoch [9/10], Step [200/230], Loss: 0.6193
Epoch [10/10], Step [100/230], Loss: 0.6245
Epoch [10/10], Step [200/230], Loss: 0.6066


In [23]:
# eval

# accuracy / IoU / mAP

model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for idx, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)

        labels = labels*255
        labels = (labels==1).float()

        outputs = (outputs > 0.5).float()
        bs, c, h , w = labels.size()
        total += bs*h*w*c
        correct += (outputs == labels).sum().item()

accuracy = correct / total
print(f'Accuracy on the test set: {100 * accuracy:.2f}%')

Accuracy on the test set: 90.07%
