In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import h5py
import os
import pandas as pd

# Check if CUDA is available
cuda_available = torch.cuda.is_available()
print(f"CUDA available: {cuda_available}")

# If available, print the name of the GPU
if cuda_available:
    print(f"Device name: {torch.cuda.get_device_name(0)}")
    print(f"Device count: {torch.cuda.device_count()}")

DATA_DIRECTORY = '/oak/stanford/groups/earlew/yuchen'


CUDA available: True
Device name: Tesla V100-PCIE-32GB
Device count: 1


In [2]:
class SeaIceDataset(Dataset):
    def __init__(self, data_directory, configuration, split_array, split_type='train', target_shape=(336, 320)):
        self.data_directory = data_directory
        self.configuration = configuration
        self.split_array = split_array
        self.split_type = split_type
        self.target_shape = target_shape

        # Open the HDF5 files
        self.inputs_file = h5py.File(f"{data_directory}/inputs_{configuration}.h5", 'r')
        self.targets_file = h5py.File(f"{data_directory}/targets.h5", 'r')
        
        self.inputs = self.inputs_file[f"inputs_{configuration}"]
        self.targets = self.targets_file['targets_sea_ice_only']

        self.n_samples, self.n_channels, self.n_y, self.n_x = self.inputs.shape
        
        # Get indices for the specified split type
        self.indices = np.where(self.split_array == split_type)[0]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        input_data = self.inputs[actual_idx]
        target_data = self.targets[actual_idx]

        # Pad input_data and target_data to the target shape
        pad_y = self.target_shape[0] - self.n_y
        pad_x = self.target_shape[1] - self.n_x
        input_data = np.pad(input_data, ((0, 0), (0, pad_y), (0, pad_x)), mode='constant', constant_values=0)
        target_data = np.pad(target_data, ((0, 0), (0, pad_y), (0, pad_x)), mode='constant', constant_values=0)

        input_tensor = torch.tensor(input_data, dtype=torch.float32)
        target_tensor = torch.tensor(target_data, dtype=torch.float32)
        
        return input_tensor, target_tensor

    def __del__(self):
        self.inputs_file.close()
        self.targets_file.close()

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, n_filters_factor=1, filter_size=3):
        super(UNet, self).__init__()
        self.encoder1 = self.conv_block(in_channels, int(64 * n_filters_factor), filter_size)
        self.encoder2 = self.conv_block(int(64 * n_filters_factor), int(128 * n_filters_factor), filter_size)
        self.encoder3 = self.conv_block(int(128 * n_filters_factor), int(256 * n_filters_factor), filter_size)
        self.bottleneck = self.conv_block(int(256 * n_filters_factor), int(512 * n_filters_factor), filter_size)
        
        self.decoder1 = self.conv_block(int(512 * n_filters_factor) + int(256 * n_filters_factor), int(256 * n_filters_factor), filter_size)
        self.decoder2 = self.conv_block(int(256 * n_filters_factor) + int(128 * n_filters_factor), int(128 * n_filters_factor), filter_size)
        self.decoder3 = self.conv_block(int(128 * n_filters_factor) + int(64 * n_filters_factor), int(64 * n_filters_factor), filter_size)
        
        self.final_conv = nn.Conv2d(int(64 * n_filters_factor), out_channels, kernel_size=1)
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        
    def conv_block(self, in_channels, out_channels, filter_size):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=filter_size, padding=filter_size//2),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=filter_size, padding=filter_size//2),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels)
        )
    
    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool(enc1))
        enc3 = self.encoder3(self.pool(enc2))
        bottleneck = self.bottleneck(self.pool(enc3))
        
        dec1 = self.upsample(bottleneck)
        dec1 = torch.cat((enc3, dec1), dim=1)
        dec1 = self.decoder1(dec1)
        
        dec2 = self.upsample(dec1)
        dec2 = torch.cat((enc2, dec2), dim=1)
        dec2 = self.decoder2(dec2)
        
        dec3 = self.upsample(dec2)
        dec3 = torch.cat((enc1, dec3), dim=1)
        dec3 = self.decoder3(dec3)
        
        return torch.sigmoid(self.final_conv(dec3))


In [3]:
TRAIN_MONTHS = pd.date_range(start='1981-01-01', end='2014-12-01', freq='MS')
VAL_MONTHS = pd.date_range(start='2015-01-01', end='2018-12-01', freq='MS')
TEST_MONTHS = pd.date_range(start='2019-01-01', end='2024-06-01', freq='MS')

# Construct the date range for the data pairs 
# Note that this is not continuous due to the missing data in 1987-1988 
first_range = pd.date_range('1981-01', pd.Timestamp('1987-12') - pd.DateOffset(months=6+1), freq='MS')
second_range = pd.date_range(pd.Timestamp('1988-01') + pd.DateOffset(months=12+1), '2024-01', freq='MS')
start_prediction_months = first_range.append(second_range)

split_array = np.empty(np.shape(start_prediction_months), dtype=object)
for i,month in enumerate(start_prediction_months):
    if month in TRAIN_MONTHS: split_array[i] = "train"
    if month in VAL_MONTHS: split_array[i] = "val"
    if month in TEST_MONTHS: split_array[i] = "test"

def print_split_stats(split_array):
    ntrain = sum(split_array == 'train')
    nval = sum(split_array == 'val')
    ntest = sum(split_array == 'test')
    
    print(f"train samples: {ntrain} ({round(ntrain / len(split_array), 2)}) \n \
    val samples: {nval} ({round(nval / len(split_array), 2)}) \n \
    test samples: {ntest} ({round(ntest / len(split_array), 2)})")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNet(in_channels=23, out_channels=6, n_filters_factor=1, filter_size=3).to(device)

criterion = nn.MSELoss()  
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

print_split_stats(split_array)

train samples: 388 (0.78) 
     val samples: 48 (0.1) 
     test samples: 61 (0.12)


In [4]:
data_directory = os.path.join(DATA_DIRECTORY, 'sicpred/data_pairs_npy')
configuration = 'simple'
batch_size = 32 

# Create dataset instances for training, validation, and testing
train_dataset = SeaIceDataset(data_directory, configuration, split_array, split_type='train', target_shape=(336, 320))
val_dataset = SeaIceDataset(data_directory, configuration, split_array, split_type='val', target_shape=(336, 320))
test_dataset = SeaIceDataset(data_directory, configuration, split_array, split_type='test', target_shape=(336, 320))

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=1)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1)


In [36]:
for inputs, targets in train_loader:
    # Your training code here
    print(inputs.shape, targets.shape)


torch.Size([32, 23, 336, 320]) torch.Size([32, 6, 336, 320])
torch.Size([32, 23, 336, 320]) torch.Size([32, 6, 336, 320])
torch.Size([32, 23, 336, 320]) torch.Size([32, 6, 336, 320])
torch.Size([32, 23, 336, 320]) torch.Size([32, 6, 336, 320])
torch.Size([32, 23, 336, 320]) torch.Size([32, 6, 336, 320])
torch.Size([32, 23, 336, 320]) torch.Size([32, 6, 336, 320])
torch.Size([32, 23, 336, 320]) torch.Size([32, 6, 336, 320])
torch.Size([32, 23, 336, 320]) torch.Size([32, 6, 336, 320])
torch.Size([32, 23, 336, 320]) torch.Size([32, 6, 336, 320])
torch.Size([32, 23, 336, 320]) torch.Size([32, 6, 336, 320])
torch.Size([32, 23, 336, 320]) torch.Size([32, 6, 336, 320])
torch.Size([32, 23, 336, 320]) torch.Size([32, 6, 336, 320])
torch.Size([4, 23, 336, 320]) torch.Size([4, 6, 336, 320])


In [5]:
num_epochs = 20

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

    # Validation loop
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item() * inputs.size(0)
    
    val_loss /= len(val_loader.dataset)
    print(f"Validation Loss: {val_loss:.4f}")

# Testing loop (optional)
model.eval()
test_loss = 0.0
with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        test_loss += loss.item() * inputs.size(0)

test_loss /= len(test_loader.dataset)
print(f"Test Loss: {test_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), f'{DATA_DIRECTORY}/sicpred/prelim_test_unet_model.pth')


Epoch [1/20], Loss: 0.7887
Validation Loss: 1.0259
Epoch [2/20], Loss: 0.6099
Validation Loss: 0.8414
Epoch [3/20], Loss: 0.5776
Validation Loss: 0.6273
Epoch [4/20], Loss: 0.5641
Validation Loss: 0.5716
Epoch [5/20], Loss: 0.5561
Validation Loss: 0.5571
Epoch [6/20], Loss: 0.5505
Validation Loss: 0.5514
Epoch [7/20], Loss: 0.5457
Validation Loss: 0.5478
Epoch [8/20], Loss: 0.5405
Validation Loss: 0.5434
Epoch [9/20], Loss: 0.5366
Validation Loss: 0.5426
Epoch [10/20], Loss: 0.5334
Validation Loss: 0.5397
Epoch [11/20], Loss: 0.5304
Validation Loss: 0.5373
Epoch [12/20], Loss: 0.5287
Validation Loss: 0.6008
Epoch [13/20], Loss: 0.5299
Validation Loss: 0.5377
Epoch [14/20], Loss: 0.5234
Validation Loss: 0.5308
Epoch [15/20], Loss: 0.5211
Validation Loss: 0.5272
Epoch [16/20], Loss: 0.5192
Validation Loss: 0.5272
Epoch [17/20], Loss: 0.5175
Validation Loss: 0.5261
Epoch [18/20], Loss: 0.5161
Validation Loss: 0.5255
Epoch [19/20], Loss: 0.5147
Validation Loss: 0.5250
Epoch [20/20], Loss: 

In [2]:
inputs_simple = np.load(f"{DATA_DIRECTORY}/sicpred/data_pairs_npy/inputs_sea_ice_only.npy")


In [3]:
np.shape(inputs_simple)

(497, 332, 316, 14)

In [None]:
def unet(input_shape, loss, learning_rate=1e-4, n_filters_factor=1, filter_size=3):
    inputs = Input(shape=input_shape)

    conv1 = Conv2D(int(64 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
    conv1 = Conv2D(int(64 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
    bn1 = BatchNormalization(axis=-1)(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(bn1)

    conv2 = Conv2D(int(128 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
    conv2 = Conv2D(int(128 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
    bn2 = BatchNormalization(axis=-1)(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(bn2)

    conv3 = Conv2D(int(256 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
    conv3 = Conv2D(int(256 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
    bn3 = BatchNormalization(axis=-1)(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(bn3)

    conv4 = Conv2D(int(512 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
    conv4 = Conv2D(int(512 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
    bn4 = BatchNormalization(axis=-1)(conv4)

    up5 = Conv2D(int(256 * n_filters_factor), 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2,2), interpolation='nearest')(bn4))
    merge5 = concatenate([bn3, up5], axis=3)
    conv5 = Conv2D(int(256 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(merge5)
    conv5 = Conv2D(int(256 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
    bn5 = BatchNormalization(axis=-1)(conv5)

    up6 = Conv2D(int(128 * n_filters_factor), 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2,2), interpolation='nearest')(bn5))
    merge6 = concatenate([bn2,up6], axis=3)
    conv6 = Conv2D(int(128 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
    conv6 = Conv2D(int(128 * n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
    bn6 = BatchNormalization(axis=-1)(conv6)

    up7 = Conv2D(int(64*n_filters_factor), 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2,2), interpolation='nearest')(bn6))
    merge7 = concatenate([bn1,up7], axis=3)
    conv7 = Conv2D(int(64*n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
    conv7 = Conv2D(int(64*n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)
    conv7 = Conv2D(int(64*n_filters_factor), filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)

    final_layer = Conv2D(1, (1, 1), activation='sigmoid', kernel_initializer='he_normal')(conv7)

    model = Model(inputs, final_layer)

    model.compile(optimizer=Adam(learning_rate=learning_rate), loss=loss)

    return model


In [19]:
unet(np.shape(inputs_cropped)[1:], MeanSquaredError)

<Functional name=functional_3, built=True>