In [1]:
import xarray as xr
import os
import netCDF4
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import random
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [2]:
has_mps = torch.backends.mps.is_built()
device = "mps" if has_mps else "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [3]:
X=np.load('/work/sds-lab/Shuochen/climsim/val_input.npy')
y=np.load('/work/sds-lab/Shuochen/climsim/val_target.npy')

X = torch.from_numpy(X).type(torch.float).to(device)
y = torch.from_numpy(y).type(torch.float).to(device)

In [4]:
LEARNING_RATE = 0.01
IN_FEATURES = 124
OUT_FEATURES = 128
RANDOM_SEED = 42

X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.2,random_state=RANDOM_SEED)

In [5]:
# Setup data loaders for batch
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [6]:
# Positional Encoding for Transformer
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [7]:
# Model definition using Transformer
class TransformerModel(nn.Module):
    def __init__(self, input_dim=IN_FEATURES, output_dim = OUT_FEATURES, d_model=64, nhead=4, num_layers=2, dropout=0.2):
        super(TransformerModel, self).__init__()

        self.encoder = nn.Linear(input_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        self.decoder = nn.Linear(d_model, OUT_FEATURES)

    def forward(self, x):
        x = self.encoder(x)
        x = self.pos_encoder(x)
        x = self.transformer_encoder(x)
        x = self.decoder(x[:, -1, :])
        return x

model = TransformerModel().to(device)



In [8]:
# Train the model
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3, verbose=True)

epochs = 100
early_stop_count = 0
min_val_loss = float('inf')
f = open("loss_Transformer.txt", "w")

for epoch in range(epochs):
    model.train()
    for batch in train_loader:
        x_batch, y_batch = batch   
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(x_batch)
        
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    val_losses = []
    with torch.no_grad():
        for batch in test_loader:
            x_batch, y_batch = batch
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            outputs = model(x_batch)
            loss = criterion(outputs, y_batch)
            val_losses.append(loss.item())

    val_loss = np.mean(val_losses)
    scheduler.step(val_loss)
    if epoch % 1 == 0:
        print(f"Epoch: {epoch} Test loss: {val_losses[0]}")
        
    # f.write(str(epoch) + '\t' + f"{float(loss):.6f}" + '\t' + f"{float(test_loss):.6f}" + '\n')
    
f.close()




Epoch: 0 Test loss: 0.01355045661330223
Epoch: 1 Test loss: 0.013412756845355034
Epoch: 2 Test loss: 0.013395403511822224
Epoch: 3 Test loss: 0.013490053825080395
Epoch: 4 Test loss: 0.013409774750471115
Epoch: 5 Test loss: 0.01336912252008915
Epoch: 6 Test loss: 0.013655360788106918
Epoch: 7 Test loss: 0.013361834920942783
Epoch: 8 Test loss: 0.013327253051102161
Epoch: 9 Test loss: 0.013392176479101181
Epoch: 10 Test loss: 0.01334233395755291
Epoch: 11 Test loss: 0.013339066877961159
Epoch: 12 Test loss: 0.0133614931255579
Epoch: 13 Test loss: 0.013331624679267406
Epoch: 14 Test loss: 0.013303442858159542
Epoch: 15 Test loss: 0.013338404707610607
Epoch: 16 Test loss: 0.013479919172823429
Epoch: 17 Test loss: 0.013285274617373943
Epoch: 18 Test loss: 0.01331124734133482
Epoch: 19 Test loss: 0.013345104642212391
Epoch: 20 Test loss: 0.013304116204380989
Epoch: 21 Test loss: 0.013317163102328777
Epoch: 22 Test loss: 0.013308174908161163
Epoch: 23 Test loss: 0.01333095133304596
Epoch: 24