In [145]:
import logging
import numpy as np
import torch
from PIL import Image
from os import listdir
from os.path import splitext, isfile, join
from pathlib import Path
from torch.utils.data import Dataset
from tqdm import tqdm
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader, random_split
from torch import optim


class BasicDataset(Dataset):
    def __init__(self, images_dir, mask_dir):
        self.images_dir = Path(images_dir)
        self.mask_dir = Path(mask_dir)
       

        self.ids = [splitext(file)[0] for file in listdir(images_dir) if isfile(join(images_dir, file)) and not file.startswith('.')]
        
        if not self.ids:
            raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there')
     
    
    #change preprocess depending on how much tranformations you want
    def img_preprocess(self, image):
        #w, h = pil_img.size
        
        #image = TF.to_tensor(image)
        image = np.asarray(image)
        return image
        
    def mask_preprocess(self, mask):
        #w, h = pil_img.size
        #img = np.asarray(pil_img)
        #mask = TF.to_tensor(mask)
        mask = np.asarray(mask)
        return mask

    def __getitem__(self, idx):
        #name from idx e.g. [0]
        name = self.ids[idx]
        
        
        #the one mask file from the given idx
        mask_file = list(self.mask_dir.glob(name  + '.*'))
        #same but for img_file
        img_file = list(self.images_dir.glob(name + '.*'))

        #makes the file an image
        mask = load_image(mask_file[0])
        img = load_image(img_file[0])

        #
        #image = np.load(img_path_input_patch).astype('float32')
        #mask = np.load(img_path_tgt_patch).astype('float32')
        img = self.img_preprocess(img)
        mask = self.mask_preprocess(mask)

        return {
            'image': torch.as_tensor(img.copy()).float().contiguous(),
            'mask': torch.as_tensor(mask.copy()).long().contiguous()
        }

    def __len__(self):
        return len(self.ids)

#creating the dataset
dataset = BasicDataset("/Users/zachderse//Documents/mice_training_data/images/",
                     "/Users/zachderse//Documents/mice_training_data/labels/"
                    )


In [146]:
#splitting the dataset
n_val = int(len(dataset) * 0.8)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

#loading the dataset
train_loader = DataLoader(train_set, shuffle=True, batch_size = 10, num_workers=0)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, batch_size = 10, num_workers=0)



In [147]:
def train_step(model: torch.nn.Module, 
               dataloader: torch.utils.data.DataLoader, 
               loss_fn: torch.nn.Module, 
               optimizer: torch.optim.Optimizer):
#train mode
    model.train()
    
    #set loss and accuracy values to 0
    train_loss = 0
    
    # Loop through data loader data batches
    for batch in dataloader:
        
        images, true_masks = batch['image'], batch['mask']

        # assert images.shape[1] == model.n_channels, \
        #     f'Network has been defined with {model.n_channels} input channels, ' \
        #     f'but loaded images have {images.shape[1]} channels. Please check that ' \
        #     'the images are loaded correctly.'

        images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
        true_masks = true_masks.to(device=device, dtype=torch.long)

        with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
            masks_pred = model(images)
            loss = criterion(masks_pred, true_masks)
            loss += dice_loss(
                F.softmax(masks_pred, dim=1).float(),
                F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
                multiclass=True
            )

        optimizer.zero_grad(set_to_none=True)
        grad_scaler.scale(loss).backward()
        grad_scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
        grad_scaler.step(optimizer)
        grad_scaler.update()
        global_step += 1
        epoch_loss += loss.item()


        return train_loss
        #eval step, will try to not use
        
#         division_step = (n_train // (5 * batch_size))
#                 if division_step > 0:
#                     if global_step % division_step == 0:
#                         histograms = {}
#                         for tag, value in model.named_parameters():
#                             tag = tag.replace('/', '.')
#                             if not (torch.isinf(value) | torch.isnan(value)).any():
#                                 histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
#                             if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any():
#                                 histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())

#                         val_score = evaluate(model, val_loader, device, amp)
#                         scheduler.step(val_score)









        
#         #old stuff
#         y_pred = model(X)
#         print(y)
        
        
#         loss = loss_fn(y_pred, y)
#         train_loss += loss.item() 

#         # Zero the gradient to remove previous optimizer data
#         optimizer.zero_grad()

#         # back propagation
#         loss.backward()

#         #Step the optimizer
#         optimizer.step()
        

#         # Calculate and accumulate accuracy metric across all batches
#         y_pred_class = torch.argmax(torch.log_softmax(y_pred, dim=1), dim=1)
#         train_acc += (y_pred_class == y).sum().item()/len(y_pred)

#     # Adjust metrics to get average loss and accuracy per batch 
#     train_loss = train_loss / len(dataloader)
#     train_acc = train_acc / len(dataloader)
#     return train_loss, train_acc

In [148]:
def test_step(model: torch.nn.Module, 
              dataloader: torch.utils.data.DataLoader, 
              loss_fn: torch.nn.Module):
    # Switch to eval mode
    model.eval() 
    
    num_val_batches = len(dataloader)
    # Set loss and accuracy to zero
    dice_score = 0
    
    
    for batch in dataloader:
            image, mask_true = batch['image'], batch['mask']

            # move images and labels to correct device and type
            image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
            mask_true = mask_true.to(device=device, dtype=torch.long)

            # predict the mask
            mask_pred = model(image)

            
            assert mask_true.min() >= 0 and mask_true.max() < model.n_classes, 'True mask indices should be in [0, n_classes['
            # convert to one-hot format
            mask_true = F.one_hot(mask_true, model.n_classes).permute(0, 3, 1, 2).float()
            mask_pred = F.one_hot(mask_pred.argmax(dim=1), model.n_classes).permute(0, 3, 1, 2).float()
            # compute the Dice score, ignoring background
            dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False)

    return dice_score
    
    
#     # Turn on inference 
#     with torch.inference_mode():
#         # Loop through DataLoader batches
#         for batch, (X, y) in enumerate(dataloader):
#             # Send data to target device
#             X, y = X.to('cpu'), y.to('cpu')
#             print("X ", X)
#             print("y ", y)
            
            
#             test_pred_logits = model(X)

            
#             loss = loss_fn(test_pred_logits, y)
#             test_loss += loss.item()
            
#             # Calculate accuracy
#             test_pred_labels = test_pred_logits.argmax(dim=1)
#             test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))
            
#     # Adjust metrics to get average loss and accuracy per batch 
#     test_loss = test_loss / len(dataloader)
#     test_acc = test_acc / len(dataloader)
#     return test_loss, test_acc

In [149]:
from tqdm.auto import tqdm
def train(model: torch.nn.Module, train_dataloader: torch.utils.data.DataLoader, test_dataloader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer,loss_fn: torch.nn.Module = nn.CrossEntropyLoss(),epochs: int = 5):
    #initialize loss and accuracy values
    results = {"train_loss": [],
        "train_acc": [],
        "test_loss": [],
        "test_acc": []}
    
    #loop through epochs
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(model=model,
                                           dataloader=train_dataloader,
                                           loss_fn=loss_fn,
                                           optimizer=optimizer)
        test_loss, test_acc = test_step(model=model,
            dataloader=test_dataloader,
            loss_fn=loss_fn)
    
    #print and update results
        # print(
        #         f"Epoch: {epoch+1} | "
        #         f"train_loss: {train_loss:.4f} | "
        #         f"train_acc: {train_acc:.4f} | "
        #         f"test_loss: {test_loss:.4f} | "
        #         f"test_acc: {test_acc:.4f}")
        # results["train_loss"].append(train_loss)
        # results["train_acc"].append(train_acc)
        # results["test_loss"].append(test_loss)
        # results["test_acc"].append(test_acc)
    return 0


In [143]:
def train_model(model: torch.nn.Module, train_dataloader: torch.utils.data.DataLoader, test_dataloader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer,loss_fn: torch.nn.Module = nn.CrossEntropyLoss(),epochs: int = 5):
    #initialize loss and accuracy values
    for epoch in tqdm(range(epochs)):
        train_loss = train_step(model=model, dataloader=train_dataloader, loss_fn=loss_fn, optimizer=optimizer)
        dice_score = test_step(model=model, dataloader=test_dataloader, loss_fn=loss_fn)
        print(dice_score)

In [144]:
# Set random seeds
torch.manual_seed(42) 
torch.cuda.manual_seed(42)

# Set number of epochs
epochs = 5

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)


from timeit import default_timer as timer 
start_time = timer()

# Train model_0 
model_results = train_model(model=model, 
                        train_dataloader=train_loader,
                        test_dataloader=val_loader,
                        optimizer=optimizer,
                        loss_fn=loss_fn, 
                        epochs=epochs)

# End the timer and print out how long it took
end_time = timer()
print(f"Total training time: {end_time-start_time:.3f} seconds")

  0%|          | 0/5 [00:00<?, ?it/s]

RuntimeError: required rank 4 tensor to use channels_last format

In [44]:
from unet import UNet
device = "cpu"
model = UNet()
model = model.to(memory_format=torch.channels_last)
model.to(device=device)

UNet(
  (encoder): Encoder(
    (encoding_blocks): ModuleList(
      (0): EncodingBlock(
        (conv1): ConvolutionalBlock(
          (conv_layer): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
          (activation_layer): ReLU()
          (block): Sequential(
            (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
            (1): ReLU()
          )
        )
        (conv2): ConvolutionalBlock(
          (conv_layer): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
          (activation_layer): ReLU()
          (block): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
            (1): ReLU()
          )
        )
        (downsample): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (1): EncodingBlock(
        (conv1): ConvolutionalBlock(
          (conv_layer): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
          (activation_layer): ReLU()
          (block): Sequential(
            (0): C