# GP

What MSE does the GP achieve

In [1]:
import torch

import pyro
import pyro.contrib.gp as gp

from os.path import join

import sys
sys.path.append('../')
import utils

In [2]:
assert torch.cuda.is_available(), 'CUDA is not available.'

## Load Data

In [3]:
fname = join('/home/squirt/Documents/data/weather_data/', 'all_data.h5')

In [4]:
split = 0.5
train_data, test_data = utils.get_data(fname, split)

In [5]:
def combine_data(data_tuple:tuple[torch.tensor]) -> tuple[torch.tensor]:
    '''
    Add Landmass to x data. Return x,y tensors
    Input:
        - data_tuple (tuple[torch.tensor]): tuple of landmass,x,y tensors
    '''
    l,x,y = data_tuple

    # Combine
    l = l.unsqueeze(1)
    x = torch.cat((l, x), 1)
    x = x.contiguous()
    y = y.contiguous()
    return (x,y)

In [6]:
train_data = combine_data(train_data)
test_data = combine_data(test_data)

## Define GP

Using Pyro for 1st try

In [7]:
x,y = train_data
x = x.view(-1, 71*3*2*2)
y = y.view(-1, 70*2*2*2)
y = y.transpose(0,1)

x = x.cuda()
y = y.cuda()

In [8]:
x_test, y_test = test_data
x_test = x_test.view(-1, 71*3*2*2)
y_test = y_test.view(-1, 70*2*2*2)
y_test = y_test.transpose(0,1)

x_test = x_test.cuda()
y_test = y_test.cuda()

In [9]:
print(x.shape)
print(y.shape)

torch.Size([9600, 852])
torch.Size([560, 9600])


In [10]:
# Initialize model
kernel = gp.kernels.RBF(input_dim=x.size(1))
gpr = gp.models.GPRegression(x, y, kernel)
gpr = gpr.cuda()

In [13]:
# Train Model
optimizer = torch.optim.Adam(gpr.parameters(), lr=0.001)
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss

for i in range(10):
    # Train Model
    optimizer.zero_grad() 
    loss = loss_fn(gpr.model, gpr.guide)
    loss.backward()
    optimizer.step()

    # Test Model (MSE Loss)
    with torch.no_grad():
        y_pred = gpr(x_test)[0]
        #mse = torch.mean((y_test - y_pred)**2)
        #print(f'Epoch {i} Eval Loss: {mse.item()}')
        loss = torch.nn.functional.mse_loss(y_pred, y_test)
        print(f'Epoch {i} Eval Loss: {loss.item()}')



Epoch 0 Eval Loss: 0.687798448864547
Epoch 1 Eval Loss: 0.687798448864547


KeyboardInterrupt: 

In [12]:
y_pred.shape

torch.Size([9600, 560])