In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import glob
import PIL.Image as Image
import torch.utils.data as data
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm import tqdm
from ipywidgets import interact, fixed

PREFIX = '/kaggle/input/vesuvius-challenge-ink-detection/train/1/'
BUFFER = 30  # Buffer size in x and y direction
Z_START = 27 # First slice in the z direction to use
Z_DIM = 10   # Number of slices in the z direction
TRAINING_STEPS = 30000
LEARNING_RATE = 0.03
BATCH_SIZE = 32
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

plt.imshow(Image.open(PREFIX+"ir.png"), cmap="gray")

In [19]:
import numba
import numpy as np
from PIL import Image
import torch
from matplotlib import pyplot as plt
from torch.utils.data import Dataset
from typing import Tuple, List
from torchvision.transforms import transforms

class VesuviusTrainData(Dataset):
    _relative_sv_dir = "surface_volume"  # relative to the directory of train data

    def __init__(self,
                 dir_path: str = "train/1",
                 z_start: int = 0,
                 z_end: int = 64,
                 nucleus_shape: Tuple[int, int] = (16, 16),
                 hull_size: Tuple[int, int] = (64, 64),
                 compress_depth =  None,
                 give_indx = False
                 ):
        self.dir_path = dir_path

        self.z_start = z_start
        self.z_end = z_end

        self.nucleus_shape = nucleus_shape
        self.hull_size = hull_size
        
        self.compress_depth = compress_depth
        self.give_indx = give_indx
        
        self._setup()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, index: int):
        i, j = self._get_pixel_from_index(self.indices[index])
        nucleus_height, nucleus_width = self.nucleus_shape

        if self.give_indx == False:
            return (torch.from_numpy(
                    self.images[
                    :,
                    i - self._left_hull + self._left_pad_image:
                    i + nucleus_width + self._right_hull + self._left_pad_image,
                    j - self._top_hull + self._top_pad_image:
                    j + nucleus_height + self._bottom_hull + self._top_pad_image
                     ]),
                    torch.from_numpy(
                    self.inklabels[i: i + nucleus_width,
                    j: j + nucleus_height]))
        else:
            return (torch.from_numpy(
                self.images[
                :,
                i - self._left_hull + self._left_pad_image:
                i + nucleus_width + self._right_hull + self._left_pad_image,
                j - self._top_hull + self._top_pad_image:
                j + nucleus_height + self._bottom_hull + self._top_pad_image
                 ]),
                torch.from_numpy(
                self.inklabels[i: i + nucleus_width,
                               j: j + nucleus_height]),
                   np.stack(([i+k for k in range(nucleus_width)],[
                    j+k for k in range(nucleus_height)]),axis=1))

    def _load_images(self, z_start=None,z_end=None):
        if z_end is None:
            z_end = self.z_end 
        if z_start is None:
            z_start = self.z_start
            
        num_images = z_end - z_start + 1
        self.images = np.empty(
            (
                num_images,
                self.mask.shape[0] + self._left_pad_image + self._right_pad_image,
                self.mask.shape[1] + self._top_pad_image + self._bottom_pad_image
            ), 
            dtype=np.float32
        )
        for index, i in enumerate(range(z_start, z_end)):
            # noinspection PyTypeChecker
            image = np.array(Image.open(f"{self.dir_path}"
                                        f"/{self._relative_sv_dir}"
                                        f"/{self.z_start + i:02d}.tif"),
                             dtype=np.float32) / 65535.0
            image = np.pad(image, ((self._left_pad_mask + self._left_pad_image,
                                    self._right_pad_mask + self._right_pad_image),
                                   (self._top_pad_mask + self._top_pad_image,
                                    self._bottom_pad_mask + self._bottom_pad_image)),
                           'constant', constant_values=0)
            self.images[index, :, :] = image
            
    def _load_scale(self):
        current_stack_layer = self.z_start
        images = np.empty(
            (
                len(self.compress_depth),
                self.mask.shape[0] + self._left_pad_image + self._right_pad_image,
                self.mask.shape[1] + self._top_pad_image + self._bottom_pad_image
            ), 
            dtype=np.float32
        )
        for index,layer in enumerate(self.compress_depth):
            self._load_images(current_stack_layer, current_stack_layer + layer)
            current_stack_layer += layer + 1
            layer_stack = self.images
            images[index,:,:] = np.mean(layer_stack,axis = 0)
            
        self.images = images
    
    # Internal utility methods
    def _setup(self):
        # noinspection PyTypeChecker
        self.mask = np.array(Image.open(f"{self.dir_path}/mask.png"))
        # noinspection PyTypeChecker
        self.inklabels = np.array(Image.open(f"{self.dir_path}/inklabels.png"))
        self._granulate_mask()
        self._fill_new_mask()
        if self.compress_depth is None:
            self._load_images()
        else:
            self._load_scale()
        self._get_leftover_hull()

    def _granulate_mask(self):
        self._get_paddings()
        self.mask = np.pad(self.mask,
                           ((self._left_pad_mask, self._right_pad_mask),
                            (self._top_pad_mask, self._bottom_pad_mask)),
                           'constant', constant_values=0)
        self.indices = _find_labeled_nucleus(self.mask, self.nucleus_shape)
        self.inklabels = np.pad(self.inklabels, 
                                ((self._left_pad_mask, self._right_pad_mask),
                                 (self._top_pad_mask, self._bottom_pad_mask)),
                                'constant', 
                                constant_values=0)

    def _get_paddings(self):
        width, height = self.mask.shape
        nucleus_width, nucleus_height = self.nucleus_shape
        hull_width, hull_height = self.hull_size
        self._left_pad_mask = (width % nucleus_width) // 2
        self._right_pad_mask = width % nucleus_width - self._left_pad_mask
        self._top_pad_mask = (height % nucleus_height) // 2
        self._bottom_pad_mask = height % nucleus_height - self._top_pad_mask

        self._left_pad_image = (hull_width - nucleus_width) // 2
        self._right_pad_image = hull_width - nucleus_width - self._left_pad_image
        self._top_pad_image = (hull_height - nucleus_height) // 2
        self._bottom_pad_image = hull_height - nucleus_height - self._top_pad_image

    def _fill_new_mask(self):
        nucleus_width, nucleus_height = self.nucleus_shape
        self.mask = np.zeros_like(self.mask)
        for index in self.indices:
            i, j = self._get_pixel_from_index(index)
            self.mask[
                i * nucleus_width: (i + 1) * nucleus_width,
                j * nucleus_height: (j + 1) * nucleus_height
            ] = 1

    def _get_pixel_from_index(self, index: int):
        mask_width, mask_height = self.mask.shape
        nucleus_width, nucleus_height = self.nucleus_shape
        height = mask_height // nucleus_height
        return index // height * nucleus_height, index % height * nucleus_width

    def _get_leftover_hull(self):
        nucleus_width, nucleus_height = self.nucleus_shape
        hull_width, hull_height = self.hull_size
        self._left_hull = (hull_width - nucleus_width) // 2
        self._right_hull = hull_width - nucleus_width - self._left_hull
        self._top_hull = (hull_height - nucleus_height) // 2
        self._bottom_hull = hull_height - nucleus_height - self._top_hull


@numba.njit()
def _find_labeled_nucleus(mask, nuclues_shape):
    width, height = mask.shape
    nucleus_width, nucleus_height = nuclues_shape

    indices = - np.ones((width // nucleus_width) * (height // nucleus_height), dtype=np.int32)
    current_index = 0
    for i in range(width // nucleus_width):
        for j in range(height // nucleus_height):
            if np.sum(
                    mask[i * nucleus_width: (i + 1) * nucleus_width,
                         j * nucleus_height: (j + 1) * nucleus_height]
                      ) > 0:
                indices[current_index] = i * (height // nucleus_height) + j
                current_index += 1
    return indices[indices != -1]


In [31]:
data1 = VesuviusTrainData(
    dir_path="/kaggle/input/vesuvius-challenge-ink-detection/train/1"\
    ,nucleus_shape = (12,12),hull_size=(32,32), z_start=0\
    , z_end=60,compress_depth=[20,20,20],give_indx=True)

In [None]:
data3 = VesuviusTrainData(
    dir_path="/kaggle/input/vesuvius-challenge-ink-detection/train/3"\
    ,nucleus_shape = (12,12),hull_size=(32,32), z_start=0\
    , z_end=60,compress_depth=[20,20,20],give_indx=True,istransform = False)

In [72]:
from torch.utils.data import ConcatDataset
merged_dataset = ConcatDataset([data1, data3])

In [88]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        
        # Encoder layers
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 8, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU()
        )
        
        self.loss_list = []
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded
    
    def augment(self,dataset):
        augmentation = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ColorJitter(),
        transforms.ToTensor()
        ])

        augmented_data = []
        for image in dataset:
            augmented_image = augmentation(image)
            augmented_data.append(augmented_image)

        return torch.stack(augmented_data)
    
    
    def fit(self,data,criterion,optimizer,batch_size=1024, num_epochs=20,verbose = False):
        dataloader = DataLoader(data, batch_size=batch_size, shuffle=True)
        for epoch in range(num_epochs):
            epoch_loss = []
            for batch_idx, (data, target, indxs) in tqdm(enumerate(dataloader)):
                inputs = self.augment(data)
                outputs = autoencoder(inputs)
                loss = criterion(outputs, inputs)
                result = np.sum(np.abs(self.forward(data).detach().numpy()-data.numpy()))
                epoch_loss.append(result)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                if batch_idx == 20:
                    break
                
            self.loss_list.append(np.mean(epoch_loss))
            if verbose == True:    
                print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {self.loss_list[-1]:.4f}")

In [92]:
autoencoder = Autoencoder()
#autoencoder.load_state_dict(torch.load('/kaggle/input/auto-encoder-weight/autoencoder_weights.pth'\
#                                 , map_location=torch.device('cpu')))
criterion = nn.MSELoss()
optimizer = optim.Adam(autoencoder.parameters(), lr=0.005)
num_epochs=10
autoencoder.fit(merged_dataset,criterion=criterion,optimizer=optimizer,num_epochs=num_epochs\
                ,verbose =True)

20it [00:12,  1.56it/s]


Epoch [1/10], Loss: 686952.0625


20it [00:12,  1.60it/s]


Epoch [2/10], Loss: 405766.6562


20it [00:12,  1.60it/s]


Epoch [3/10], Loss: 286760.4688


20it [00:13,  1.52it/s]


Epoch [4/10], Loss: 211429.7812


20it [00:12,  1.64it/s]


Epoch [5/10], Loss: 131932.8750


20it [00:12,  1.60it/s]


Epoch [6/10], Loss: 109890.2969


20it [00:12,  1.62it/s]


Epoch [7/10], Loss: 106929.4297


20it [00:12,  1.61it/s]


Epoch [8/10], Loss: 100020.9609


20it [00:12,  1.65it/s]


Epoch [9/10], Loss: 97181.9609


20it [00:12,  1.64it/s]

Epoch [10/10], Loss: 95427.6875





In [93]:
autoencoder.encoder(data1[0][0]).shape

torch.Size([10, 8, 8])

In [103]:
num_epochs=10
autoencoder.fit(merged_dataset,criterion=criterion,optimizer=optimizer,num_epochs=num_epochs\
                ,verbose =True)

20it [00:12,  1.67it/s]


Epoch [1/10], Loss: 34905.5195


20it [00:12,  1.66it/s]


Epoch [2/10], Loss: 33654.4688


20it [00:12,  1.60it/s]


Epoch [3/10], Loss: 33438.1172


20it [00:11,  1.69it/s]


Epoch [4/10], Loss: 33170.9180


20it [00:12,  1.64it/s]


Epoch [5/10], Loss: 34437.7422


20it [00:12,  1.65it/s]


Epoch [6/10], Loss: 35078.1445


20it [00:12,  1.65it/s]


Epoch [7/10], Loss: 32606.7734


20it [00:12,  1.66it/s]


Epoch [8/10], Loss: 32581.2715


20it [00:12,  1.64it/s]


Epoch [9/10], Loss: 32213.3242


20it [00:12,  1.64it/s]

Epoch [10/10], Loss: 37251.1953





In [71]:
torch.save(autoencoder.state_dict(), 'autoencoder_weights.pth')

In [63]:
data2 = VesuviusTrainData(
    dir_path="/kaggle/input/vesuvius-challenge-ink-detection/train/2"\
    ,nucleus_shape = (12,12),hull_size=(32,32), z_start=0\
    , z_end=60,compress_depth=[20,20,20],give_indx=True,istransform = False)



In [104]:
loader = DataLoader(data2, batch_size=1024, shuffle=True)

result_loss = []
for batch_idx, (data, target, indices) in enumerate(loader): 
    result_loss.append(np.sum(np.abs(autoencoder(data).detach().numpy()-data.numpy())))
    


In [105]:
print(np.mean(result_loss))

37647.38


In [84]:
print((12*12*1024-33604.33)/(12*12*1024))

0.7721060519748264
