In [1]:
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau, ExponentialLR, CosineAnnealingLR
import torch.nn.functional as F

# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')


exp_name = 'Transformer'
# assume non overlapping, in this case, SEQUENCE_SIZE is fixed
SEQUENCE_SIZE = 5
LEARNING_RATE = 1e-4
# set BATCH_SIZE for 32 for demostration, in climsim is 512
BATCH_SIZE = 32
# run for 10 epochs for demostration, the actual epoch is 200
EPOCHS = 10
IN_FEATURES = 60*4
OUT_FEATURES = 60
D_MODEL = 256
# DROPOUT = 0.2
N_HEAD = 4
N_LAYER = 6
MAX_LEN = SEQUENCE_SIZE

Using device: cpu


In [2]:
# dummy datasets, assume input: 60*4 vars, output: 60 vars
# train: 5 years, test: 1 year
X_train = torch.randn(365*5,IN_FEATURES)
y_train = torch.randn(365*5,OUT_FEATURES)
X_test = torch.randn(365,IN_FEATURES)
y_test = torch.randn(365,OUT_FEATURES)

def create_non_overlapping(tensor, features):
    return tensor.reshape(int(tensor.shape[0] / SEQUENCE_SIZE), SEQUENCE_SIZE, features)

X_train = create_non_overlapping(X_train, IN_FEATURES)
y_train = create_non_overlapping(y_train, OUT_FEATURES)
X_test = create_non_overlapping(X_test, IN_FEATURES)
y_test = create_non_overlapping(y_test, OUT_FEATURES)
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

# create datasets
training_set = TensorDataset(X_train, y_train)
testing_set = TensorDataset(X_test, y_test)
# create dataloaders
train_dataloader = DataLoader(training_set, # dataset to turn into iterable
    batch_size=BATCH_SIZE, # how many samples per batch? 
    shuffle=True # shuffle data every epoch?
)
test_dataloader = DataLoader(testing_set,
    batch_size=BATCH_SIZE,
    shuffle=False # don't necessarily have to shuffle the testing data
)

torch.Size([365, 5, 240]) torch.Size([365, 5, 60]) torch.Size([73, 5, 240]) torch.Size([73, 5, 60])


In [3]:
from models import TransformerModel

model = TransformerModel(input_dim=IN_FEATURES, 
                         output_dim=OUT_FEATURES, 
                         d_model=D_MODEL, 
                         nhead=N_HEAD, 
                         num_layers=N_LAYER,
                         max_len=MAX_LEN).to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
# scheduler is not used in this example
scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=40)

In [4]:
train_loss_list = []
test_loss_list = []

for epoch in range(EPOCHS):
    train_loss = 0
    for batch, (X, y) in enumerate(train_dataloader):
        model.train()
        y_pred = model(X.to(device))
        loss = loss_fn(y_pred, y.to(device))
        train_loss += loss # accumulatively add up the loss per epoch 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Divide total train loss by length of train dataloader (average loss per batch per epoch)
    train_loss /= len(train_dataloader)
    train_loss_list.append(train_loss.detach().cpu().numpy())
    
    model.eval()
    with torch.no_grad():
        test_loss = 0
        for X, y in test_dataloader:
            test_pred = model(X.to(device))
            test_loss += loss_fn(test_pred, y.to(device)) # accumulatively add up the loss per epoch
        
        # Divide total test loss by length of test dataloader (per batch)
        test_loss /= len(test_dataloader)
        test_loss_list.append(test_loss.detach().cpu().numpy())

    print(f"Epoch: {epoch} | Train loss: {train_loss:.5f} | Test loss: {test_loss:.5f}")

Epoch: 0 | Train loss: 1.07093 | Test loss: 1.00268
Epoch: 1 | Train loss: 1.01239 | Test loss: 0.99753
Epoch: 2 | Train loss: 1.01042 | Test loss: 0.99500
Epoch: 3 | Train loss: 1.01244 | Test loss: 0.99544
Epoch: 4 | Train loss: 1.00761 | Test loss: 0.99629


KeyboardInterrupt: 