In [1]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import gcsfs
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, TensorDataset, random_split
import torchvision.models as models
import os

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [2]:
class GCSImageSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, target_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.target_transform = target_transform
        self.fs = gcsfs.GCSFileSystem()
        self.image_files = self.fs.ls(image_dir)
        self.mask_files = self.fs.ls(mask_dir)

        self.image_files = sorted([f for f in self.image_files if f.endswith(('jpg', 'jpeg', 'png'))])
        self.mask_files = sorted([f for f in self.mask_files if f.endswith(('jpg', 'jpeg', 'png'))])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        mask_path = self.mask_files[idx]

        with self.fs.open(image_path, 'rb') as img_file:
            image = Image.open(img_file).convert('RGB')

        with self.fs.open(mask_path, 'rb') as mask_file:
            mask = Image.open(mask_file).convert('L')

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)

        return image, mask

In [None]:
# mean = 0.0
# std = 0.0
# n_samples = 0

# for images, _ in dataloader:
#     batch_samples = images.size(0)
#     images = images.view(batch_samples, images.size(1), -1)
#     mean += images.mean(2).sum(0)
#     std += images.std(2).sum(0)
#     n_samples += batch_samples

# mean /= n_samples
# std /= n_samples

# print(f"Mean: {mean}")
# print(f"Standard Deviation: {std}")

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.4843, 0.3917, 0.3575], std = [0.2620, 0.2456, 0.2405])
])

target_transform = transforms.Compose([
    transforms.ToTensor()
])
# Training on personal Google Cloud Cluster
dataset = GCSImageSegmentationDataset(
    image_dir='gs://wound-image-seg/train_images/',
    mask_dir='gs://wound-image-seg/train_masks/',
    transform=transform,
    target_transform=target_transform
)

# DataLoader to load the dataset in batches
dataloader = torch.utils.data.DataLoader(dataset, batch_size = 4, shuffle=True)


In [5]:
def image_segmentation_model(model):
    if model == "Deeplab":
      model = models.segmentation.deeplabv3_resnet101(pretrained=True)
      criterion = nn.BCEWithLogitsLoss()
      optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
      model.classifier[4] = torch.nn.Conv2d(256, 1, kernel_size=1)
      return model, criterion, optimizer
    else:
        pass

In [6]:
deeplab_resnet101, criterion, optimizer = image_segmentation_model("Deeplab")
deeplab_resnet101



Downloading: "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth" to /root/.cache/torch/hub/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth


  0%|          | 0.00/233M [00:00<?, ?B/s]

DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Se

In [7]:
def train_model(model, train_loader, optimizer, criterion, num_epochs):
  model.train()
  torch.backends.cudnn.benchmark = False
  torch.backends.cudnn.enabled = False
  model = model.to(device)
  for epoch in range(num_epochs):
      model.train()
      total_loss = 0
      progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch+1}/{num_epochs}')
      for batch_idx, (inputs, labels) in progress_bar:
          inputs, labels = inputs.to(device), labels.to(device)
          optimizer.zero_grad()
          outputs = model(inputs)['out']
          loss = criterion(outputs, labels)
          loss.backward()
          optimizer.step()
          total_loss += loss.item()
          progress_bar.set_postfix({'loss': total_loss/(batch_idx+1)})
      print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader)}')
  return model

In [8]:
trained_deeplab = train_model(deeplab_resnet101, dataloader, optimizer, criterion, num_epochs = 50)

Epoch 1/50:   0%|          | 0/552 [00:00<?, ?it/s]

In [None]:
fs = gcsfs.GCSFileSystem(project = 'long-sonar-426316-c7', )
with fs.open("gs://wound-image-seg/"+f'model/trained_deeplab_resnet101.pt','wb') as f: 
    torch.save(trained_deeplab, f)