In [1]:
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import pandas as pd

In [2]:
#Model architecture

class Unet(nn.Module):
    """
    Our modified Unet :
    Use of padding to keep size of input in output easily.
    Use of batchnorm2d after Conv2d
    """
    def __init__(self):
        super().__init__()

        self.downblock1 = nn.Sequential(
            # nn.Dropout2d(0.2),
            nn.Conv2d(4, 64, kernel_size=3, padding='same', padding_mode='replicate'),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding='same', padding_mode='replicate'),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.downblock2 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2),
            # nn.Dropout2d(0.2),
            nn.Conv2d(64, 128, kernel_size=3, padding='same', padding_mode='replicate'),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding='same', padding_mode='replicate'),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.downblock3 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2),
            # nn.Dropout2d(0.2),          
            nn.Conv2d(128, 256, kernel_size=3, padding='same', padding_mode='replicate'),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding='same', padding_mode='replicate'),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.middleU = nn.Sequential(
            nn.MaxPool2d(kernel_size=2),
            # nn.Dropout2d(0.2),
            nn.Conv2d(256, 512, kernel_size=3, padding='same', padding_mode='replicate'),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding='same', padding_mode='replicate'),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            # nn.Dropout2d(0.2),
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2)
        )

        self.upblock1 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding='same', padding_mode='replicate'),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding='same', padding_mode='replicate'),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            # nn.Dropout2d(0.2),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        )

        self.upblock2 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding='same', padding_mode='replicate'),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding='same', padding_mode='replicate'),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            # nn.Dropout2d(0.2),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2)
        )

        self.upblock3 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding='same', padding_mode='replicate'),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding='same', padding_mode='replicate'),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=1)
        )
    

    def forward(self, x):
        
        x1 = self.downblock1(x)

        x2 = self.downblock2(x1)

        x3 = self.downblock3(x2)

        xmiddle = self.middleU(x3)

        xup0_1 = torch.cat((x3,xmiddle), dim=1)
        xup1 = self.upblock1(xup0_1)

        xup1_2 = torch.cat((x2,xup1), dim=1)
        xup2 = self.upblock2(xup1_2)

        xup2_3 = torch.cat((x1,xup2), dim=1)
        xup3 = self.upblock3(xup2_3)

        return xup3


In [3]:
# Instantiate a new empty model
model = Unet()

print(model)

# Load state
checkpoint_path = "UNetV1.pt"
model.load_state_dict(torch.load(checkpoint_path))

print("Model Loaded")

Unet(
  (downblock1): Sequential(
    (0): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=same, padding_mode=replicate)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same, padding_mode=replicate)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (downblock2): Sequential(
    (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same, padding_mode=replicate)
    (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU(inplace=True)
    (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same, padding_mode=replicate)
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inpl

In [4]:
class TestDataset():

    def __init__(self, path):
        X_test = xr.open_dataset(path + '/OSSE_U_V_SLA_SST_test.nc')
        
        X_verti = X_test.vomecrtyT.values
        X_hori = X_test.vozocrtxT.values
        X_SSH = X_test.sossheig.values
        X_SST = X_test.votemper.values
        
        ##Transformation
        X = np.array([X_verti, X_hori, X_SSH, X_SST])
        X = X.transpose((1,0,2,3))
                    
        land_and_sea = np.zeros([X.shape[0], X.shape[2], X.shape[3]])
        for batch in range(X.shape[0]):
            for i in range(X.shape[2]):
                for j in range(X.shape[3]):
                    if np.isnan(X[batch, 0, i, j]):
                        land_and_sea[batch, i, j] = 1
    
        #Enregistre les index correspondant aux bords
        edges_index = []
        for i in range(1,X.shape[2]-1):
            for j in range (1,X.shape[3]-1):
                if np.isnan(X[32, 3, i, j]):
                    if np.any(np.isnan(X[32, 3, i-1:i+2, j-1:j+2])!=True):
                        edges_index.append((i, j))  
        
        
        #Lisse les bords au niveau des index sélectionnés
        for img_index in tqdm(range(X.shape[0])):
            for index in edges_index:
                i, j = index
                X[img_index, :, i, j] = np.mean(X[img_index, :, i-1:i+2, j-1:j+2], axis=(1,2))
        
        X = np.nan_to_num(X, nan=0)
    
        ##Normalisation
        X = (X - np.min(X, axis=(0,2,3), keepdims=True))/( np.max(X, axis=(0,2,3), keepdims=True) - np.min(X, axis=(0,2,3), keepdims=True) )

        
        X = torch.tensor(X)        
        
        self.X_test = X
        self.land_and_sea = torch.Tensor(land_and_sea)
    
    
    def __len__(self):
        return len(self.X_test)
    
    def __getitem__(self, idx):
        return self.X_test[idx], self.land_and_sea[idx]

In [5]:
path = '../Data' # local
test_dataset = TestDataset(path)

#test_dataloader = DataLoader(test_dataset, batch_size=1)

100%|███████████████████████████████████████████| 72/72 [00:01<00:00, 48.87it/s]


In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model = model.to(device)

test_dataset.X_test.shape

cuda:0


torch.Size([72, 4, 357, 717])

In [8]:
# df = pd.DataFrame()
df = []


model.eval()
with torch.no_grad():
    X_full, land_and_sea_full = test_dataset.X_test, test_dataset.land_and_sea

    for batch in tqdm(range(72//9)):
        X = X_full[batch*9: (batch+1)*9]
        land_and_sea = land_and_sea_full[batch*9: (batch+1)*9]

        X = X.to(device)
    #     land_and_sea = land_and_sea.to(device)

        pred = model(X)
        pred = pred.argmax(axis=1)

        pred = pred.cpu()
        pred[land_and_sea==1] = 999
        

        #youpi = pd.concat([str(i)+"_"+pd.Series(id_type) for i in range(batch*9, (batch+1)*9)], axis=0)
        heureux = pd.Series(pred[:,:, :].flatten())
        
        temp_df = heureux
        
        df.append(temp_df)
        #temp_df = pd.concat([youpi,heureux],axis=1)
        
#         df = pd.concat([df, temp_df], axis=0)

df = pd.concat(df, axis=0)

100%|█████████████████████████████████████████████| 8/8 [00:01<00:00,  4.03it/s]


In [None]:
list_of_id = []
for batch in tqdm(range(72)):
    for x in range(357):
        for y in range(717):
            list_of_id.append(str(batch)+"_"+str(x)+"_"+str(y))

In [12]:
df = pd.DataFrame({"Id": list_of_id, "Predicted": df})

In [14]:
#Uncomment to write in a csv file
file_name = "result10.csv"
df.to_csv(file_name, index=False)