In [24]:
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
from tqdm import tqdm


In [3]:
class SinteticDataset(torch.utils.data.Dataset):
    def __init__(self, directory, subwindow=None,
                 transform=None, D=1, skip=1):
        # D - Number of datas to retrieve. 
        # skip - interval between observations
        # skip = 1 we get every observation
        # skip = 2 every other observation

        self.directory = directory
        self.subwindow = subwindow  # Proportion of subwindow
        self.files = [f for f in os.listdir(directory) if '.npy' in f]
        self.D = D
        self.skip = skip
        self.Nx = (40*2)*2
        self.Ny = (68*2)*2
        self.valid_index = self.calcular_indices_validos()


    def __len__(self):
        return len(self.valid_index)

    def __getitem__(self, idx, subwindow=None):
        idx = self.valid_index[idx]
        file = idx//150
        data = np.load(self.directory + "/" + self.files[file])
        if file>0:
            idx = idx-file*150
        psi1 = data[0:150]
        psi2 = data[150:]
        psi1 = data[idx: idx+self.D*self.skip: self.skip]
        psi2 = data[idx: idx+self.D*self.skip: self.skip]
        if subwindow is None:
            subwindow = self.subwindow

        Ny_mesh, Nx_mesh = torch.meshgrid(torch.arange(self.Ny),torch.arange(self.Nx))
        lat_idx, lon_idx = self.get_indices_from_proportion(Ny_mesh, Nx_mesh, subwindow)

        psi1 = torch.tensor(psi1.reshape(psi1.shape[0],
                                         self.Ny,self.Nx)[:,
                                                     lat_idx[0]:lat_idx[1],
                                                     lon_idx[0]:lon_idx[1]])
        
        psi2 = torch.tensor(psi2.reshape(psi2.shape[0],
                                         self.Ny,self.Nx)[:,
                                                     lat_idx[0]:lat_idx[1],
                                                     lon_idx[0]:lon_idx[1]])
                                                     
        return torch.permute(torch.stack([psi1,psi2]), (1,0,2,3))
        # [D, psis, NY, NX]
        return self.prepare_tensors(u_velocity, v_velocity, ssh, mask, sliced_latitudes, sliced_longitudes)
    

    def get_indices_from_proportion(self, latitudes, longitudes, subwindow):
        if subwindow:
            lat_range = [int(subwindow[0][0] * len(latitudes)), int(subwindow[0][1] * len(latitudes))]
            lon_range = [int(subwindow[1][0] * len(longitudes)), int(subwindow[1][1] * len(longitudes))]
            return lat_range, lon_range
        return [0, len(latitudes)], [0, len(longitudes)]
    
    """
    def prepare_tensors(self, u_velocity, v_velocity, ssh, mask, latitudes, longitudes):
        u_tensor = torch.tensor(u_velocity.filled(0.0), dtype=torch.float32)
        v_tensor = torch.tensor(v_velocity.filled(0.0), dtype=torch.float32)
        ssh_tensor = torch.tensor(ssh.filled(0.0), dtype=torch.float32)
        combined_tensor = torch.stack([u_tensor, v_tensor, ssh_tensor])
        mask_tensor = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)

        # Creating tensors for the latitude and longitude slices
        lat_tensor = torch.tensor(latitudes, dtype=torch.float32)
        lon_tensor = torch.tensor(longitudes, dtype=torch.float32)

        return mask_tensor, combined_tensor, lat_tensor, lon_tensor
    """
    
    def calcular_indices_validos(self):
        block_size = 150
        total_size = block_size*len(self.files)
        
        # Lista para armazenar os índices válidos
        valid_indices = []
        
        # Percorre todos os índices possíveis
        for start in range(total_size):
            # Calcula os índices da sequência
            indices = [start + i * self.skip for i in range(self.D)]
            # print(indices)
            # Verifica se o último índice é válido dentro do tamanho total
            if indices[-1] >= total_size:
                continue

            # Determina o bloco do índice inicial e final
            start_block = start // block_size
            end_block = indices[-1] // block_size
            # print('startbloc = ', start_block)
            # print('end = ', end_block)
            # Verifica se o bloco inicial e final são iguais
            if start_block != end_block:
                continue
            
            # Adiciona o índice inicial à lista de válidos se todas as condições forem satisfeitas
            valid_indices.append(start)
        
        return valid_indices

class ContiguousSinteticDatasetAutoregressive(SinteticDataset):
    def __getitem__(self, idx, subwindow=None):
        tensors = super().__getitem__(idx, subwindow)
        x = tensors[-1]
        y = {'y': tensors[:-1].reshape(-1,tensors.size(-2),tensors.size(-1))}
        return x, y

In [49]:
def plotField2(ax, psi, Lx, Ly, filename):
    x = np.linspace(0, Lx, psi.shape[0])
    y = np.linspace(0, Ly, psi.shape[1])
    X, Y = np.meshgrid(x, y)
    psi = (psi + 2.5)/5
    levels = np.linspace(0, 1, 5)
    thresh_psi = np.digitize(psi.T, levels)
    plt.imshow(thresh_psi, cmap='bwr')
    # ax.set_cmap('bwr')
    # contour = ax.contourf(X, Y, np.transpose(psi), levels=levels, cmap = 'bwr')
    # plt.colorbar(contour, ax=ax) #, label='Value')

    ax.axis('off')

    plt.savefig(f"test/{filename}.png", format='png', bbox_inches='tight', pad_inches=0)
    plt.close()

In [50]:
path = "../data/ocean"

In [55]:
data = ContiguousSinteticDatasetAutoregressive(path, D = 150)
for i, file in enumerate(data):
    file = file[1]['y']

    psi1 = file[::2]
    psi2 = file[1::2]

    for j in tqdm(range(149)):
        np.save('../ocean/data/raw/{i}_psi1_{j}.npy', psi1[j].T)
        np.save('../ocean/data/raw/{i}_psi2_{j}.npy', psi2[j].T)

100%|██████████| 149/149 [00:00<00:00, 1090.27it/s]
100%|██████████| 149/149 [00:00<00:00, 1115.68it/s]
100%|██████████| 149/149 [00:00<00:00, 809.38it/s]
100%|██████████| 149/149 [00:00<00:00, 1113.96it/s]
100%|██████████| 149/149 [00:00<00:00, 1112.96it/s]
100%|██████████| 149/149 [00:00<00:00, 1111.44it/s]
100%|██████████| 149/149 [00:00<00:00, 1111.84it/s]
100%|██████████| 149/149 [00:00<00:00, 1109.59it/s]
100%|██████████| 149/149 [00:00<00:00, 1111.46it/s]
100%|██████████| 149/149 [00:00<00:00, 1116.47it/s]
100%|██████████| 149/149 [00:00<00:00, 1104.93it/s]
100%|██████████| 149/149 [00:00<00:00, 1108.87it/s]
100%|██████████| 149/149 [00:00<00:00, 1114.51it/s]
100%|██████████| 149/149 [00:00<00:00, 1083.65it/s]
100%|██████████| 149/149 [00:00<00:00, 1095.83it/s]
100%|██████████| 149/149 [00:00<00:00, 904.45it/s] 
100%|██████████| 149/149 [00:00<00:00, 1121.50it/s]
100%|██████████| 149/149 [00:00<00:00, 1115.78it/s]
100%|██████████| 149/149 [00:00<00:00, 1114.32it/s]
100%|████████