In [None]:
import os
import kornia as K
import rasterio
import torch
import numpy as np

def toroidal_shift(image_tensor, shift):
    """
    Applies a toroidal shift to an image tensor.
    
    Parameters:
    - image_tensor (torch.Tensor): The image tensor of shape (C, H, W).
    - shift (tuple): The shift/movement as (dy, dx), where dy is the vertical shift and dx is the horizontal shift.
    
    Returns:
    - torch.Tensor: The shifted image tensor of the same shape as image_tensor.
    """
    C, H, W = image_tensor.shape
    dy, dx = shift

    # Compute the shifted indices
    shifted_y_indices = (torch.arange(H) + dy) % H
    shifted_x_indices = (torch.arange(W) + dx) % W

    # Apply the toroidal shift
    shifted_image = image_tensor[:, shifted_y_indices, :]
    shifted_image = shifted_image[:, :, shifted_x_indices]

    return shifted_image

def augment_and_save_images(input_folder, output_folder, num_augmented_images=3, num_sel_imgs=50, max_shift=0):
    # Ensure the output directory exists
    os.makedirs(output_folder, exist_ok=True)

    # Define a series of Kornia augmentations for data augmentation
    augmentations = torch.nn.Sequential(
        K.augmentation.RandomHorizontalFlip(p=0.5),
        K.augmentation.RandomVerticalFlip(p=0.5),
        K.augmentation.RandomRotation(degrees=40.0),
        # Add more transformations as needed, ensuring they are suitable for multi-channel images
    )

    imgs = os.listdir(input_folder)
    idxs = np.random.choice(len(imgs), num_sel_imgs, replace=False)
    process_imgs = [imgs[idx] for idx in idxs]
    # Process each file in the input directory
    for filename in process_imgs:
        if filename.endswith('.tif'):
            file_path = os.path.join(input_folder, filename)

            # Open the image with rasterio
            with rasterio.open(file_path) as src:
                image = src.read()  # Read the multi-channel image as (channels, height, width)
                image_tensor = torch.from_numpy(image).float()  # Convert to a float tensor

                # Perform the augmentation
                for i in range(num_augmented_images):
                    # Generate random shifts
                    if max_shift > 0:
                        dy = torch.randint(-max_shift, max_shift + 1, (1,)).item()
                        dx = torch.randint(-max_shift, max_shift + 1, (1,)).item()

                        # Apply the toroidal shift
                        augmented_image_tensor = toroidal_shift(image_tensor, (dy, dx))
                        image_tensor = augmented_image_tensor
                    # Apply the Kornia augmentations
                    augmented_image_tensor = augmentations(image_tensor.unsqueeze(0))  # Add batch dimension
                    augmented_image = augmented_image_tensor.squeeze().numpy()  # Remove batch dimension and convert to numpy

                    # Save the augmented image using the metadata of the original image
                    output_path = os.path.join(output_folder, f'augmented_{i}_{filename}')
                    with rasterio.open(output_path, 'w', **src.meta) as dst:
                        dst.write(augmented_image)

input_folder = '/home/roberto/PythonProjects/S2RAWVessel/mmdetection/data/Venus/classification/ds/Container Ship'  # Replace with your input folder path
output_folder = '/home/roberto/PythonProjects/S2RAWVessel/mmdetection/data/Venus/classification/ds/Container Ship'  # Replace with your output folder path
augment_and_save_images(input_folder, output_folder, num_augmented_images=2, num_sel_imgs=150)


In [None]:
def random_eraser(folder_path, num_erase=500):
    imgs = os.listdir(folder_path)
    idxs = np.random.choice(len(imgs), num_erase, replace=False)
    for idx in idxs:
        os.remove(os.path.join(folder_path, imgs[idx]))
        
        
random_eraser(input_folder, num_erase=300)
    
    

In [None]:
import matplotlib.pyplot as plt
from pathlib import Path

def list_imgs(path):
    return list(Path(path).iterdir())

def read_img(img_path):
    with rasterio.open(img_path) as src:
        img = src.read()
    return img

def display_img(img):

    bands = [img[x,:,:] for x in range(12)]

    plt.figure(figsize=(15,15))
    for i in range(12):
        plt.subplot(1,12,i+1)
        plt.imshow(bands[i])
        plt.axis('off')
        plt.title('Band {}'.format(i+1))
    plt.show()
    
imgs = list_imgs('/home/roberto/PythonProjects/S2RAWVessel/mmdetection/data/Venus/classification/ds/Bulk Carrier Augmented')
print('Found {} images'.format(len(imgs)))

for i in range(3):
    im_data = read_img(imgs[i])
    print(f'Displaying img: {imgs[i].name}')
    display_img(im_data)
