## Training UNets in Pytorch

In [1]:
from PIL import Image
from torch.utils.data import Dataset
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
try:
    from torchsummary import summary
except ImportError:
    !pip3 install torchsummary
    from torchsummary import summary
    
import os
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
#add DeepLabV3_ResNet50_Weights
from torchvision.models.segmentation import deeplabv3_resnet50

In [2]:
img_dir = "/Users/salma/Desktop/HC701/HC701 Assignment 3/train-test/training"
mask_dir = "/Users/salma/Desktop/HC701/HC701 Assignment 3/train-test/trainingmask"
# ! mkdir saved_images
# test_dir = '/Users/salma/Desktop/Leuk/data/generated_train_masks/ALL'
# #create dir for masks
# if not os.path.exists(test_dir):
#     os.makedirs(test_dir)

In [3]:
class ConsecutiveConvolution(nn.Module):
    def __init__(self,input_channel,out_channel):
        super(ConsecutiveConvolution,self).__init__()
        self.conv = nn.Sequential(
            
            nn.Conv2d(input_channel,out_channel,3,1,1,bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU6(inplace=True),
            
            nn.Conv2d(out_channel,out_channel,3,1,1,bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU6(inplace=True),            
        
        )
        
    def forward(self,x):
        return self.conv(x)

In [4]:
class UNet(nn.Module):
    def __init__(self,input_channel, output_channel, features = [64,128,256,512]):
        super(UNet,self).__init__()
        self.pool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()

        # initialize the encoder
        for feat in features:
            self.encoder.append(
                ConsecutiveConvolution(input_channel, feat)    
            )
            input_channel = feat
        
        #initialize the decoder 
        for feat in reversed(features):
            # the authors used transpose convolution
            self.decoder.append(nn.ConvTranspose2d(feat*2, feat, kernel_size=2, stride=2))
            self.decoder.append(ConsecutiveConvolution(feat*2, feat))
        
        #bottleneck
        self.bottleneck = ConsecutiveConvolution(features[-1],features[-1]*2)
        
        #output layer
        self.final_layer = nn.Conv2d(features[0],output_channel,kernel_size=1)
        
    def forward(self,x):
        skip_connections = []
        
        #encoding
        for layers in self.encoder:
            x = layers(x)
            #skip connection to be used in recreation 
            skip_connections.append(x)

            x = self.pool(x)
        
        x = self.bottleneck(x)
        
        skip_connections = skip_connections[::-1]
        
        
        for idx in range(0,len(self.decoder),2):
            
            
            x = self.decoder[idx](x)
            skip_connection = skip_connections[idx//2]
            
    
            if x.shape != skip_connection.shape[2:]:
                x = TF.resize(x,size=skip_connection.shape[2:])
            
            concat_skip = torch.cat((skip_connection,x),dim=1)

            x = self.decoder[idx+1](concat_skip)
        
        return self.final_layer(x)
            

In [5]:
class CXRDataset(Dataset):
    '''
    CAUTON: Some masks of the images from img_dir are missing. Hence, only processing those images whose masks are available
    '''
    def __init__(self, image_dir, mask_dir, type="train", split_ratio=0.2, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.masks = os.listdir(mask_dir)
        
        #a very standard "meh" way of train-test split
        if type=="train":
            self.masks = self.masks[:int(len(self.masks)*(1-split_ratio))]

        else:
            self.masks = self.masks[int(len(self.masks)*(1-split_ratio)):]


    def __len__(self):
        return len(self.masks)

    def __getitem__(self, index):
        
        mask_path = os.path.join(self.mask_dir, self.masks[index])
        img_path = os.path.join(self.image_dir, self.masks[index].replace("_mask.png", ".png"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

In [6]:
def save_checkpoint(state, filename="my_checkpoint_448.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def get_loaders(
    image_dir,
    mask_dir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,):
    
    
    train_ds = CXRDataset(
        image_dir=image_dir,
        mask_dir=mask_dir,
        type="train",
        transform=train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = CXRDataset(
        image_dir=image_dir,
        mask_dir=mask_dir,
        type="val",
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )


    return train_loader, val_loader

def check_accuracy(loader, model, device="mps"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            #get x to tensor
            preds = torch.sigmoid(model(x)['out'])
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()

def save_predictions_as_imgs(loader, model, folder="saved_images/", device="mps"):
    
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x)['out'])
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

In [7]:
# hyperparams
lr = 1e-4
dev = "mps"
batch_size = 16
epochs = 10
workers= 0
img_h = 448
img_w = 448
pin_mem= True
load_model = False


In [8]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=dev)
        targets = targets.float().unsqueeze(1).to(device=dev)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)['out']
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())

In [None]:
train_transform = A.Compose(
        [
            A.Resize(height=img_h, width=img_w),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=1.0,
            ),
            ToTensorV2(),
        ],
    )

val_transforms = A.Compose(
    [
        A.Resize(height=img_h, width=img_w),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=1.0,
        ),
        ToTensorV2(),
    ],
)

model = torch.hub.load('pytorch/vision:v0.8.0', 'deeplabv3_resnet50', pretrained=True)
model.classifier[4] = torch.nn.Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))

# If you are working with an auxiliary classifier, you should also adjust it
if model.aux_classifier is not None:
    model.aux_classifier[4] = torch.nn.Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))
    
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-6)

train_loader, val_loader = get_loaders(
    img_dir,
    mask_dir,
    batch_size,
    train_transform,
    val_transforms,
    workers,
    pin_mem,
)
model.to(dev)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

train_loader, val_loader = get_loaders(
    img_dir,
    mask_dir,
    batch_size,
    train_transform,
    val_transforms,
    workers,
    pin_mem,
)

# if LOAD_MODEL:
#     load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)


check_accuracy(val_loader, model, device=dev)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(epochs):
    print(epoch)
    train_fn(train_loader, model, optimizer, loss_fn, scaler)

    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer":optimizer.state_dict(),
    }
    save_checkpoint(checkpoint)

    # check accuracy
    check_accuracy(val_loader, model, device=dev)

    # print some examples to a folder
    save_predictions_as_imgs(val_loader, model, folder="saved_images_model/", device=dev)

In [120]:
load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)

=> Loading checkpoint


## Inference

In [124]:
class InferenceDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = [os.path.join(image_dir, img) for img in os.listdir(image_dir) if img.endswith(".png")]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = self.images[index]
        image = np.array(Image.open(img_path).convert("RGB"))
        if self.transform is not None:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, img_path

# Transform for the inference dataset
infer_transforms = A.Compose(
    [
        A.Resize(height=img_h, width=img_w),
        A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0),
        ToTensorV2(),
    ],
)

In [138]:
def infer_and_save_predictions(model, loader, folder="saved_images/", device="cuda"):
    model.eval()
    for idx, (x, img_path) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x)['out'])
            preds = (preds > 0.5).float()
        preds = preds.cpu().numpy()
        for i in range(len(preds)):
            # We have batch size of 1, so we access 0-th dimension which gives us the predicted mask
            mask = preds[i].squeeze()
            mask_img = Image.fromarray((mask * 255).astype(np.uint8))
            save_path = os.path.join(folder, os.path.basename(img_path[i]).replace(".png", "_mask.png"))
            #check if the path exists
            if not os.path.exists(folder):
                os.makedirs(folder)
            mask_img.save(save_path)

# Instantiate the inference dataset and dataloader
infer_dataset = InferenceDataset(image_dir=test_dir, transform=infer_transforms)
infer_loader = DataLoader(infer_dataset, batch_size=1, shuffle=False)


# Perform inference and save the predicted masks
infer_and_save_predictions(model, infer_loader, folder="masks_NORMAL_train/", device=dev)
