In [None]:
!pip install albumentations==1.2.1
!pip install timm
!pip install openvino
from typing import Any, Callable, List, Optional, Type, Union
import timm
import os
import zipfile
import pathlib
import sklearn
import torchvision
import random
import torch
import torch.nn.functional as F
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch import nn, optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import torchvision.transforms as tt
from torchvision.utils import make_grid
import torchvision.transforms.functional as TF
from tqdm import tqdm
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from types import SimpleNamespace
from sklearn.model_selection import train_test_split
from sklearn.metrics import log_loss
import cv2
from tqdm import tqdm
import gc
from google.colab.patches import cv2_imshow
from torch.utils.tensorboard import SummaryWriter
from torch.optim.swa_utils import AveragedModel, SWALR
from google.colab import drive
drive.mount('/content/drive')
%matplotlib inline

os.environ['PYTHONHASHSEED'] = str(27)
random.seed(27)
np.random.seed(27)
torch.manual_seed(27)
torch.cuda.manual_seed(27)
torch.cuda.manual_seed_all(27)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm
  Downloading timm-0.6.11-py3-none-any.whl (548 kB)
[K     |████████████████████████████████| 548 kB 30.5 MB/s 
[?25hCollecting huggingface-hub
  Downloading huggingface_hub-0.10.1-py3-none-any.whl (163 kB)
[K     |████████████████████████████████| 163 kB 68.0 MB/s 
Installing collected packages: huggingface-hub, timm
Successfully installed huggingface-hub-0.10.1 timm-0.6.11
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting openvino
  Using cached openvino-2022.2.0-7713-cp37-cp37m-manylinux_2_27_x86_64.whl (26.8 MB)
Installing collected packages: openvino
Successfully installed openvino-2022.2.0
Mounted at /content/drive


#ResNet

In [None]:
class BasicBlock(nn.Module):
    expansion: int = 1

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None
    ):
      super(BasicBlock, self).__init__()
      self.stride = stride
      self.downsample = downsample

      self.conv1 = nn.Conv2d(in_channels, out_channels, (3, 3), (stride, stride), (1, 1), bias=False)
      self.bn1 = nn.BatchNorm2d(out_channels)
      self.relu = nn.ReLU(True)
      self.conv2 = nn.Conv2d(out_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False)
      self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
      identity = x

      out = self.conv1(x)
      out = self.bn1(out)
      out = self.relu(out)

      out = self.conv2(out)
      out = self.bn2(out)

      if self.downsample is not None:
          identity = self.downsample(x)

      out = torch.add(out, identity)
      out = self.relu(out)

      return out

class Bottleneck(nn.Module):
    expansion: int = 4
    def __init__(
        self, 
        in_channels: int, 
        hidden_dims:int, 
        stride: int = 1, 
        identity_downsample: Optional[nn.Module] = None
    ):
      super(Bottleneck, self).__init__()
      self.conv1 = nn.Conv2d(in_channels, hidden_dims, kernel_size=1, stride=1, padding=0, bias=False)
      self.bn1 = nn.BatchNorm2d(hidden_dims)
      self.conv2 = nn.Conv2d(hidden_dims, hidden_dims, kernel_size=3, stride=stride, padding=1, bias=False)
      self.bn2 = nn.BatchNorm2d(hidden_dims)
      self.conv3 = nn.Conv2d(hidden_dims, hidden_dims * self.expansion, kernel_size=1, stride=1, padding=0, bias=False
      )
      self.bn3 = nn.BatchNorm2d(hidden_dims * self.expansion)
      self.relu = nn.ReLU()
      self.identity_downsample = identity_downsample
      self.stride = stride

    def forward(self, x):
      identity = x.clone()

      x = self.conv1(x)
      x = self.bn1(x)
      x = self.relu(x)
      x = self.conv2(x)
      x = self.bn2(x)
      x = self.relu(x)
      x = self.conv3(x)
      x = self.bn3(x)

      if self.identity_downsample is not None:
          identity = self.identity_downsample(identity)

      x += identity
      x = self.relu(x)
      return x


class ResNet(nn.Module):
    def __init__(self, block, layers, image_channels=3, num_classes=1000):
      super(ResNet, self).__init__()
      self.in_channels = 64
      stages = []
      self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
      self.bn1 = nn.BatchNorm2d(64)
      self.relu = nn.ReLU()
      stages.append(nn.Sequential(self.conv1, self.bn1, self.relu))
      self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

      self.layer1 = self._make_layer(
        block, layers[0], hidden_dims=64, stride=1
      )
      stages.append(nn.Sequential(self.maxpool, self.layer1))
      self.layer2 = self._make_layer(
        block, layers[1], hidden_dims=128, stride=2
      )
      stages.append(self.layer2)
      self.layer3 = self._make_layer(
        block, layers[2], hidden_dims=256, stride=2
      )
      stages.append(self.layer3)
      self.layer4 = self._make_layer(
        block, layers[3], hidden_dims=512, stride=2
      )
      stages.append(self.layer4)

      self.stages = nn.ModuleList(stages)

      self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
      self.fc = nn.Linear(512 * block.expansion, num_classes)

    def forward(self, x):
      for stage in self.stages:
        x = stage(x)

      x = self.avgpool(x)
      x = torch.flatten(x, 1)
      x = self.fc(x)

      return x

    def _make_layer(self, block, num_residual_blocks, hidden_dims, stride):
      identity_downsample = None
      layers = []

      if stride != 1 or self.in_channels != hidden_dims * block.expansion:
        identity_downsample = nn.Sequential(
          nn.Conv2d(self.in_channels, hidden_dims * block.expansion, kernel_size=1, stride=stride, bias=False),
          nn.BatchNorm2d(hidden_dims * block.expansion),
        )

      layers.append(
          block(self.in_channels, hidden_dims, stride, identity_downsample)
      )

      self.in_channels = hidden_dims * block.expansion

      for i in range(num_residual_blocks - 1):
        layers.append(block(self.in_channels, hidden_dims))

      return nn.Sequential(*layers)


def ResNet18(img_channel=3, num_classes=1000):
  return ResNet(BasicBlock, [2, 2, 2, 2], img_channel, num_classes)

def ResNet34(img_channel=3, num_classes=1000):
  return ResNet(BasicBlock, [3, 4, 6, 3], img_channel, num_classes)

def ResNet50(img_channel=3, num_classes=1000):
  return ResNet(Bottleneck, [3, 4, 6, 3], img_channel, num_classes)


def ResNet101(img_channel=3, num_classes=1000):
  return ResNet(Bottleneck, [3, 4, 23, 3], img_channel, num_classes)


def ResNet152(img_channel=3, num_classes=1000):
  return ResNet(Bottleneck, [3, 8, 36, 3], img_channel, num_classes)

#MobilenetV2

In [None]:
def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

class InvertedResidual(nn.Module):
  def __init__(
      self, in_channels: int, out_channels: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None
  ):
    super().__init__()
    self.stride = stride
    if stride not in [1, 2]:
        raise ValueError(f"stride should be 1 or 2 insted of {stride}")

    if norm_layer is None:
        norm_layer = nn.BatchNorm2d

    hidden_dims = int(round(in_channels * expand_ratio))
    self.use_res_connect = ((self.stride == 1) and (in_channels == out_channels))

    layers: List[nn.Module] = []
    if expand_ratio != 1:
      layers.append(
        nn.Sequential(
          nn.Conv2d(in_channels, hidden_dims, kernel_size=1, stride=1),
          nn.BatchNorm2d(hidden_dims),
          nn.ReLU6()
        ),
      )
    layers.extend(
      [
        nn.Sequential(
          nn.Conv2d(hidden_dims, hidden_dims, kernel_size=3, stride=self.stride, padding=1, groups=hidden_dims),
          nn.BatchNorm2d(hidden_dims),
          nn.ReLU6()
        ),
        nn.Conv2d(hidden_dims, out_channels, 1, 1, 0, bias=False),
        norm_layer(out_channels),
      ]
    )
    self.conv = nn.Sequential(*layers)
    self.out_channels = out_channels

  def forward(self, x):
    if self.use_res_connect:
      return x + self.conv(x)
    else:
      return self.conv(x)


In [None]:
class MobileNetV2(nn.Module):
  def __init__(
    self,
    num_classes: int = 1000,
    width_mult: float = 1.0,
    inverted_residual_setting: Optional[List[List[int]]] = None,
    round_nearest: int = 8,
    block: Optional[Callable[..., nn.Module]] = None,
    norm_layer: Optional[Callable[..., nn.Module]] = None,
    dropout: float = 0.2,
  ):
    super().__init__()
    if block is None:
      block = InvertedResidual

    if norm_layer is None:
      norm_layer = nn.BatchNorm2d

    input_channel = 32
    last_channel = 1280

    if inverted_residual_setting is None:
      inverted_residual_setting = [
        # t, c, n, s
        [1, 16, 1, 1],
        [6, 24, 2, 2],
        [6, 32, 3, 2],
        [6, 64, 4, 2],
        [6, 96, 3, 1],
        [6, 160, 3, 2],
        [6, 320, 1, 1],
      ]

    if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
      raise ValueError(
        f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}"
      )


    input_channel = _make_divisible(input_channel * width_mult, round_nearest)
    self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
    stages: List[nn.Module] = [
      nn.Sequential(
        nn.Conv2d(3, input_channel, kernel_size=3, stride=2),
        nn.BatchNorm2d(input_channel),
        nn.ReLU6()
      ),
    ]

    for t, c, n, s in inverted_residual_setting:
      output_channel = _make_divisible(c * width_mult, round_nearest)
      stage = []
      for i in range(n):
        stride = s if i == 0 else 1
        stage.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
        input_channel = output_channel
      stages.append(nn.Sequential(*stage))


    stages.append(
      nn.Sequential(
          nn.Conv2d(input_channel, self.last_channel, kernel_size=1, stride=stride),
          nn.BatchNorm2d(self.last_channel),
          nn.ReLU6()
        ),
    )

    self.stages = nn.ModuleList(stages)

    # building classifier
    self.classifier = nn.Sequential(
        nn.Dropout(p=dropout),
        nn.Linear(self.last_channel, num_classes),
    )

    # weight initialization
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out")
        if m.bias is not None:
          nn.init.zeros_(m.bias)
      elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)
      elif isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, 0, 0.01)
        nn.init.zeros_(m.bias)

  def forward(self, x):
    for stage in self.stages:
      x = stage(x)
    x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
    x = torch.flatten(x, 1)
    x = self.classifier(x)
    return x

In [None]:
# model = MobileNetV2()
# batch = torch.randn((1,3,256,256))
# res = model(batch)
# print(res.shape)

#Unet

In [None]:
class ResNetEncoder(ResNet):
  def __init__(self, block = BasicBlock, layers = [3, 4, 6, 3], image_channels = 3, out_channels = [3, 64, 64, 128, 256, 512]):
    super().__init__(block, layers, image_channels)
    self._out_channels = out_channels
    self.identity = nn.Identity()
  
  def forward(self, x):
    features = []
    x = self.identity(x)
    features.append(x)
    for stage in self.stages:
      x = stage(x)
      features.append(x)

    return x, features

class MobileNetV2Encoder(MobileNetV2):
  def __init__(self, out_channels = [3, 32, 16, 24, 32, 64, 96, 160, 320, 1280]):
    super().__init__()
    self._out_channels = out_channels
    self.identity = nn.Identity()
  
  def forward(self, x):
    features = []
    x = self.identity(x)
    features.append(x)
    for stage in self.stages:
      x = stage(x)
      features.append(x)

    return x, features

class Conv_block(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()

    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(out_channels)
    self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(out_channels)
    self.relu = nn.ReLU(inplace=True)

  def forward(self, inputs):
    x = self.conv1(inputs)
    x = self.bn1(x)
    x = self.relu(x)

    x = self.conv2(x)
    x = self.bn2(x)
    x = self.relu(x)

    return x

class Downsample(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()

    self.conv = Conv_block(in_channels, out_channels)
    self.pool = nn.MaxPool2d((2, 2))

  def forward(self, inputs):
    x = self.conv(inputs)
    p = self.pool(x)

    return x, p


class Upsample(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()

    self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
    self.conv = Conv_block(out_channels + out_channels, out_channels)

  def forward(self, inputs, skip):
    x = self.up(inputs)
    if x.shape != skip.shape:
      x = TF.resize(x, size=skip.shape[2:])
    x = torch.cat([x, skip], axis=1)
    x = self.conv(x)

    return x

class Decoder(nn.Module):
  def __init__(self, encoder_out_channels, num_classes=1):
    super().__init__()
    in_channels = encoder_out_channels[-1]
    self.transition = Conv_block(in_channels, in_channels*2)

    encoder_out_channels = encoder_out_channels[::-1]
    in_channels = in_channels*2
    ups = []
    for out_channels in encoder_out_channels:
      ups.append(Upsample(in_channels, out_channels))
      in_channels = out_channels
    self.ups = nn.ModuleList(ups)

    self.outputs = nn.Conv2d(in_channels, num_classes, kernel_size=1, padding=0)

  def forward(self, inputs, skips):
    x = self.transition(inputs)


    for  up in self.ups:
      skip = skips.pop() 
      x = up(x, skip)


    outputs = self.outputs(x)
    return outputs

class UnetEncoder(nn.Module):
  def __init__(self, encoder_out_channels = [64, 128, 256, 512], in_channels=3):
    super().__init__()
    self._out_channels = encoder_out_channels

    downs = []
    for out_channels in self._out_channels:
      downs.append(Downsample(in_channels, out_channels))
      in_channels = out_channels
    
    self.downs = nn.ModuleList(downs)

  def forward(self, x):
    skips = []
    for down in self.downs:
      skip, x = down(x)
      skips.append(skip)
    
    return x, skips

class unet(nn.Module):
  def __init__(self, encoder_type = UnetEncoder, in_channels=3, num_classes=1):
    super().__init__()

    self.encoder = encoder_type()
    self.decoder = Decoder(self.encoder._out_channels, num_classes)

  def forward(self, inputs):
    x, skips = self.encoder(inputs)
    output = self.decoder(x, skips)
    return output

In [None]:
# model = unet()
# batch = torch.randn((1,3,512,512))
# res = model(batch)
# print(res.shape)

#Config

In [None]:
def seed_everything(seed=27):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
seed_everything()
WEIGHT_DECAY = 1e-4
LEARNING_RATE = 3e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
NUM_EPOCHS = 20
NUM_WORKERS = 2
IMAGE_HEIGHT = 720  # 1080 originally
IMAGE_WIDTH = 1088  # 1920 originally
PIN_MEMORY = True
LOAD_MODEL = False

# Dataset

In [None]:
zipFile = zipfile.ZipFile('/content/drive/MyDrive/Datasets/synthetic.zip', 'r')
zipFile.extractall('dataset')
zipFile.close()

In [None]:
class SyntheticDataset(Dataset):
  def __init__(self, list_path, dir="/content/dataset/synthetic", transform=None ):
    self.file_path = list_path
    self.dir = dir
    self.transforms = transform
    self.images_path = self.dir+"/images"
    self.masks_path = self.dir+"/masks"
    self.file = open(self.file_path, encoding='utf-8')
    self.text = self.file.read()
    self.file.close
    self.image_descriptions = self.text.split("\n")
    if '' in self.image_descriptions:
      self.image_descriptions.remove('')
    

  def __getitem__(self, index):
    img_path = self.images_path+"/Img_"+self.image_descriptions[index]+".jpeg"
    mask_path = self.masks_path+"/Mask_"+self.image_descriptions[index]+".png"
    image = np.array(Image.open(img_path).convert("RGB"))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    mask = np.array(Image.open(mask_path).convert("L"))
    mask[mask == 107] = 0
    mask[mask == 195] = 1
    mask[mask == 88] = 2
    mask[mask == 70] = 3
    mask[mask == 225] = 4
    
    if self.transforms is not None:
      augmentations = self.transforms(image=image, mask=mask)
      image = augmentations["image"]
      mask = augmentations["mask"]
    return image, mask

  
  def __len__(self):
    return len(self.image_descriptions)

In [None]:
seed_everything()
mean = [0.5585, 0.6353, 0.6439]
std = [0.1218, 0.1545, 0.1860]
train_transforms1 = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.9),
            A.VerticalFlip(p=0.9),
            A.Normalize(
                mean=mean,
                std=std,
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

train_transforms2 = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            # A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25),
            # A.Blur(blur_limit=5, p=0.7),
            # A.RandomBrightnessContrast(p=0.5),
            # A.ChannelShuffle(p=0.7),
            # A.Solarize(threshold=128, p=0.5),
            # A.GaussNoise(),
            # A.InvertImg(p=0.7),
            A.Normalize(
                mean=mean,
                std=std,
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

train_transforms3 = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            # A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25),
            # A.Blur(blur_limit=5, p=0.7),
            # A.RandomBrightnessContrast(p=0.5),
            # A.ChannelShuffle(p=0.7),
            A.Solarize(threshold=128, p=0.5),
            A.GaussNoise(),
            A.InvertImg(p=0.7),
            A.Normalize(
                mean=mean,
                std=std,
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )
val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=mean,
                std=std,
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

In [None]:
seed_everything()
train_dataset1 = SyntheticDataset(
        list_path = "/content/dataset/synthetic/lists/train_lst.txt",
        transform=train_transforms1,
    )

train_dataset2 = SyntheticDataset(
        list_path = "/content/dataset/synthetic/lists/train_lst.txt",
        transform=train_transforms2,
    )

train_dataset3 = SyntheticDataset(
        list_path = "/content/dataset/synthetic/lists/train_lst.txt",
        transform=train_transforms3,
    )

train_dataset = torch.utils.data.ConcatDataset([train_dataset1, train_dataset2, train_dataset3])

train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        shuffle=True,
    )

val_dataset = SyntheticDataset(
        list_path = "/content/dataset/synthetic/lists/val_lst.txt",
        transform=val_transforms,
    )

val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        shuffle=False,
    )

#Assets

In [None]:
def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    jaccard_idx = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            preds = torch.softmax(model(x), 1)
            _, preds = torch.max(preds, dim=1)
            preds = preds.type(torch.float)
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * num_correct +1e-8) / (
                2 * num_correct + num_pixels - num_correct + 1e-8
            )
            jaccard_idx += (num_correct + 1e-8) / (
                num_correct + (preds != y).sum()+ 1e-8
            )
    acc = num_correct/num_pixels
    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    dice = dice_score/len(loader)
    print(f"Dice score: {dice_score/len(loader)}")

    jaccard = jaccard_idx/len(loader)
    print(f"Jaccard index: {jaccard}")
    model.train()

    return acc, dice, jaccard

#Training

In [None]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    seed_everything()
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.type(torch.LongTensor)
        targets = targets.to(device=DEVICE)

        with torch.cuda.amp.autocast():
          predictions = model(data)
          loss = loss_fn(predictions, targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loop.set_postfix(loss=loss.item())

In [None]:
seed_everything()
weights = torch.tensor([0.4, 1.9, 0.7, 0.4, 1.5], device=DEVICE)
ClassicUnet = unet(encoder_type = UnetEncoder, num_classes=5)
ClassicUnet = ClassicUnet.to(DEVICE)
loss_fn = nn.CrossEntropyLoss(weight = weights)
optimizer = optim.Adam(ClassicUnet.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scaler = torch.cuda.amp.GradScaler()


writer = SummaryWriter("/content/boards/classicunet")
step = 0


for epoch in range(NUM_EPOCHS):
  train_fn(train_loader, ClassicUnet, optimizer, loss_fn, scaler)


  val_acc, dice_score, jaccard_idx = check_accuracy(val_loader, ClassicUnet, device=DEVICE)


  writer.add_scalar("Validation accuracy", val_acc, global_step=step)
  writer.add_scalar("Dice score", dice_score, global_step=step)
  writer.add_scalar("Jaccard index", jaccard_idx, global_step=step) 
  step += 1
  
writer.close()
gc.collect()
torch.cuda.empty_cache()

100%|██████████| 75/75 [01:41<00:00,  1.35s/it, loss=0.736]


Got 11275616/12533760 with acc 89.96
Dice score: 0.9413691163063049
Jaccard index: 0.9413865804672241


100%|██████████| 75/75 [01:34<00:00,  1.26s/it, loss=0.575]


Got 11190497/12533760 with acc 89.28
Dice score: 0.9376745223999023
Jaccard index: 0.9377182722091675


100%|██████████| 75/75 [01:34<00:00,  1.27s/it, loss=0.467]


Got 11030288/12533760 with acc 88.00
Dice score: 0.9235711693763733
Jaccard index: 0.9255033731460571


100%|██████████| 75/75 [01:34<00:00,  1.26s/it, loss=0.385]


Got 9366227/12533760 with acc 74.73
Dice score: 0.8349645733833313
Jaccard index: 0.8419570922851562


100%|██████████| 75/75 [01:34<00:00,  1.26s/it, loss=0.201]


Got 11826185/12533760 with acc 94.35
Dice score: 0.9683389663696289
Jaccard index: 0.9679786562919617


100%|██████████| 75/75 [01:34<00:00,  1.25s/it, loss=0.172]


Got 11747745/12533760 with acc 93.73
Dice score: 0.9642733335494995
Jaccard index: 0.9639766216278076


100%|██████████| 75/75 [01:34<00:00,  1.26s/it, loss=0.139]


Got 11850362/12533760 with acc 94.55
Dice score: 0.9681392908096313
Jaccard index: 0.9679662585258484


100%|██████████| 75/75 [01:34<00:00,  1.25s/it, loss=0.113]


Got 11992076/12533760 with acc 95.68
Dice score: 0.9756700992584229
Jaccard index: 0.9753456115722656


100%|██████████| 75/75 [01:34<00:00,  1.25s/it, loss=0.11]


Got 11785158/12533760 with acc 94.03
Dice score: 0.9644649028778076
Jaccard index: 0.9645889401435852


100%|██████████| 75/75 [01:33<00:00,  1.25s/it, loss=0.1]


Got 12039095/12533760 with acc 96.05
Dice score: 0.9778341054916382
Jaccard index: 0.9775028228759766


100%|██████████| 75/75 [01:33<00:00,  1.25s/it, loss=0.187]


Got 10558999/12533760 with acc 84.24
Dice score: 0.9138653874397278
Jaccard index: 0.912909209728241


100%|██████████| 75/75 [01:34<00:00,  1.25s/it, loss=0.114]


Got 11866440/12533760 with acc 94.68
Dice score: 0.9689728021621704
Jaccard index: 0.9689285159111023


100%|██████████| 75/75 [01:33<00:00,  1.25s/it, loss=0.0951]


Got 11997128/12533760 with acc 95.72
Dice score: 0.9762411117553711
Jaccard index: 0.9758787751197815


100%|██████████| 75/75 [01:33<00:00,  1.25s/it, loss=0.0838]


Got 12056785/12533760 with acc 96.19
Dice score: 0.9785735607147217
Jaccard index: 0.9782595634460449


100%|██████████| 75/75 [01:33<00:00,  1.25s/it, loss=0.0772]


Got 12024406/12533760 with acc 95.94
Dice score: 0.9768967032432556
Jaccard index: 0.9766157865524292


100%|██████████| 75/75 [01:33<00:00,  1.25s/it, loss=0.0936]


Got 9397192/12533760 with acc 74.98
Dice score: 0.8483207821846008
Jaccard index: 0.8514636754989624


100%|██████████| 75/75 [01:34<00:00,  1.26s/it, loss=0.22]


Got 11044173/12533760 with acc 88.12
Dice score: 0.9309117794036865
Jaccard index: 0.9311120510101318


100%|██████████| 75/75 [01:34<00:00,  1.26s/it, loss=0.0976]


Got 11925417/12533760 with acc 95.15
Dice score: 0.9726930856704712
Jaccard index: 0.9723756909370422


100%|██████████| 75/75 [01:33<00:00,  1.25s/it, loss=0.0778]


Got 12070803/12533760 with acc 96.31
Dice score: 0.9791549444198608
Jaccard index: 0.9788599014282227


100%|██████████| 75/75 [01:33<00:00,  1.25s/it, loss=0.0755]


Got 12076403/12533760 with acc 96.35
Dice score: 0.9791728258132935
Jaccard index: 0.9789113402366638


In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir /content/boards/classicunet

In [None]:
seed_everything()
weights = torch.tensor([0.4, 1.9, 0.7, 0.4, 1.5], device=DEVICE)
UnetWithResNetEncoder = unet(encoder_type = ResNetEncoder, num_classes=5)
UnetWithResNetEncoder = UnetWithResNetEncoder.to(DEVICE)
loss_fn = nn.CrossEntropyLoss(weight = weights)
optimizer = optim.Adam(UnetWithResNetEncoder.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scaler = torch.cuda.amp.GradScaler()


writer = SummaryWriter("/content/boards/resnet")
step = 0


for epoch in range(NUM_EPOCHS):
  train_fn(train_loader, UnetWithResNetEncoder, optimizer, loss_fn, scaler)


  val_acc, dice_score, jaccard_idx = check_accuracy(val_loader, UnetWithResNetEncoder, device=DEVICE)


  writer.add_scalar("Validation accuracy", val_acc, global_step=step)
  writer.add_scalar("Dice score", dice_score, global_step=step)
  writer.add_scalar("Jaccard index", jaccard_idx, global_step=step) 
  step += 1
  
writer.close()
gc.collect()
torch.cuda.empty_cache()

100%|██████████| 75/75 [00:36<00:00,  2.03it/s, loss=1.92]


Got 3269114/12533760 with acc 26.08
Dice score: 0.4453437924385071
Jaccard index: 0.46349892020225525


100%|██████████| 75/75 [00:36<00:00,  2.05it/s, loss=1.88]


Got 3412863/12533760 with acc 27.23
Dice score: 0.45821642875671387
Jaccard index: 0.4762350618839264


100%|██████████| 75/75 [00:37<00:00,  2.03it/s, loss=1.86]


Got 3416089/12533760 with acc 27.26
Dice score: 0.4585648775100708
Jaccard index: 0.47655874490737915


100%|██████████| 75/75 [00:36<00:00,  2.04it/s, loss=1.83]


Got 3408360/12533760 with acc 27.19
Dice score: 0.45778408646583557
Jaccard index: 0.4758162498474121


100%|██████████| 75/75 [00:38<00:00,  1.94it/s, loss=1.77]


Got 3296807/12533760 with acc 26.30
Dice score: 0.4487047493457794
Jaccard index: 0.4665164351463318


100%|██████████| 75/75 [00:36<00:00,  2.04it/s, loss=1.73]


Got 3411188/12533760 with acc 27.22
Dice score: 0.4582372307777405
Jaccard index: 0.4762048125267029


100%|██████████| 75/75 [00:37<00:00,  2.02it/s, loss=1.66]


Got 3405462/12533760 with acc 27.17
Dice score: 0.45765650272369385
Jaccard index: 0.4756571650505066


100%|██████████| 75/75 [00:36<00:00,  2.05it/s, loss=1.56]


Got 3414220/12533760 with acc 27.24
Dice score: 0.45834091305732727
Jaccard index: 0.4763602018356323


100%|██████████| 75/75 [00:36<00:00,  2.04it/s, loss=1.47]


Got 3459774/12533760 with acc 27.60
Dice score: 0.4623924493789673
Jaccard index: 0.48034054040908813


100%|██████████| 75/75 [00:38<00:00,  1.96it/s, loss=1.4]


Got 5807357/12533760 with acc 46.33
Dice score: 0.6467040777206421
Jaccard index: 0.6569055318832397


100%|██████████| 75/75 [00:37<00:00,  2.02it/s, loss=1.28]


Got 6894432/12533760 with acc 55.01
Dice score: 0.7097415924072266
Jaccard index: 0.7189478278160095


100%|██████████| 75/75 [00:36<00:00,  2.06it/s, loss=1.18]


Got 8932702/12533760 with acc 71.27
Dice score: 0.8274720907211304
Jaccard index: 0.8304508924484253


100%|██████████| 75/75 [00:36<00:00,  2.03it/s, loss=1.1]


Got 9706785/12533760 with acc 77.45
Dice score: 0.8619572520256042
Jaccard index: 0.8650549054145813


100%|██████████| 75/75 [00:36<00:00,  2.04it/s, loss=0.984]


Got 10554446/12533760 with acc 84.21
Dice score: 0.9054545760154724
Jaccard index: 0.9067991971969604


100%|██████████| 75/75 [00:38<00:00,  1.94it/s, loss=0.902]


Got 10775451/12533760 with acc 85.97
Dice score: 0.9171605706214905
Jaccard index: 0.917889416217804


100%|██████████| 75/75 [00:36<00:00,  2.04it/s, loss=0.805]


Got 10737312/12533760 with acc 85.67
Dice score: 0.9166252613067627
Jaccard index: 0.9170735478401184


100%|██████████| 75/75 [00:36<00:00,  2.03it/s, loss=0.744]


Got 10787394/12533760 with acc 86.07
Dice score: 0.9186697006225586
Jaccard index: 0.9191606044769287


100%|██████████| 75/75 [00:37<00:00,  2.02it/s, loss=0.762]


Got 10728603/12533760 with acc 85.60
Dice score: 0.9168728590011597
Jaccard index: 0.9170849323272705


100%|██████████| 75/75 [00:36<00:00,  2.03it/s, loss=0.643]


Got 10838727/12533760 with acc 86.48
Dice score: 0.9219534993171692
Jaccard index: 0.9222301244735718


100%|██████████| 75/75 [00:37<00:00,  1.99it/s, loss=0.597]


Got 10880189/12533760 with acc 86.81
Dice score: 0.9233851432800293
Jaccard index: 0.923737645149231


In [None]:
%tensorboard --logdir /content/boards/resnet

In [None]:
seed_everything()
weights = torch.tensor([0.4, 1.9, 0.7, 0.4, 1.5], device=DEVICE)
UnetWithMobileNetV2Encoder = unet(encoder_type = MobileNetV2Encoder, num_classes=5)
UnetWithMobileNetV2Encoder = UnetWithMobileNetV2Encoder.to(DEVICE)
loss_fn = nn.CrossEntropyLoss(weight = weights)
optimizer = optim.Adam(UnetWithMobileNetV2Encoder.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scaler = torch.cuda.amp.GradScaler()


writer = SummaryWriter("/content/boards/mobilenetv2")
step = 0


for epoch in range(NUM_EPOCHS):
  train_fn(train_loader, UnetWithMobileNetV2Encoder, optimizer, loss_fn, scaler)


  val_acc, dice_score, jaccard_idx = check_accuracy(val_loader, UnetWithMobileNetV2Encoder, device=DEVICE)


  writer.add_scalar("Validation accuracy", val_acc, global_step=step)
  writer.add_scalar("Dice score", dice_score, global_step=step)
  writer.add_scalar("Jaccard index", jaccard_idx, global_step=step) 
  step += 1
  
writer.close()
gc.collect()
torch.cuda.empty_cache()

100%|██████████| 75/75 [00:47<00:00,  1.57it/s, loss=1.09]


Got 8775901/12533760 with acc 70.02
Dice score: 0.80437171459198
Jaccard index: 0.8126125335693359


100%|██████████| 75/75 [00:47<00:00,  1.58it/s, loss=1.01]


Got 8830281/12533760 with acc 70.45
Dice score: 0.8071077466011047
Jaccard index: 0.815250039100647


100%|██████████| 75/75 [00:46<00:00,  1.60it/s, loss=0.963]


Got 8859921/12533760 with acc 70.69
Dice score: 0.808976411819458
Jaccard index: 0.8169353008270264


100%|██████████| 75/75 [00:46<00:00,  1.60it/s, loss=0.9]


Got 8748645/12533760 with acc 69.80
Dice score: 0.8010371923446655
Jaccard index: 0.8100195527076721


100%|██████████| 75/75 [00:46<00:00,  1.62it/s, loss=0.836]


Got 8931984/12533760 with acc 71.26
Dice score: 0.8126356601715088
Jaccard index: 0.8204694390296936


100%|██████████| 75/75 [00:46<00:00,  1.62it/s, loss=0.777]


Got 9252236/12533760 with acc 73.82
Dice score: 0.8310174345970154
Jaccard index: 0.8374994397163391


100%|██████████| 75/75 [00:46<00:00,  1.63it/s, loss=0.725]


Got 9150734/12533760 with acc 73.01
Dice score: 0.8260908722877502
Jaccard index: 0.8327156901359558


100%|██████████| 75/75 [00:47<00:00,  1.60it/s, loss=0.672]


Got 9022202/12533760 with acc 71.98
Dice score: 0.8193233013153076
Jaccard index: 0.8262869119644165


100%|██████████| 75/75 [00:46<00:00,  1.63it/s, loss=0.628]


Got 9011517/12533760 with acc 71.90
Dice score: 0.818800687789917
Jaccard index: 0.8257733583450317


100%|██████████| 75/75 [00:46<00:00,  1.63it/s, loss=0.588]


Got 9032029/12533760 with acc 72.06
Dice score: 0.8197404742240906
Jaccard index: 0.826711893081665


100%|██████████| 75/75 [00:46<00:00,  1.62it/s, loss=0.548]


Got 9042642/12533760 with acc 72.15
Dice score: 0.8204469680786133
Jaccard index: 0.8273395299911499


100%|██████████| 75/75 [00:46<00:00,  1.62it/s, loss=0.513]


Got 9029110/12533760 with acc 72.04
Dice score: 0.8194764852523804
Jaccard index: 0.8264937996864319


100%|██████████| 75/75 [00:45<00:00,  1.63it/s, loss=0.482]


Got 9049820/12533760 with acc 72.20
Dice score: 0.8208000063896179
Jaccard index: 0.8276822566986084


100%|██████████| 75/75 [00:46<00:00,  1.62it/s, loss=0.457]


Got 9039520/12533760 with acc 72.12
Dice score: 0.8203004598617554
Jaccard index: 0.8271961212158203


100%|██████████| 75/75 [00:46<00:00,  1.63it/s, loss=0.428]


Got 9062675/12533760 with acc 72.31
Dice score: 0.821506142616272
Jaccard index: 0.828344464302063


100%|██████████| 75/75 [00:46<00:00,  1.62it/s, loss=0.407]


Got 9057611/12533760 with acc 72.27
Dice score: 0.8212207555770874
Jaccard index: 0.8280788660049438


100%|██████████| 75/75 [00:46<00:00,  1.63it/s, loss=0.391]


Got 9032695/12533760 with acc 72.07
Dice score: 0.819736123085022
Jaccard index: 0.8267197608947754


100%|██████████| 75/75 [00:46<00:00,  1.63it/s, loss=0.367]


Got 9060187/12533760 with acc 72.29
Dice score: 0.8213868141174316
Jaccard index: 0.8282291889190674


100%|██████████| 75/75 [00:45<00:00,  1.63it/s, loss=0.348]


Got 9059025/12533760 with acc 72.28
Dice score: 0.8213143348693848
Jaccard index: 0.8281625509262085


100%|██████████| 75/75 [00:47<00:00,  1.58it/s, loss=0.331]


Got 9054263/12533760 with acc 72.24
Dice score: 0.8210656046867371
Jaccard index: 0.8279259204864502


In [None]:
%tensorboard --logdir /content/boards/mobilenetv2