In [2]:
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class Adapted_loss(nn.Module):
    def __init__(self):
        super().__init__()

        self.loss = nn.CrossEntropyLoss(weight=torch.Tensor([1/13, 6/13, 6/13]))
      
    def forward(self, prediction, target):
        mask = target!=3
        y = target[mask]
        ypred = prediction.transpose(0,1)
        ypred = ypred[:, mask].transpose(0,1)
        return self.loss(ypred, y)

In [4]:
class TrainDataset():
    def __init__(self, path):

        eddies_train = xr.open_dataset('eddies_train.nc')
        X_train = xr.open_dataset('OSSE_U_V_SLA_SST_train.nc')
        
        y = eddies_train.eddies.values
        
        X_verti = X_train.vomecrtyT.values
        X_hori = X_train.vozocrtxT.values
        X_SSH = X_train.sossheig.values
        X_SST = X_train.votemper.values
        
        ##Transformation
        X = np.array([X_verti, X_hori, X_SSH, X_SST])
        X = X.transpose((1,0,2,3))
        
        y = np.nan_to_num(y, nan=3)
            
        #Enregistre les index correspondant aux bords
        edges_index = []
        for i in range(1,X.shape[2]-1):
            for j in range (1,X.shape[3]-1):
                if np.isnan(X[32, 3, i, j]):
                    if np.any(np.isnan(X[32, 3, i-1:i+2, j-1:j+2])!=True):
                        edges_index.append((i, j))  
        
        
        for img_index in tqdm(range(X.shape[0])):
            for index in edges_index:
                i, j = index
                X[img_index, :, i, j] = np.mean(X[img_index, :, i-1:i+2, j-1:j+2], axis=(1,2))

        
        X = np.nan_to_num(X, nan=0)
    
        ##Normalisation
        
        X = (X - np.min(X, axis=(0,2,3), keepdims=True))/( np.max(X, axis=(0,2,3), keepdims=True) - np.min(X, axis=(0,2,3), keepdims=True) )
        
        ##Augmentation 
        
        X = torch.tensor(X)
        y = torch.tensor(y, dtype=torch.long).reshape(y.shape[0],1, y.shape[1],y.shape[2])
        data = torch.cat((X,y),dim=1)
        
        data_aug = image_transforms(data)
        
        X_aug  = data_aug[:,:-1,:,:]
        y_aug = data_aug[:,-1,:,:]
        
        self.X_train = torch.cat((X, X_aug), dim=0)
        self.y_train = torch.cat((y.reshape(y.shape[0], y.shape[2], y.shape[3]), y_aug), dim=0)
 
        self.image_transforms = image_transforms   
        
        
    def __len__(self):
        return len(self.X_train)

    def __getitem__(self, idx):
        return self.X_train[idx], self.y_train[idx]