In [None]:
from matplotlib import pyplot as plt
import pandas as pd
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
from utils import read_img
import json
import gc

from segmentation_models_pytorch.losses import DiceLoss, FocalLoss
from BlissLearn import BlissLearner
from BlissLearn.BlissCallbacks.Callbacks import SegmentationMetricsCallback, PrintCriteriaCallback
from utils import calculate_iou

In [None]:
def load_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
    return data

In [None]:
colors_file = r"C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\src\metadata\rgb_colors.json"
porosty_file = r"C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\src\metadata\porosty_info.json"

colors = load_json(colors_file)
porosty = load_json(porosty_file)

In [None]:
NUM_CLASSES = len(porosty[0]['classes'])
NUM_CLASSES

In [None]:
BATCH_SIZE = 2 
SIZE = 1024


path_to_train = r'C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\data\train-test\segmentation-train-data.xlsx'
path_to_test = r'C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\data\train-test\segmentation-test-data.xlsx'
path_to_models = r'C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\resources'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_str = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

In [None]:
train_df = pd.read_excel(path_to_train, index_col=0)
train_df['n'] = train_df[[f'class{i}' for i in range(NUM_CLASSES)]].sum(1)
train_df = train_df.sort_values(by='n') 
train_df = train_df.iloc[-240:].copy()
train_df.index = np.arange(len(train_df))
train_df

In [None]:
test_df = pd.read_excel(path_to_test, index_col=0)
test_df

In [None]:
colors = np.array([
    [0, 0, 0],        # class 0
    [0, 255, 0],      # class 1
    [255, 0, 255],    # class 2
    [255, 255, 0],    # class 3
    [255, 0, 0],      # class 4
    [0, 255, 255],    # class 5
    [255, 255, 255]   # fallback for unknown class
])

def get_image_mask(mask):

    # Клипим значения классов к максимально допустимым (вдруг в маске есть класс 6+)
    mask_clipped = np.clip(mask, 0, len(colors) - 1)
    
    # Применяем векторно цвета
    return colors[mask_clipped]

In [None]:
def transform_train(SIZE):
    return A.Compose(
        transforms=[
            ToTensorV2(p=1)
        ]
    )

def transform_valid(SIZE):
    return A.Compose(
        transforms=[
            A.RandomCrop(SIZE, SIZE, p=1),
            ToTensorV2(p=1)
        ]
    )

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, df, transforms):
        super().__init__()
        self.df = df
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):

        image = read_img(self.df.loc[index, 'image_path'])
        mask = read_img(self.df.loc[index, 'mask_path'], rgb=False)

        transformed = self.transforms(image=image, mask=mask)
        image, mask = transformed['image'] / 255, transformed['mask']

        return image, mask

In [None]:
train_datasets = SegmentationDataset(train_df, transforms=transform_train(SIZE))
valid_datasets = SegmentationDataset(test_df, transforms=transform_valid(SIZE))

In [None]:
x, y = train_datasets[0]
x.shape, y.shape

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].imshow(x.permute(1,2,0).cpu().numpy());
axs[1].imshow(get_image_mask(y.cpu().numpy()));

In [None]:
def collate_fn(batch):
    inputs, targets = zip(*batch)

    inputs = torch.stack(inputs).to(dtype=torch.float, device=device)
    targets = torch.stack(targets).to(dtype=torch.long, device=device)

    return inputs, targets

In [None]:
train_loader = DataLoader(
    train_datasets,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
)

valid_loader = DataLoader(
    valid_datasets,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
)

In [None]:
xb, yb = next(iter(train_loader))
xb.shape, yb.shape

In [None]:
# colorization_model = smp.Unet(
#         "mit_b3",
#         activation='sigmoid',
#         in_channels=1,
#         classes=3
#     ).to(device)

In [None]:
# colorization_model.load_state_dict(torch.load(r"C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\resources\ColorUnetMIT.pkl", weights_only=True))

In [None]:
model = smp.Unet(
        "mit_b3",
        activation=None,
        in_channels=3,
        classes=NUM_CLASSES,
        decoder_attention_type='scse'
    ).to(device)

In [None]:
model.encoder

In [None]:
model.load_state_dict(torch.load(r"C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\resources\UnetMITSegmentationModel.pkl", weights_only=True))

In [None]:
def adapt_colorization_input(model, target_in_channels=3):
    # Get original conv layer
    old_conv = model.encoder.patch_embed1.proj  # Conv2d(1, 64, ...)
    old_weights = old_conv.weight  # Shape: (64, 1, 7, 7)

    # Repeat or expand weights to match new input channels
    new_weights = old_weights.repeat(1, target_in_channels, 1, 1) / target_in_channels

    # Replace layer
    model.encoder.patch_embed1.proj = torch.nn.Conv2d(
        in_channels=target_in_channels,
        out_channels=old_conv.out_channels,
        kernel_size=old_conv.kernel_size,
        stride=old_conv.stride,
        padding=old_conv.padding,
        bias=old_conv.bias is not None
    ).to(old_conv.weight.device)

    with torch.no_grad():
        model.encoder.patch_embed1.proj.weight.copy_(new_weights)
        if old_conv.bias is not None:
            model.encoder.patch_embed1.proj.bias.copy_(old_conv.bias)

    return model

In [None]:
# colorization_model = adapt_colorization_input(colorization_model, target_in_channels=3)

In [None]:
# model.encoder.load_state_dict(colorization_model.encoder.state_dict())

In [None]:
for param in model.encoder.parameters():
    param.requires_grad = False

for param in model.encoder.patch_embed1.parameters():
    param.requires_grad = True

In [None]:
loss_fn = DiceLoss('multiclass', from_logits=True)

class CombinedLoss(torch.nn.Module):
    def __init__(self, dice_weight=0.5, focal_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.dice_loss = DiceLoss('multiclass', from_logits=True)
        self.focal_loss = FocalLoss('multiclass')
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight
    
    def forward(self, outputs, targets):
        dice_loss = self.dice_loss(outputs, targets)
        focal_loss = self.focal_loss(outputs, targets)
        return self.dice_weight * dice_loss + self.focal_weight * focal_loss

def accuracy(yb, preds):
    preds = torch.argmax(preds, dim=1)

    return (preds == yb).float().mean().item()

def iou(yb, preds):
    return calculate_iou[preds, yb]

# def loss(yb, preds):
#     loss = 0

#     for pred in preds:
#         loss += loss_fn(pred, yb)
#     return loss

metrics_callback = SegmentationMetricsCallback(
    num_classes=NUM_CLASSES,
    common_metrics={'accuracy': accuracy},
    class_metrics={"iou": calculate_iou}
)

In [None]:
learner = BlissLearner.BlissLearner(
    model,
    CombinedLoss(),
    torch.optim.Adam,
    dict(lr=1e-3),
    train_loader,
    valid_loader,
    callbacks=[
        metrics_callback,
        PrintCriteriaCallback()
    ],
)

In [None]:
learner.fit(40)

In [None]:
plt.plot(learner._callback_state.epoch_train_loss['loss'])
plt.plot(learner._callback_state.epoch_eval_loss['loss'])
plt.grid()

In [None]:
plt.plot(learner._callback_state.epoch_train_criteria['iou_mean'])
plt.plot(learner._callback_state.epoch_eval_criteria['iou_mean'])
plt.grid()

In [None]:
plt.plot(learner._callback_state.epoch_train_criteria['accuracy'])
plt.plot(learner._callback_state.epoch_eval_criteria['accuracy'])
plt.grid()

In [None]:
# example_inputs = torch.randn(1, 3, SIZE, SIZE).to(device)
# onnx_program = torch.onnx.export(model, example_inputs, dynamo=True)
# onnx_program.save(path_to_models + r"\UnetMITSegmentationModel.onnx")

In [None]:
# torch.save(model.state_dict(), path_to_models + r"\UnetMITSegmentationModel.pkl")

In [None]:
learner.get_train_info()

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
model(xb).shape

In [None]:
N = len(valid_datasets)
# N = 10
fig, axs = plt.subplots(N, 3, figsize=(15, 120))
model.eval()

for i, (x, y) in enumerate(valid_datasets):
    with torch.inference_mode():
        inputs = torch.unsqueeze(x.to(device).float(), 0)
        preds = model(inputs)[0]


    print(preds.shape)
    img = x.permute(1, 2, 0).cpu().numpy()
    mask = y.cpu().numpy()
    pred_mask = preds.argmax(dim=0).detach().cpu().numpy()
    
    axs[i, 0].imshow(img);
    axs[i, 0].axis(False)
    axs[i, 1].imshow(get_image_mask(mask));
    axs[i, 1].axis(False)
    axs[i, 2].imshow(get_image_mask(pred_mask));
    axs[i, 2].axis(False)
    if i == N - 1:
        break

In [None]:
model