In [2]:
import numpy as np
import torch

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

device

device(type='cuda', index=0)

In [3]:
import torchvision
from torchvision import models
from torchvision import transforms
import os
import glob
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision import models


# when transfer, color value is divided by 255
loader = transforms.Compose([transforms.ToTensor()])  
unloader = transforms.ToPILImage()
preprocess = transforms.Compose([transforms.ToTensor(),
                                 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])


In [4]:
class SegDataset(Dataset):
    def __init__(self, mask_dir, img_dir, img_transform = None):
        self.mask_dir = mask_dir
        self.img_dir = img_dir
        self.n_mask = len(os.listdir(self.mask_dir))
        self.n_img = len(os.listdir(self.img_dir))
        self.img_transform = img_transform

    def __len__(self):     
        return min(self.n_img, self.n_mask)

    def __getitem__(self, idx):
        
        filename = os.listdir(self.img_dir)[idx]
        
        img_path = f"{self.img_dir}{filename}"
        img = Image.open(img_path).convert("RGB")
        if self.img_transform:
            img = self.img_transform(img)        
        
        mask_path = f"{self.mask_dir}{filename}"
        mask = Image.open(mask_path).convert('L')
#         plt.imshow(mask, cmap="gray")
        mask = np.asarray(mask)
        mask[mask>200]=255    # black:0, white: 255
        mask[mask<=200]=0
        mask = mask/255
#         plt.imshow(mask, cmap="gray")
        mask=mask[np.newaxis,:, :]
        mask = torch.Tensor(mask)

        return img, mask

train_mask_dir = f"./images/train/mask/"
train_img_dir = f"./images/train/img/"
training_set = SegDataset(train_mask_dir,train_img_dir, preprocess)
training_generator = DataLoader(training_set, batch_size=2, shuffle=True)

val_mask_dir = f"./images/validation/mask/"
val_img_dir = f"./images/validation/img/"
validation_set = SegDataset(val_mask_dir, val_img_dir, preprocess)
validation_generator = DataLoader(validation_set, batch_size=2, shuffle=True)


In [5]:
def createDeepLabv3(outputchannels=1):
    """DeepLabv3 class with custom head
    Args:
        outputchannels (int, optional): The number of output channels
        in your dataset masks. Defaults to 1.
    Returns:
        model: Returns the DeepLabv3 model with the ResNet101 backbone.
    """
    model = models.segmentation.deeplabv3_resnet101(pretrained=True,
                                                    progress=True)
    model.classifier = DeepLabHead(2048, outputchannels)
    # Set the model in training mode
    model.train()
    return model

model = createDeepLabv3()


In [None]:
model.to(device)
num_epochs = 10

# Specify the loss function
criterion = torch.nn.MSELoss(reduction='mean')
# Specify the optimizer with a lower learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(1, num_epochs + 1):
    total_loss = 0
    optimizer.zero_grad()
    for (img, mask) in iter(training_generator):
        img = img.to(device).to(torch.float32)
        mask = mask.to(device).to(torch.float32)
        output = model(img)['out']
        loss = criterion(output, mask)
        loss.backward()
        optimizer.step()
        total_loss += float(loss)
    print(f"Epoch {epoch}: {total_loss:.4f}")


In [None]:
torch.save(model, "./ignore/trained_model")