# Self-supervised learning augmentations

In [None]:
import albumentations as A
from utils.imageIO import read_img
import pandas as pd
import cv2 # todo: move save to utils
import os
from tqdm import tqdm

In [None]:
TARGET_SIZE = 2048

all_images_pipeline = A.Compose(
    [
        # 1. Initial Resizing/Cropping (Choose one strategy)
        A.RandomResizedCrop(size=(TARGET_SIZE, TARGET_SIZE), scale=(0.9, 1), ratio=(0.95, 1.05), p=1.0),

        # 2. Basic Geometric
        A.D4(p=0.5),

        # 3. Affine
        A.Affine(
            rotate=(-15, 15),
            translate_percent=(-0.05, 0.05),
            shear=(-5, 5),
            p=0.6
        ),

        # 9. Distortion (Use if relevant to domain)
        A.ElasticTransform(alpha=0.3, sigma=30, p=0.5),
    ]
)

gray_images_pipeline = A.Compose(
    [
        # 4. Dropout / Occlusion
        A.OneOf([
            A.CoarseDropout(num_holes_range=(1, 4), hole_height_range=(0.05, 0.15), hole_width_range=(0.05, 0.15), fill_value=0, p=0.8),
            A.GridDropout(ratio=0.3, unit_size_range=(10, 20), p=0.4),
        ], p=0.4),  # Меньше вероятность применения

        # 5. Color Space / Type Reduction
        # A.ToGray(p=1), Applies later

        # 6. Color Augmentations (Brightness, Contrast, Saturation, Hue)
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.7),
            A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05, p=0.7),
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.7),
            A.RandomGamma(gamma_limit=(90, 110), p=0.7),
        ], p=0.6),

        # 7. Blur
        A.OneOf([
            A.GaussianBlur(blur_limit=(1, 3), p=0.4),
            A.MedianBlur(blur_limit=3, p=0.4),
        ], p=0.3),

        # 8. Noise
        A.OneOf([
            A.GaussNoise(std_range=(0.02, 0.05), p=0.3),
            A.MultiplicativeNoise(multiplier=(0.95, 1.05), per_channel=True, p=0.3),
            A.SaltAndPepper(p=0.3)
        ], p=0.3),

        # 10. Compression / Downscaling Artifacts
        A.OneOf([
            A.ImageCompression(quality_range=(5, 15), p=0.3),
            A.Downscale(scale_range=(0.4, 0.6), p=0.4),
        ], p=0.1)  # Меньше вероятность применения
    ]
)

In [None]:
path_to_train_data = r'C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\data\train-test\balanced-train-data.xlsx'
path_to_test_data = r'C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\data\train-test\balanced-test-data.xlsx'
path_to_save_train = r'C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\data\train-test\ssl-train'
path_to_save_test = r'C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\data\train-test\ssl-test'

In [None]:
def apply_transforms(image_path, path_to_save, image_number=0):

    image = read_img(image_path)
    augmented_rgb = all_images_pipeline(image=image)['image']
    rgb_after_gray = gray_images_pipeline(image=augmented_rgb)['image']
    augmented_gray = cv2.cvtColor(rgb_after_gray, cv2.COLOR_RGB2GRAY)
    name = os.path.basename(image_path)

    image_filename = 'image' + str(image_number) + name
    mask_filename = 'gray_image' + str(image_number) + name

    output_image_path = path_to_save + "\\" + image_filename
    output_mask_path =  path_to_save + "\\" + mask_filename

    cv2.imwrite(output_image_path, augmented_rgb[:, :, ::-1])
    cv2.imwrite(output_mask_path, augmented_gray)

    return output_image_path, output_mask_path


In [None]:
def save_test_data(image_path, path_to_save):
    image = read_img(image_path)
    gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)

    name = os.path.basename(image_path)

    image_filename = 'image' + name
    mask_filename = 'gray_image' + name

    output_image_path = path_to_save + "\\" + image_filename
    output_mask_path =  path_to_save + "\\" + mask_filename

    cv2.imwrite(output_image_path, image[:, :, ::-1])
    cv2.imwrite(output_mask_path, gray_image)

    return output_image_path, output_mask_path

In [None]:
train_data = pd.read_excel(path_to_train_data)
train_data

In [None]:
test_data = pd.read_excel(path_to_test_data)
test_data

In [None]:
# augmented_train = list()

# for num in range(20):
#     for mask_path in tqdm(train_data['name'], total=len(train_data)):
#         image_path = mask_path.replace('masks', 'images')
#         image_path, gray_image_path = apply_transforms(image_path, path_to_save_train, num)
#         augmented_train.append((image_path, gray_image_path))

In [None]:
# augmented_train_df = pd.DataFrame(augmented_train)
# augmented_train_df

In [None]:
# augmented_train_df.to_excel(r'C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\data\train-test\ssl-train-data.xlsx')

In [None]:
test = list()

for mask_path in tqdm(test_data['name'], total=len(test_data)):
    image_path = mask_path.replace('masks', 'images')
    image_path, gray_image_path = save_test_data(image_path, path_to_save_test)
    test.append((image_path, gray_image_path))

In [None]:
test_df = pd.DataFrame(test)
test_df

In [None]:
test_df.to_excel(r'C:\Users\Viktor\Documents\IT\ReservoirRockAnalysis\data\train-test\ssl-test-data.xlsx')