In [None]:
from glob import glob
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.metrics import jaccard_score as iou_score
import segmentation_models_pytorch as smp
from dataclasses import dataclass
from collections import defaultdict
from torchvision.transforms import GaussianBlur
from utils import read_img

In [None]:
BATCH_SIZE = 4
SIZE = 1024
EPOCHS = 20
N_CHANNELS = 3
N_CLASSES = 6
encoder = 'efficientnet-b0'
dataset = 'imagenet'
aux_params=dict(
    pooling='avg',
    dropout=0.5,
    activation=None,
    classes=N_CLASSES
)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_str = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

In [None]:
blur = GaussianBlur(3, sigma=(0.01, 1.0))

In [None]:
def transform_train():
    transforms = [
        # A.SafeRotate(p=0.4),
        A.RandomCrop(SIZE, SIZE, p=1),
        # A.HorizontalFlip(p=0.5),
        # A.Transpose(p=0.5),
        # A.ColorJitter(brightness=0.33,
        #               contrast=0.19,
        #               saturation=0.19,
        #               hue=(-0.05, 0.095),
        #               p=1),
    ]
    return A.Compose(transforms)


def transform_valid():
    transforms = [
        A.RandomCrop(SIZE+1, SIZE+1, p=1),
        A.Resize(SIZE, SIZE, p=1),
    ]
    return A.Compose(transforms)


def to_tensor():
    transforms = [
        ToTensorV2(p=1)
    ]
    return A.Compose(transforms)

In [None]:
from segmentation_models_pytorch.encoders import get_preprocessing_fn

class PoreDataset(Dataset):
    def __init__(self, df, transforms):
        super().__init__()
        self.df = df
        self.transforms = transforms
        self.to_tensor = to_tensor()
        self.preprocess_input = get_preprocessing_fn(encoder, pretrained='imagenet')

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):

        image = read_img(self.df.loc[index, 'image'])
        mask = read_img(self.df.loc[index, 'mask'], rgb=False)

        transformed = self.transforms(image=image, mask=mask)
        image, mask = transformed['image'], transformed['mask']

        transformed = self.to_tensor(image=image, mask=mask)
        image, mask = transformed['image'], transformed['mask']
        image = blur(image)
        image = torch.reshape(image, (SIZE, SIZE, 3))
        image = self.preprocess_input(image)
        image = torch.reshape(image, (3, SIZE, SIZE))
        return image.float().to(device), mask.long().to(device)

In [None]:
def get_name(image):
    image_name = image.split('/')[-1]
    return image_name[:-4]

In [None]:
np.random.seed(42)

images = (
    glob(r"C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\data\Images\Sihor\images\*")
    +
    glob(r"C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\data\Images\Surhar\images\*")
)

masks = (
    glob(r"C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\data\Images\Sihor\masks\*")
    +
    glob(r"C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\data\Images\Surhar\masks\*")
)


# images = sorted(list(images))[2:3]
# masks = sorted(list(masks))[2:3]

print('images count: ', len(images))
print('masks count: ', len(masks))

assert  (length := len(images)) == len(masks)
is_valid = np.random.choice([False, True], length, p=[0.8, 0.2])
df = pd.DataFrame({'image': images, 'mask': masks, 'is_valid': is_valid})

# df.to_csv('/content/drive/MyDrive/pore segmentation/data.csv', sep=' ')

train_df = df[~df['is_valid']]
# valid_df = df[df['is_valid']]
valid_df = train_df.copy()
train_df.index, valid_df.index = np.arange(len(train_df)), np.arange(len(valid_df))

In [None]:
# train_df = pd.concat([train_df for _ in range(BATCH_SIZE)])
# train_df.index = np.arange(BATCH_SIZE)
train_df

In [None]:
valid_df

In [None]:
train_datasets = PoreDataset(train_df, transforms=transform_train())
valid_datasets = PoreDataset(valid_df, transforms=transform_valid())

In [None]:
x, y = train_datasets[0]
x.shape, y.shape

In [None]:
fig, axs = plt.subplots(1, 2)
axs[0].imshow(x.permute(1, 2, 0).cpu().numpy());
axs[1].imshow(y.cpu().numpy());

In [None]:
train_loader = DataLoader(
    train_datasets,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    # pin_memory=True
)

valid_loader = DataLoader(
    valid_datasets,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    # pin_memory=True
)

In [None]:
from segmentation_models_pytorch.losses import DiceLoss, FocalLoss

In [None]:
class CombinedLoss(torch.nn.Module):
    def __init__(self, dice_weight=0.2, focal_weight=0.8):
        super(CombinedLoss, self).__init__()
        self.dice_loss = DiceLoss('multiclass', from_logits=True)
        self.focal_loss = FocalLoss('multiclass', gamma=4)
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight
    
    def forward(self, outputs, targets):
        dice_loss = self.dice_loss(outputs, targets)
        focal_loss = self.focal_loss(outputs, targets)
        return self.dice_weight * dice_loss + self.focal_weight * focal_loss

In [None]:
# import sys
# sys.path.append("C:/Users/Viktor/Documents/IT/ReservoirRockAnalysis/src/NeuralNetwork/")
from BlissLearn import BlissLearner
# from BlissLearn import SegmentationMetricsCallback

In [None]:
class Unet(torch.nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        self.model = smp.Unet(
            encoder_name=encoder,
            encoder_weights=dataset,
            in_channels=N_CHANNELS,
            classes=N_CLASSES,
            aux_params=aux_params
        ).to(device)
    def forward(self, x):
       return self.model(x)[0]

model = Unet()

# from models import HrSegNetB64

# model = HrSegNetB64(num_classes=N_CLASSES, in_channels=N_CHANNELS).to(device)

# loss_fn = torch.nn.functional.cross_entropy
# loss_fn = DiceLoss(
#     'multiclass',
#     from_logits=True,
#     # alpha=0.7,
#     # beta=0.3
# )

loss_fn = CombinedLoss()

learner = BlissLearner(
    model,
    loss_fn,
    torch.optim.SGD,
    dict(lr=1e-3, momentum=0.9, weight_decay=0.0005),
    train_loader,
    valid_loader,
)


In [None]:
learner.fit(5)

In [None]:
import gc

gc.collect()
torch.cuda.empty_cache()

In [None]:
# learner.plot_lr_finding(init_value=1e-9, final_value=10)

In [None]:
def get_color(clls):
    if clls == 0:
        return [0, 0, 0]
    elif clls == 1:
        return [0, 255, 0]
    elif clls == 2:
        return [255, 0, 255]
    elif clls == 3:
        return [255, 255, 0]
    elif clls == 4:
        return [255, 0, 0]
    elif clls == 5:
        return [0, 255, 255]
    else:
        return [255, 255, 255]

def get_image_mask(mask):
    s = mask.shape
    return np.array([get_color(pixel) for row in mask for pixel in row]).reshape(s + (3,))

In [None]:
learner.train_model(300)

In [None]:
learner.plot_learning_info()

In [None]:
# model.load_state_dict(torch.load('/content/drive/MyDrive/pore segmentation/manet.pkl', map_location=device))

In [None]:
learner.validate_n_epochs(5)

In [None]:
preds = model(xb.to(device).float())[0]
idx = 0
img = xb[idx].permute(1, 2, 0).cpu().numpy()
mask = yb[idx].cpu().numpy()
pred_mask = preds.argmax(axis=1)[idx].cpu().numpy()

fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(img);
axs[1].imshow(get_image_mask(mask));
axs[2].imshow(get_image_mask(pred_mask));

In [None]:
N = len(valid_datasets)
fig, axs = plt.subplots(N, 3, figsize=(15, 120))

for i, (x, y) in enumerate(valid_datasets):
    preds = model(torch.unsqueeze(x.to(device).float(), 0))[0]
    img = x.permute(1, 2, 0).cpu().numpy()
    mask = y.cpu().numpy()
    pred_mask = preds.argmax(axis=1).cpu().numpy()

    axs[i, 0].imshow(img);
    axs[i, 0].axis(False)
    axs[i, 1].imshow(get_image_mask(mask));
    axs[i, 1].axis(False)
    axs[i, 2].imshow(get_image_mask(pred_mask[0]));
    axs[i, 2].axis(False)
    if i == N - 1:
        break