In [1]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import xrft
import dask
import time
import torch.optim.lr_scheduler as lr_scheduler
import random
from torchvision.models import vgg16
from sklearn.metrics import r2_score, mean_squared_error
import zarr

In [2]:
og_data_x = np.squeeze(xr.open_zarr('gs://leap-persistent/funky-user/ds1_downscale.zarr').to_array(), axis=0)

In [3]:
og_data_y = np.squeeze(xr.open_zarr('gs://leap-persistent/funky-user/ds1_norm.zarr').to_array(), axis=0)

In [4]:
x_train_0 = og_data_x.isel(dataset1_filt=slice(0, 34))
x_val_0 = og_data_x.isel(dataset1_filt=slice(34, 67))
x_test_0 = og_data_x.isel(dataset1_filt=slice(67, 100))

In [35]:
x_train_0_reshape = x_train_0.stack(new_dim=('dataset1_filt', 'time')).expand_dims(dim={'channel': 1})
x_val_0_reshape = x_val_0.stack(new_dim=('dataset1_filt', 'time')).expand_dims(dim={'channel': 1})
x_test_0_reshape = x_test_0.stack(new_dim=('dataset1_filt', 'time')).expand_dims(dim={'channel': 1})

x_train_0_reshape = x_train_0_reshape.transpose('new_dim','channel','j','i')
x_val_0_reshape = x_val_0_reshape.transpose('new_dim','channel','j','i')
x_test_0_reshape = x_test_0_reshape.transpose('new_dim','channel','j','i')

In [5]:
y_train_0 = og_data_y.isel(datasets1=slice(0, 34))
y_val_0 = og_data_y.isel(datasets1=slice(34, 67))
y_test_0 = og_data_y.isel(datasets1=slice(67, 100))

In [36]:
y_train_0_reshape = y_train_0.stack(new_dim=('datasets1', 'time')).expand_dims(dim={'channel': 1})
y_val_0_reshape = y_val_0.stack(new_dim=('datasets1', 'time')).expand_dims(dim={'channel': 1})
y_test_0_reshape = y_test_0.stack(new_dim=('datasets1', 'time')).expand_dims(dim={'channel': 1})

y_train_0_reshape = y_train_0_reshape.transpose('new_dim','channel','j','i')
y_val_0_reshape = y_val_0_reshape.transpose('new_dim','channel','j','i')
y_test_0_reshape = y_test_0_reshape.transpose('new_dim','channel','j','i')

In [38]:
x_train = torch.tensor(x_train_0_reshape.values, dtype=torch.float32)
x_val = torch.tensor(x_val_0_reshape.values, dtype=torch.float32)
x_test = torch.tensor(x_test_0_reshape.values, dtype=torch.float32)

In [39]:
y_train = torch.tensor(y_train_0_reshape.values, dtype=torch.float32)
y_val = torch.tensor(y_val_0_reshape.values, dtype=torch.float32)
y_test = torch.tensor(y_test_0_reshape.values, dtype=torch.float32)

In [41]:
torch.save(y_train, 'y_train.pt')
torch.save(y_val, 'y_val.pt')
torch.save(y_test, 'y_test.pt')
torch.save(x_train, 'x_train.pt')
torch.save(x_val, 'x_val.pt')
torch.save(x_test, 'x_test.pt')

In [None]:
y_train = torch.load('y_train.pt')
y_

In [42]:
train_dataset = TensorDataset(x_train, y_train)
val_dataset = TensorDataset(x_val, y_val)
test_dataset = TensorDataset(x_test, y_test)
# Create DataLoaders
batch_size = 34
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [43]:
#Loss function with two terms: usual MSE and gardient loss

class CombinedLoss(nn.Module):
    def __init__(self, alpha=100):
        super(CombinedLoss, self).__init__()
        self.mse_loss = nn.MSELoss()
        self.alpha = alpha

    def compute_gradient(self, img):
        sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=torch.float32, device=img.device).unsqueeze(0).unsqueeze(0)
        sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32, device=img.device).unsqueeze(0).unsqueeze(0)
        
        grad_x = F.conv2d(img, sobel_x, padding=1, groups=img.shape[1])
        grad_y = F.conv2d(img, sobel_y, padding=1, groups=img.shape[1])
   
        return grad_x, grad_y
  
    def forward(self, output, target):
        mse_loss = self.mse_loss(output, target)
    
        output_grad_x, output_grad_y = self.compute_gradient(output)
        target_grad_x, target_grad_y = self.compute_gradient(target)
        grad_loss_x = self.mse_loss(output_grad_x, target_grad_x)
        grad_loss_y = self.mse_loss(output_grad_y, target_grad_y)
        grad_loss = grad_loss_x + grad_loss_y
  
    # Combine losses
        combined_loss = mse_loss + self.alpha * grad_loss
        return combined_loss

In [44]:
def print_model_summary(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Total Parameters: {total_params}')
    print(f'Trainable Parameters: {trainable_params}')

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=25):
    print_model_summary(model)

    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    early_stopping_counter = 0
    patience = 6

    for epoch in range(num_epochs):
        start_time = time.time()

        model.train()
        running_loss = 0.0
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            # Forward pass
            outputs = model(batch_x)
            
            loss = criterion(outputs, batch_y)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * batch_x.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_loss)

        # Validation phase
        model.eval()
        val_running_loss = 0.0
        with torch.no_grad():
            for batch_x, batch_y in val_loader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)

                outputs = model(batch_x)
                loss = criterion(outputs, batch_y)

                val_running_loss += loss.item() * batch_x.size(0)

        val_loss = val_running_loss / len(val_loader.dataset)
        val_losses.append(val_loss)

        scheduler.step(val_loss)  # Adjust learning rate based on the validation loss

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= patience:
                print('Early stopping triggered')
                break

        # Calculate epoch duration
        end_time = time.time()
        epoch_duration = end_time - start_time

        # Check peak memory usage
        peak_memory = torch.cuda.max_memory_allocated(device=device) / (1024 ** 2)

        # Print epoch summary
        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}, Epoch Time: {epoch_duration:.2f}s, Peak Memory Usage: {peak_memory:.2f}MB')

    print('Training complete')

    # Plotting
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.legend()
    plt.grid(True)
    plt.show()

In [45]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Encoder, strip out for now
        self.enc_conv1 = self.conv_block(in_channels=1, out_channels=16)
        self.enc_conv1_1 = self.conv_block(in_channels=16, out_channels=16)
        self.enc_conv2 = self.conv_block(in_channels=16, out_channels=32)
        self.enc_conv2_2 = self.conv_block(in_channels=32, out_channels=32)
        self.enc_conv3 = self.conv_block(in_channels=32, out_channels=64)
        self.enc_conv3_3 = self.conv_block(in_channels=64, out_channels=64)
        self.enc_conv4 = self.conv_block(in_channels=64, out_channels=128)
        self.enc_conv4_4 = self.conv_block(in_channels=128, out_channels=128)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bottleneck
        self.bottleneck_conv0 = self.conv_block(in_channels=128, out_channels=256)
        self.bottleneck_conv_0 = self.conv_block(in_channels=256, out_channels=256)

        # Decoder
        self.upconv4 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=2, stride=2)
        self.dec_conv4 = self.conv_block(in_channels=256, out_channels=256)
        self.dec_conv4_4 = self.conv_block(in_channels=256, out_channels=128)
        self.upconv3 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=2, stride=2)
        self.dec_conv3 = self.conv_block(in_channels=128, out_channels=128)
        self.dec_conv3_3 = self.conv_block(in_channels=128, out_channels=64)
        self.upconv2 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2)
        self.dec_conv2 = self.conv_block(in_channels=64, out_channels=64)
        self.dec_conv2_2 = self.conv_block(in_channels=64, out_channels=32)
        self.upconv1 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=2, stride=2)
        self.dec_conv1 = self.conv_block(in_channels=32, out_channels=32)
        self.dec_conv1_1 = self.conv_block(in_channels=32, out_channels=16)

        self.final_conv = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        return block

    def forward(self, x):
        # Encoder
        x1 = self.enc_conv1_1(self.enc_conv1(x))
        x2 = self.pool(x1)
        x3 = self.enc_conv2_2(self.enc_conv2(x2))
        x4 = self.pool(x3)
        x5 = self.enc_conv3_3(self.enc_conv3(x4))
        x6 = self.pool(x5)
        x7 = self.enc_conv4_4(self.enc_conv4(x6))
        x8 = self.pool(x7)

        # Bottleneck
        p = self.bottleneck_conv_0(self.bottleneck_conv0(x8))
        return(p)
        # Decoder
        u4 = self.upconv4(p)
        u4 = self.dec_conv4(u4)
        u4 = F.interpolate(u4, size=x7.size()[2:], mode='bilinear', align_corners=True)
        c4 = torch.cat([u4, x7], dim=1)
        c4 = self.dec_conv4_4(c4)

        u3 = self.upconv3(c4)
        u3 = self.dec_conv3(u3)
        u3 = F.interpolate(u3, size=x5.size()[2:], mode='bilinear', align_corners=True)
        c3 = torch.cat([u3, x5], dim=1)
        c3 = self.dec_conv3_3(c3)

        u2 = self.upconv2(c3)
        u2 = self.dec_conv2(u2)
        u2 = F.interpolate(u2, size=x3.size()[2:], mode='bilinear', align_corners=True)
        c2 = torch.cat([u2, x3], dim=1)
        c2 = self.dec_conv2_2(c2)

        u1 = self.upconv1(c2)
        u1 = self.dec_conv1(u1)
        u1 = F.interpolate(u1, size=x1.size()[2:], mode='bilinear', align_corners=True)
        c1 = torch.cat([u1, x1], dim=1)
        c1 = self.dec_conv1_1(c1)

        out = self.final_conv(c1)

        return out

In [46]:
# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_unet = UNet()
model_unet.to(device)

criterion = CombinedLoss(alpha=10)
optimizer = torch.optim.Adam(model_unet.parameters(), lr=0.001)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=2, verbose=True)

# Training and validation
train_model(model_unet, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=100)



Total Parameters: 2706321
Trainable Parameters: 2706321


  return F.mse_loss(input, target, reduction=self.reduction)


RuntimeError: Given groups=256, expected weight to be at least 256 at dimension 0, but got weight of size [1, 1, 3, 3] instead