In [None]:
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import os
from tqdm.notebook import tqdm
from dgl.geometry import farthest_point_sampler

In [None]:
input_path = 'data/val'
output_path = 'data/manipulated_data'
path_list = [f'{input_path}/{f}' for f in os.listdir(input_path) if f.endswith('.pth')]

In [None]:
def remove_color(data, _=None):
    data[1] = np.zeros(data[1].shape)
    return data

def sample_farthest(data, p):
    n_new = int(len(data[0])*p)
    arr = np.array([data[0]])
    point_idx = farthest_point_sampler(torch.from_numpy(arr), n_new)[0]
    return [a[point_idx] for a in data]

def sample_random(data, p):
    mask = np.random.rand(len(data[0])) < p
    return [a[mask] for a in data]

def add_noise_to_cords(data, s):
    data[0] = np.random.normal(data[0], s)
    return data

def manipulate_data(f, name: str, values:list=None):
    values = values if values else ['']
    for value in tqdm(values):
        out_path = f'{output_path}/{name}_{value}/val'
        os.makedirs(out_path, exist_ok=True)
        f_adapted = lambda x: f(x, value)
        
        for file_path in tqdm(path_list, leave=False):
            file_name = file_path.split('/')[-1]
            data = torch.load(file_path)
            data = tuple(f_adapted(list(data)))
            torch.save(data, f'{out_path}/{file_name}')
            

In [None]:
manipulate_data(remove_color, 'no_color')

In [None]:
manipulate_data(sample_random, 'sample_random', [1.0, 0.95, 0.9, 0.85, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3])

In [None]:
manipulate_data(sample_farthest, 'sample_farthest', [0.6])

In [None]:
manipulate_data(add_noise_to_cords, 'noisy', [0, 0.1, 0.2, 0.5, 0.8, 1.0])