In [1]:
# Import necessary libraries
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder, StandardScaler
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from segrnn import SegRNNModel
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import os
from utils import preprocess_and_save_data

In [2]:
preprocess_and_save_data('./csvs/ROC.csv', normalize=True)
preprocess_and_save_data('./csvs/JRB.csv', normalize=True)
preprocess_and_save_data('./csvs/BGM.csv', normalize=True)
preprocess_and_save_data('./csvs/PEO.csv', normalize=True)
preprocess_and_save_data('./csvs/RME.csv', normalize=True)
preprocess_and_save_data('./csvs/MSS.csv', normalize=True)

Missing values in continuous columns before processing:
tmpf                     2
dwpf                     3
relh                     3
feel                    14
drct                   634
sknt                    29
gust                 44180
peak_wind_gust       48793
peak_wind_drct       48793
alti                     1
mslp                 10049
vsby                     0
p01i                 10614
ice_accretion_1hr    52974
ice_accretion_3hr    53018
ice_accretion_6hr    53008
skyl1                 3687
skyl2                21406
skyl3                37709
skyl4                51113
snowdepth            52315
peak_wind_time       53032
dtype: int64
nan thresh is 26516.0
bad columns are ['gust', 'skyl3', 'skyl4', 'ice_accretion_1hr', 'ice_accretion_3hr', 'ice_accretion_6hr', 'peak_wind_gust', 'peak_wind_drct', 'peak_wind_time', 'snowdepth']
10 remaining continuous columns: ['mslp', 'vsby', 'dwpf', 'tmpf', 'p01i', 'alti', 'sknt', 'feel', 'drct', 'relh']
Missing values in continuous

  df[continuous_cols] = df[continuous_cols].replace(placeholders, np.nan).astype(str)


Missing values in continuous columns before processing:
tmpf                  5037
dwpf                  5198
relh                  5198
feel                  5198
drct                  9272
sknt                   127
gust                 41371
peak_wind_gust       44018
peak_wind_drct       44018
alti                    10
mslp                 12991
vsby                  1258
p01i                  3098
ice_accretion_1hr    47642
ice_accretion_3hr    47642
ice_accretion_6hr    47642
skyl1                20500
skyl2                37465
skyl3                43766
skyl4                47642
snowdepth            47642
peak_wind_time       47642
dtype: int64
nan thresh is 23821.0
bad columns are ['gust', 'skyl2', 'skyl3', 'skyl4', 'ice_accretion_1hr', 'ice_accretion_3hr', 'ice_accretion_6hr', 'peak_wind_gust', 'peak_wind_drct', 'peak_wind_time', 'snowdepth']
10 remaining continuous columns: ['mslp', 'vsby', 'dwpf', 'tmpf', 'p01i', 'alti', 'sknt', 'feel', 'drct', 'relh']
Missing values in c

  df[continuous_cols] = df[continuous_cols].replace(placeholders, np.nan).astype(str)


Missing values in continuous columns before processing:
tmpf                    44
dwpf                    44
relh                    44
feel                    76
drct                  1377
sknt                   180
gust                 52379
peak_wind_gust       60197
peak_wind_drct       60197
alti                     1
mslp                 21145
vsby                    59
p01i                 10097
ice_accretion_1hr    63161
ice_accretion_3hr    63552
ice_accretion_6hr    63518
skyl1                16351
skyl2                38243
skyl3                51742
skyl4                63611
snowdepth            63611
peak_wind_time       63611
dtype: int64
nan thresh is 31805.5
bad columns are ['gust', 'skyl2', 'skyl3', 'skyl4', 'ice_accretion_1hr', 'ice_accretion_3hr', 'ice_accretion_6hr', 'peak_wind_gust', 'peak_wind_drct', 'peak_wind_time', 'snowdepth']
10 remaining continuous columns: ['mslp', 'vsby', 'dwpf', 'tmpf', 'p01i', 'alti', 'sknt', 'feel', 'drct', 'relh']
Missing values in c

  df[continuous_cols] = df[continuous_cols].replace(placeholders, np.nan).astype(str)


Missing values in continuous columns before processing:
tmpf                     9
dwpf                   455
relh                   455
feel                   553
drct                  4971
sknt                   328
gust                 48691
peak_wind_gust       56201
peak_wind_drct       56201
alti                     1
mslp                 16369
vsby                    25
p01i                 10510
ice_accretion_1hr    58969
ice_accretion_3hr    59122
ice_accretion_6hr    59109
skyl1                17488
skyl2                38038
skyl3                49721
skyl4                59152
snowdepth            59152
peak_wind_time       59152
dtype: int64
nan thresh is 29576.0
bad columns are ['gust', 'skyl2', 'skyl3', 'skyl4', 'ice_accretion_1hr', 'ice_accretion_3hr', 'ice_accretion_6hr', 'peak_wind_gust', 'peak_wind_drct', 'peak_wind_time', 'snowdepth']
10 remaining continuous columns: ['mslp', 'vsby', 'dwpf', 'tmpf', 'p01i', 'alti', 'sknt', 'feel', 'drct', 'relh']
Missing values in c

  df[continuous_cols] = df[continuous_cols].replace(placeholders, np.nan).astype(str)


Missing values in continuous columns before processing:
tmpf                   115
dwpf                   139
relh                   139
feel                   149
drct                  1325
sknt                    58
gust                 53522
peak_wind_gust       56820
peak_wind_drct       56820
alti                     1
mslp                 16297
vsby                    13
p01i                  9421
ice_accretion_1hr    58874
ice_accretion_3hr    59079
ice_accretion_6hr    59060
skyl1                16377
skyl2                36364
skyl3                48795
skyl4                59111
snowdepth            59111
peak_wind_time       59111
dtype: int64
nan thresh is 29555.5
bad columns are ['gust', 'skyl2', 'skyl3', 'skyl4', 'ice_accretion_1hr', 'ice_accretion_3hr', 'ice_accretion_6hr', 'peak_wind_gust', 'peak_wind_drct', 'peak_wind_time', 'snowdepth']
10 remaining continuous columns: ['mslp', 'vsby', 'dwpf', 'tmpf', 'p01i', 'alti', 'sknt', 'feel', 'drct', 'relh']
Missing values in c

  df[continuous_cols] = df[continuous_cols].replace(placeholders, np.nan).astype(str)


Missing values in continuous columns before processing:
tmpf                   142
dwpf                   290
relh                   290
feel                   320
drct                   994
sknt                   254
gust                 50746
peak_wind_gust       55448
peak_wind_drct       55448
alti                     0
mslp                 16471
vsby                   192
p01i                  9366
ice_accretion_1hr    57798
ice_accretion_3hr    58255
ice_accretion_6hr    58204
skyl1                18803
skyl2                39975
skyl3                50875
skyl4                58320
snowdepth            58320
peak_wind_time       58320
dtype: int64
nan thresh is 29160.0
bad columns are ['gust', 'skyl2', 'skyl3', 'skyl4', 'ice_accretion_1hr', 'ice_accretion_3hr', 'ice_accretion_6hr', 'peak_wind_gust', 'peak_wind_drct', 'peak_wind_time', 'snowdepth']
10 remaining continuous columns: ['mslp', 'vsby', 'dwpf', 'tmpf', 'p01i', 'alti', 'sknt', 'feel', 'drct', 'relh']
Missing values in c

['mslp',
 'vsby',
 'dwpf',
 'tmpf',
 'p01i',
 'alti',
 'sknt',
 'feel',
 'drct',
 'relh']

In [3]:
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import torch.nn.functional as F

In [4]:
nylocations = './csvs/_nylocations.csv'
latlongs = pd.read_csv(nylocations)
latlongs

Unnamed: 0,stid,station_name,lat,lon,elev,begints,endts,iem_network
0,6B9,Skaneateles,42.914,-76.4408,304.32452,2016-07-22 00:00,,NY_ASOS
1,ALB,ALBANY COUNTY ARPT,42.7576,-73.8036,89.0,1945-01-01 00:00,,NY_ASOS
2,ART,WATERTOWN INTL ARPT,43.9888,-76.0262,99.0,1949-04-30 00:00,,NY_ASOS
3,BGM,BINGHAMTON/BROOME,42.2086,-75.9797,497.0,1948-01-01 00:00,,NY_ASOS
4,BUF,BUFFALO INTL ARPT,42.9408,-78.7358,215.0,1942-01-31 00:00,,NY_ASOS
5,DKK,DUNKIRK AIRPORT,42.4933,-79.272,203.0,1948-12-31 00:00,,NY_ASOS
6,DSV,DANSVILLE MUNICIPAL,42.5709,-77.713,209.0,1948-12-31 00:00,,NY_ASOS
7,ELM,Elmira / Corning,42.1571,-76.8994,287.12537,1949-02-01 00:00,,NY_ASOS
8,ELZ,Wellsville Municipal,42.1078,-77.9842,639.0,1978-06-13 00:00,,NY_ASOS
9,FOK,WESTHAMPTON BEACH,40.8436,-72.6318,20.0,1943-07-18 00:00,,NY_ASOS


In [5]:
roc = './csvs/ROC_processed.csv'
roc_p = pd.read_csv(roc)
roc_p

Unnamed: 0,station,valid,tmpf,dwpf,relh,drct,sknt,p01i,alti,mslp,vsby,feel
0,ROC,2020-01-01 00:54:00,-0.942380,-0.872567,0.120999,0.575668,0.899204,-0.207339,-1.770322,-1.688849,0.490278,-1.070871
1,ROC,2020-01-01 01:54:00,-0.942380,-1.050176,-0.406298,0.575668,1.530761,-0.207339,-1.770322,-1.701429,0.490278,-1.125387
2,ROC,2020-01-01 02:54:00,-0.942380,-1.050176,-0.406298,0.670413,0.688685,-0.207339,-1.727055,-1.663691,0.490278,-1.049869
3,ROC,2020-01-01 03:54:00,-1.001766,-1.101739,-0.382094,0.670413,1.320242,-0.207339,-1.727055,-1.651112,0.490278,-1.171859
4,ROC,2020-01-01 04:54:00,-1.001766,-1.101739,-0.382094,0.575668,1.109723,-0.207339,-1.727055,-1.651112,0.490278,-1.153091
...,...,...,...,...,...,...,...,...,...,...,...,...
53027,ROC,2024-11-28 19:54:00,-0.564467,-0.299635,0.717451,1.144142,0.267646,0.046903,-1.207852,-1.185684,0.490278,-0.614192
53028,ROC,2024-11-28 20:54:00,-0.564467,-0.299635,0.717451,1.049396,-0.574430,-0.080218,-1.078051,-1.047313,0.490278,-0.502926
53029,ROC,2024-11-28 21:54:00,-0.618454,-0.356928,0.713417,0.575668,-0.363911,-0.207339,-0.948250,-0.896363,0.490278,-0.590062
53030,ROC,2024-11-28 22:54:00,-0.618454,-0.356928,0.713417,0.386176,-0.153392,-0.207339,-0.861716,-0.833468,0.490278,-0.620001


In [6]:
required_stations = ['JRB', 'ROC', 'BGM', 'MSS', 'PEO', 'RME']

### first get the locations of stations
# Filter the DataFrame for the required stations
stations_df = latlongs[latlongs['stid'].isin(required_stations)]
# Create a dictionary from the filtered DataFrame
stations_latlong = stations_df.set_index('stid')[['lat', 'lon']].T.to_dict()

print(stations_latlong)

### now get station data itself
processed_data_paths = {station:f'./csvs/{station}_processed.csv' for station in required_stations}
print(processed_data_paths)

station_features = []
normalized_latlongs = []
for stid in required_stations:
   station_data = pd.read_csv(processed_data_paths[stid])
   features = torch.tensor(station_data.drop(columns=['station', 'valid']).values, dtype=torch.float)
   mean_features = features.mean(dim=0)  # Mean of all the features to get "average weather"

   # Append latitude and longitude to the feature vector
   lat, long = stations_latlong[stid]['lat'], stations_latlong[stid]['lon']
   nlat = (lat+90)/ (180)
   nlong = (long+180)/ (360) 
   lat_long = torch.tensor([nlat,nlong], dtype=torch.float)
   print(f"latlong is {lat_long}")
   combined_features = torch.cat((mean_features, lat_long))  # Concatenate features with lat/lon

   station_features.append(combined_features)
   normalized_latlongs.append([nlat, nlong])

print(f'n latlongs: {normalized_latlongs}')

node_features = torch.stack(station_features)
node_features += torch.randn_like(node_features) * 0.01  # noise
print(node_features.shape)
print(node_features)
print(torch.var(node_features, dim=0))

def calculate_distances(latlongs):
   num_stations = len(latlongs)
   distances = np.zeros((num_stations, num_stations))
   for i, coord1 in enumerate(latlongs):
      for j, coord2 in enumerate(latlongs):
         # Calculate Euclidean distance for normalized coordinates
         distances[i, j] = np.linalg.norm(np.array(coord1) - np.array(coord2))
   return distances

distances = calculate_distances(normalized_latlongs)
print("Distance Matrix:")
print(distances)

distance_threshold = 0.015  # Adjust this threshold based on your scale and data
edges = []

for i in range(len(distances)):
   for j in range(len(distances)):
      if i != j and distances[i, j] <= distance_threshold:
         edges.append((i, j))

# Define edges 
edge_index = torch.tensor(edges, dtype=torch.long).T  # Transpose to match edge_index format
print("Edge Index:")
print(edge_index)

# Create the graph data
data = Data(x=node_features, edge_index=edge_index)
print(data)

{'BGM': {'lat': 42.2086, 'lon': -75.9797}, 'JRB': {'lat': 40.7012, 'lon': -74.009}, 'MSS': {'lat': 44.9358, 'lon': -74.8456}, 'PEO': {'lat': 42.6441, 'lon': -77.0529}, 'RME': {'lat': 43.2239, 'lon': -75.3953}, 'ROC': {'lat': 43.1167, 'lon': -77.6767}}
{'JRB': './csvs/JRB_processed.csv', 'ROC': './csvs/ROC_processed.csv', 'BGM': './csvs/BGM_processed.csv', 'MSS': './csvs/MSS_processed.csv', 'PEO': './csvs/PEO_processed.csv', 'RME': './csvs/RME_processed.csv'}
latlong is tensor([0.7261, 0.2944])
latlong is tensor([0.7395, 0.2842])
latlong is tensor([0.7345, 0.2889])
latlong is tensor([0.7496, 0.2921])
latlong is tensor([0.7369, 0.2860])
latlong is tensor([0.7401, 0.2906])
n latlongs: [[0.7261177777777778, 0.29441944444444446], [0.7395372222222223, 0.2842313888888889], [0.7344922222222222, 0.2889452777777778], [0.7496433333333333, 0.29209555555555555], [0.7369116666666667, 0.2859641666666667], [0.7401327777777779, 0.2905686111111111]]
torch.Size([6, 12])
tensor([[-1.0023e-03,  6.0683e-03,

In [7]:
class SimpleGNN(torch.nn.Module):
   def __init__(self, input_dim, hidden_dim, output_dim):
      super(SimpleGNN, self).__init__()
      self.conv1 = GCNConv(input_dim, hidden_dim)
      self.conv2 = GCNConv(hidden_dim, output_dim)

   def forward(self, data):
      x, edge_index = data.x, data.edge_index
      x = self.conv1(x, edge_index)
      x = F.relu(x)
      x = self.conv2(x, edge_index)
      return x  # Embeddings for each node

# Initialize the GNN
input_dim = node_features.shape[1]  # Latitude and longitude plus features
hidden_dim = 14
output_dim = 12  # Embedding size
gnn = SimpleGNN(input_dim, hidden_dim, output_dim)

In [8]:
embeddings = gnn(data)

# Print embeddings
print("initial embeddings for each station:")
print(embeddings)

initial embeddings for each station:
tensor([[-0.0010, -0.0639,  0.0999, -0.0328,  0.0132,  0.1274, -0.0606, -0.2108,
          0.1127, -0.2048, -0.1221, -0.1166],
        [-0.0020, -0.0724,  0.1121, -0.0349,  0.0146,  0.1426, -0.0685, -0.2363,
          0.1258, -0.2284, -0.1363, -0.1286],
        [-0.0014, -0.0717,  0.1119, -0.0359,  0.0151,  0.1422, -0.0677, -0.2358,
          0.1255, -0.2284, -0.1365, -0.1298],
        [-0.0018, -0.0651,  0.1005, -0.0310,  0.0131,  0.1278, -0.0614, -0.2120,
          0.1128, -0.2046, -0.1221, -0.1148],
        [-0.0017, -0.0788,  0.1224, -0.0388,  0.0162,  0.1555, -0.0743, -0.2580,
          0.1375, -0.2497, -0.1489, -0.1411],
        [-0.0017, -0.0788,  0.1224, -0.0388,  0.0162,  0.1555, -0.0743, -0.2580,
          0.1375, -0.2497, -0.1489, -0.1411]], grad_fn=<AddBackward0>)


Now train the gnn

In [9]:
from torch.nn.functional import cosine_similarity

# Compute similarity matrix from node features
similarity_matrix = torch.mm(node_features, node_features.T)
print(f'similarity matrix is {similarity_matrix}')

# Define unsupervised loss function with regularization
def contrastive_loss(embeddings, similarity_matrix, distance_matrix, lambda_diversity=0.1, lambda_distance=0.1):
   pred_similarity = torch.mm(embeddings, embeddings.T)
   mse_loss = torch.nn.functional.mse_loss(pred_similarity, similarity_matrix)

   diversity_loss = -torch.var(embeddings, dim=0).mean()  # Penalize low variance

   pred_distance_matrix = torch.cdist(embeddings, embeddings, p=2)
   distance_loss = torch.nn.functional.mse_loss(pred_distance_matrix, distance_matrix)

   total_loss = mse_loss + lambda_diversity * diversity_loss + lambda_distance * distance_loss
   return total_loss

# Training loop
epochs = 50
optimizer = torch.optim.Adam(gnn.parameters(), lr=0.001)

gnn.train()
for epoch in range(epochs):
   optimizer.zero_grad()  # Reset gradients
   embeddings = gnn(data)  # Forward pass

   # Compute contrastive loss
   loss = contrastive_loss(embeddings, similarity_matrix, torch.Tensor(distances))
   
   # Backward pass and optimization
   loss.backward()
   optimizer.step()

   # Print loss and gradient information
   print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
   # for name, param in gnn.named_parameters():
   #    if param.grad is not None:
   #       print(f"Gradient for {name}: {param.grad.abs().mean().item():.6f}")
   #    else:
   #       print(f"No gradient for {name}.")


similarity matrix is tensor([[0.6201, 0.6285, 0.6187, 0.6374, 0.6106, 0.6101],
        [0.6285, 0.6404, 0.6291, 0.6475, 0.6200, 0.6209],
        [0.6187, 0.6291, 0.6218, 0.6376, 0.6118, 0.6113],
        [0.6374, 0.6475, 0.6376, 0.6573, 0.6291, 0.6290],
        [0.6106, 0.6200, 0.6118, 0.6291, 0.6039, 0.6023],
        [0.6101, 0.6209, 0.6113, 0.6290, 0.6023, 0.6035]])
Epoch 1, Loss: 0.1791
Epoch 2, Loss: 0.1702
Epoch 3, Loss: 0.1612
Epoch 4, Loss: 0.1521
Epoch 5, Loss: 0.1430
Epoch 6, Loss: 0.1339
Epoch 7, Loss: 0.1248
Epoch 8, Loss: 0.1158
Epoch 9, Loss: 0.1068
Epoch 10, Loss: 0.0980
Epoch 11, Loss: 0.0894
Epoch 12, Loss: 0.0809
Epoch 13, Loss: 0.0728
Epoch 14, Loss: 0.0649
Epoch 15, Loss: 0.0574
Epoch 16, Loss: 0.0502
Epoch 17, Loss: 0.0435
Epoch 18, Loss: 0.0373
Epoch 19, Loss: 0.0316
Epoch 20, Loss: 0.0264
Epoch 21, Loss: 0.0218
Epoch 22, Loss: 0.0178
Epoch 23, Loss: 0.0144
Epoch 24, Loss: 0.0117
Epoch 25, Loss: 0.0095
Epoch 26, Loss: 0.0080
Epoch 27, Loss: 0.0069
Epoch 28, Loss: 0.

In [10]:
learned_embeddings = gnn(data)  # Get final embeddings
print("Learned embeddings:")
print(learned_embeddings)

Learned embeddings:
tensor([[-0.1468, -0.1550,  0.2000, -0.0575,  0.1047,  0.2262, -0.1686, -0.3321,
          0.1651, -0.3363, -0.2194, -0.2050],
        [-0.1621, -0.1709,  0.2205, -0.0594,  0.1135,  0.2501, -0.1864, -0.3685,
          0.1810, -0.3727, -0.2421, -0.2246],
        [-0.1610, -0.1701,  0.2202, -0.0604,  0.1138,  0.2493, -0.1853, -0.3679,
          0.1810, -0.3723, -0.2420, -0.2255],
        [-0.1482, -0.1560,  0.2003, -0.0558,  0.1047,  0.2270, -0.1701, -0.3329,
          0.1649, -0.3366, -0.2198, -0.2038],
        [-0.1748, -0.1845,  0.2390, -0.0627,  0.1220,  0.2706, -0.2010, -0.4010,
          0.1959, -0.4057, -0.2622, -0.2437],
        [-0.1748, -0.1845,  0.2390, -0.0627,  0.1220,  0.2706, -0.2010, -0.4010,
          0.1959, -0.4057, -0.2622, -0.2437]], grad_fn=<AddBackward0>)


In [11]:
for r in required_stations:
   df = pd.read_csv(f'./csvs/{r}_processed.csv')
   print(r)
   print(df.shape)

JRB
(47642, 12)
ROC
(53032, 12)
BGM
(63611, 12)
MSS
(58320, 12)
PEO
(59152, 12)
RME
(59111, 12)


for each station, add on the static embedding that we just learned

In [12]:
all_data = []

for idx, (station, path) in enumerate(processed_data_paths.items()):
   # Load the CSV into a DataFrame
   df = pd.read_csv(path)
   print(df.shape)
   
   # Get the corresponding embedding for this station
   embedding = learned_embeddings[idx].detach().numpy()
   
   # Add the embedding as new columns to the DataFrame
   for i, value in enumerate(embedding):
      df[f'embedding_{i}'] = value

   print(df.shape)
   
   # Append to the list of all data
   all_data.append(df)

# Concatenate all DataFrames into one
combined_df = pd.concat(all_data, ignore_index=True)

# Display the result
print(combined_df.shape)
combined_df.head()


(47642, 12)
(47642, 24)
(53032, 12)
(53032, 24)
(63611, 12)
(63611, 24)
(58320, 12)
(58320, 24)
(59152, 12)
(59152, 24)
(59111, 12)
(59111, 24)
(340868, 24)


Unnamed: 0,station,valid,tmpf,dwpf,relh,drct,sknt,p01i,alti,mslp,...,embedding_2,embedding_3,embedding_4,embedding_5,embedding_6,embedding_7,embedding_8,embedding_9,embedding_10,embedding_11
0,JRB,2020-01-01 00:56:00,-1.052667,-0.58335,0.657509,1.118256,0.310933,-0.119089,-1.621235,-1.705313,...,0.199959,-0.057479,0.104688,0.22618,-0.168643,-0.332092,0.165093,-0.336337,-0.219415,-0.205045
1,JRB,2020-01-01 01:56:00,-1.118122,-0.704325,0.523108,1.118256,0.836536,-0.119089,-1.534873,-1.63939,...,0.199959,-0.057479,0.104688,0.22618,-0.168643,-0.332092,0.165093,-0.336337,-0.219415,-0.205045
2,JRB,2020-01-01 02:56:00,-1.183577,-0.764813,0.518875,1.118256,1.099338,-0.119089,-1.578054,-1.678944,...,0.199959,-0.057479,0.104688,0.22618,-0.168643,-0.332092,0.165093,-0.336337,-0.219415,-0.205045
3,JRB,2020-01-01 03:56:00,-1.183577,-0.8253,0.346377,1.118256,0.573735,-0.119089,-1.578054,-1.652575,...,0.199959,-0.057479,0.104688,0.22618,-0.168643,-0.332092,0.165093,-0.336337,-0.219415,-0.205045
4,JRB,2020-01-01 04:56:00,-1.249032,-1.012812,0.042123,1.118256,1.624941,-0.119089,-1.534873,-1.613021,...,0.199959,-0.057479,0.104688,0.22618,-0.168643,-0.332092,0.165093,-0.336337,-0.219415,-0.205045


In [13]:
# Step 6: Prepare sequences for LSTM input
# Assuming we are predicting 'tmpf' (temperature) as the target variable
# and using previous 24 time steps/8 hours (n_steps_in) to predict the next time step/20 minutes from now (n_steps_out)
# create sliding window sequences X: (114640, 24, 10), y: (114640, 10)
feature_cols = list(set(df.columns) - set(['station', 'valid']))

n_steps_in = 24  # Number of past time steps
n_steps_out = 1  # Number of future time steps to predict

# We'll create sequences for each station separately
def create_sequences(data, n_steps_in, n_steps_out):
   X, y = [], []
   for i in range(len(data) - n_steps_in - n_steps_out + 1):
      X.append(data[i:(i + n_steps_in), :])
      y.append(data[(i + n_steps_in):(i + n_steps_in + n_steps_out), :])
   return np.array(X), np.array(y)

# Prepare data for each station
X_list = []
y_list = []
stations = df['station'].unique()

for station in stations:
   station_data = df[df['station'] == station]
   station_data = station_data.reset_index(drop=True)
   data_values = station_data[feature_cols].values
   # target_col_index = feature_cols.index('tmpf')  # Index of target variable in features

   X_station, y_station = create_sequences(data_values, n_steps_in, n_steps_out)
   X_list.append(X_station)
   y_list.append(y_station)


# Concatenate data from all stations
X = np.concatenate(X_list, axis=0)
y = np.concatenate(y_list, axis=0)


if n_steps_out == 1:
   y = y.squeeze(1)  # Shape becomes (num_samples, num_features) = (114640, 10) for JRB


print(X.shape)
print(y.shape)

# Convert to PyTorch tensors
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)

(59087, 24, 22)
(59087, 22)


In [14]:
# Step 7: Split the data into training, validation, and testing sets
# Since it's time-series data, we'll use the first 70% for training, next 10% for validation, and the rest for testing
train_size = int(len(X) * 0.7)
val_size = int(len(X) * 0.1)
test_size = len(X) - train_size - val_size

X_train, X_val, X_test = X[:train_size], X[train_size:train_size + val_size], X[train_size + val_size:]
y_train, y_val, y_test = y[:train_size], y[train_size:train_size + val_size], y[train_size + val_size:]

# Now the data is ready for training the LSTM model

# Define a PyTorch Dataset
class WeatherDataset(Dataset):
   def __init__(self, X, y):
      self.X = X
      self.y = y
   def __len__(self):
      return len(self.X)
   def __getitem__(self, idx):
      return self.X[idx], self.y[idx]

# Create Dataset objects for training, validation, and testing
train_dataset = WeatherDataset(X_train, y_train)
val_dataset = WeatherDataset(X_val, y_val)
test_dataset = WeatherDataset(X_test, y_test)

# Example to check shapes
print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}, Test size: {len(test_dataset)}")

Train size: 41360, Validation size: 5908, Test size: 11819


In [15]:
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer

# Create DataLoaders
batch_size = 32
train_dataset = WeatherDataset(X_train, y_train)
val_dataset = WeatherDataset(X_val, y_val)
test_dataset = WeatherDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Hyperparameters for SegRNN
input_size = X.shape[2]  # Number of features
hidden_size = 512  # Based on the SEGRNN paper
output_size = X.shape[2]  # Predict all features
segment_length = 8  # Based on the SEGRNN paper
learning_rate = 0.001

# Initialize SegRNNModel
model = SegRNNModel(
   input_size=input_size,
   hidden_size=hidden_size,
   output_size=output_size,
   segment_length=segment_length,
   learning_rate=learning_rate
)

# Logger
logger = TensorBoardLogger("logs", name="segrnn_experiment")

# Checkpoint callback
checkpoint_callback = ModelCheckpoint(
   dirpath="checkpoints/",
   filename="segrnn-{epoch:02d}-{val_loss:.4f}",
   save_top_k=1,
   monitor="val_loss",  # Monitor validation loss
   mode="min"
)

# Trainer with logging and checkpointing
trainer = Trainer(
   max_epochs=25,
   accelerator="gpu" if torch.cuda.is_available() else "cpu",
   devices=1,
   logger=logger,
   callbacks=[checkpoint_callback]
)

# Train the model, including validation loader
trainer.fit(model, train_loader, val_loader)

# Optional: Evaluate on the test set
trainer.test(model, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name       | Type    | Params | Mode 
-----------------------------------------------
0 | model      | SegRNN  | 2.4 M  | train
1 | criterion  | MSELoss | 0      | train
2 | criterion2 | L1Loss  | 0      | train
-----------------------------------------------
2.4 M     Trainable params
0         Non-trainable params
2.4 M     Total params
9.642     Total estimated model params size (MB)
6         Modules in train mode
0         Modules in eval mode
SLURM auto-reque

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Validation Loss: 0.516322135925293
Validation Loss: 0.3555587828159332


/global/homes/p/parshvam/.local/perlmutter/python-3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.
/global/homes/p/parshvam/.local/perlmutter/python-3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation Loss: 0.02118496783077717
Validation Loss: 0.049162041395902634
Validation Loss: 0.03336353600025177
Validation Loss: 0.028187468647956848
Validation Loss: 0.059936702251434326
Validation Loss: 0.02290920913219452
Validation Loss: 0.019369736313819885
Validation Loss: 0.015387424267828465
Validation Loss: 0.027209624648094177
Validation Loss: 0.04243132472038269
Validation Loss: 0.0678059458732605
Validation Loss: 0.09146402031183243
Validation Loss: 0.08311328291893005
Validation Loss: 0.030933771282434464
Validation Loss: 0.10519303381443024
Validation Loss: 0.026902299374341965
Validation Loss: 0.06371869146823883
Validation Loss: 0.040127743035554886
Validation Loss: 0.040692444890737534
Validation Loss: 0.033423569053411484
Validation Loss: 0.0878966823220253
Validation Loss: 0.06840655952692032
Validation Loss: 0.04355301707983017
Validation Loss: 0.030622180551290512
Validation Loss: 0.08760601282119751
Validation Loss: 0.048805296421051025
Validation Loss: 0.02814120

Validation: |          | 0/? [00:00<?, ?it/s]

Validation Loss: 0.031127803027629852
Validation Loss: 0.044886644929647446
Validation Loss: 0.04052957892417908
Validation Loss: 0.03270455077290535
Validation Loss: 0.06343920528888702
Validation Loss: 0.0238400436937809
Validation Loss: 0.02250540815293789
Validation Loss: 0.01728300377726555
Validation Loss: 0.030804185196757317
Validation Loss: 0.04657330363988876
Validation Loss: 0.0715823769569397
Validation Loss: 0.10105163604021072
Validation Loss: 0.11575150489807129
Validation Loss: 0.040590133517980576
Validation Loss: 0.10888384282588959
Validation Loss: 0.0362447127699852
Validation Loss: 0.05929860472679138
Validation Loss: 0.042324285954236984
Validation Loss: 0.04969281703233719
Validation Loss: 0.04442274942994118
Validation Loss: 0.09953875094652176
Validation Loss: 0.07252638787031174
Validation Loss: 0.04939959570765495
Validation Loss: 0.032573774456977844
Validation Loss: 0.07782281935214996
Validation Loss: 0.05157085135579109
Validation Loss: 0.0362521894276142

Validation: |          | 0/? [00:00<?, ?it/s]

Validation Loss: 0.021894650533795357
Validation Loss: 0.043388962745666504
Validation Loss: 0.03544503450393677
Validation Loss: 0.03370160609483719
Validation Loss: 0.0566285103559494
Validation Loss: 0.017940545454621315
Validation Loss: 0.020882513374090195
Validation Loss: 0.01490133535116911
Validation Loss: 0.028458600863814354
Validation Loss: 0.03430856019258499
Validation Loss: 0.06559798866510391
Validation Loss: 0.08381196111440659
Validation Loss: 0.1002800464630127
Validation Loss: 0.027035024017095566
Validation Loss: 0.10000082105398178
Validation Loss: 0.03344879671931267
Validation Loss: 0.062412068247795105
Validation Loss: 0.034084778279066086
Validation Loss: 0.04329141974449158
Validation Loss: 0.03896281123161316
Validation Loss: 0.0826081782579422
Validation Loss: 0.05864090472459793
Validation Loss: 0.042500752955675125
Validation Loss: 0.025688795372843742
Validation Loss: 0.07642205059528351
Validation Loss: 0.04721487686038017
Validation Loss: 0.029061302542

Validation: |          | 0/? [00:00<?, ?it/s]

Validation Loss: 0.01804862916469574
Validation Loss: 0.046025462448596954
Validation Loss: 0.03636140376329422
Validation Loss: 0.025744853541254997
Validation Loss: 0.058799874037504196
Validation Loss: 0.020060067996382713
Validation Loss: 0.020537665113806725
Validation Loss: 0.01594615913927555
Validation Loss: 0.02952362224459648
Validation Loss: 0.03425747901201248
Validation Loss: 0.06244624778628349
Validation Loss: 0.0833880677819252
Validation Loss: 0.09671323001384735
Validation Loss: 0.02878391370177269
Validation Loss: 0.09183213114738464
Validation Loss: 0.032710250467061996
Validation Loss: 0.07029068470001221
Validation Loss: 0.04273078963160515
Validation Loss: 0.04701032489538193
Validation Loss: 0.03806116059422493
Validation Loss: 0.08073365688323975
Validation Loss: 0.06024584174156189
Validation Loss: 0.04003274068236351
Validation Loss: 0.0277983658015728
Validation Loss: 0.08370191603899002
Validation Loss: 0.04062407836318016
Validation Loss: 0.027214121073484

Validation: |          | 0/? [00:00<?, ?it/s]

Validation Loss: 0.018322257325053215
Validation Loss: 0.04162782430648804
Validation Loss: 0.031317658722400665
Validation Loss: 0.031012866646051407
Validation Loss: 0.055236779153347015
Validation Loss: 0.01750394143164158
Validation Loss: 0.020173106342554092
Validation Loss: 0.01577604003250599
Validation Loss: 0.028544645756483078
Validation Loss: 0.031616371124982834
Validation Loss: 0.059890005737543106
Validation Loss: 0.08164022117853165
Validation Loss: 0.07734458148479462
Validation Loss: 0.026711102575063705
Validation Loss: 0.09472515434026718
Validation Loss: 0.026841603219509125
Validation Loss: 0.055669840425252914
Validation Loss: 0.03190189599990845
Validation Loss: 0.03747944161295891
Validation Loss: 0.034260571002960205
Validation Loss: 0.08067076653242111
Validation Loss: 0.0559365339577198
Validation Loss: 0.04254445806145668
Validation Loss: 0.02584470808506012
Validation Loss: 0.07872568815946579
Validation Loss: 0.04502083733677864
Validation Loss: 0.03140450

Validation: |          | 0/? [00:00<?, ?it/s]

Validation Loss: 0.02033158205449581
Validation Loss: 0.04521290585398674
Validation Loss: 0.035893648862838745
Validation Loss: 0.032956186681985855
Validation Loss: 0.05494484305381775
Validation Loss: 0.01867855153977871
Validation Loss: 0.018217748031020164
Validation Loss: 0.013425939716398716
Validation Loss: 0.0256867203861475
Validation Loss: 0.03490762785077095
Validation Loss: 0.05831078812479973
Validation Loss: 0.07980229705572128
Validation Loss: 0.10015231370925903
Validation Loss: 0.02408963441848755
Validation Loss: 0.09394349902868271
Validation Loss: 0.02975780889391899
Validation Loss: 0.05413131043314934
Validation Loss: 0.03780997917056084
Validation Loss: 0.0390956774353981
Validation Loss: 0.03573887050151825
Validation Loss: 0.0780981034040451
Validation Loss: 0.0610479936003685
Validation Loss: 0.04037787765264511
Validation Loss: 0.02426040731370449
Validation Loss: 0.07321903854608536
Validation Loss: 0.04189622402191162
Validation Loss: 0.03194581717252731
V

Validation: |          | 0/? [00:00<?, ?it/s]

Validation Loss: 0.020878935232758522
Validation Loss: 0.04313082993030548
Validation Loss: 0.029452655464410782
Validation Loss: 0.029520761221647263
Validation Loss: 0.05478354170918465
Validation Loss: 0.01722428947687149
Validation Loss: 0.019079433754086494
Validation Loss: 0.015356233343482018
Validation Loss: 0.025021910667419434
Validation Loss: 0.03176512569189072
Validation Loss: 0.057947490364313126
Validation Loss: 0.08087872713804245
Validation Loss: 0.07655651867389679
Validation Loss: 0.0252786036580801
Validation Loss: 0.09124346822500229
Validation Loss: 0.037098824977874756
Validation Loss: 0.05913645774126053
Validation Loss: 0.040358029305934906
Validation Loss: 0.05032680556178093
Validation Loss: 0.04386947676539421
Validation Loss: 0.07815820723772049
Validation Loss: 0.0604778453707695
Validation Loss: 0.044741977006196976
Validation Loss: 0.027599118649959564
Validation Loss: 0.08953861892223358
Validation Loss: 0.045544613152742386
Validation Loss: 0.034810408

Validation: |          | 0/? [00:00<?, ?it/s]

Validation Loss: 0.024106621742248535
Validation Loss: 0.03961963206529617
Validation Loss: 0.03090541623532772
Validation Loss: 0.03172551468014717
Validation Loss: 0.0602639839053154
Validation Loss: 0.017969205975532532
Validation Loss: 0.01855810359120369
Validation Loss: 0.013865792192518711
Validation Loss: 0.028426047414541245
Validation Loss: 0.037408336997032166
Validation Loss: 0.059563279151916504
Validation Loss: 0.08627539873123169
Validation Loss: 0.08768638223409653
Validation Loss: 0.02495630830526352
Validation Loss: 0.09053363651037216
Validation Loss: 0.028415702283382416
Validation Loss: 0.05743671581149101
Validation Loss: 0.038699228316545486
Validation Loss: 0.040266282856464386
Validation Loss: 0.0370357409119606
Validation Loss: 0.0803745687007904
Validation Loss: 0.06520457565784454
Validation Loss: 0.039317816495895386
Validation Loss: 0.02583247236907482
Validation Loss: 0.07477306574583054
Validation Loss: 0.03945083171129227
Validation Loss: 0.032844815403

Validation: |          | 0/? [00:00<?, ?it/s]

Validation Loss: 0.022557513788342476
Validation Loss: 0.04516449570655823
Validation Loss: 0.024984875693917274
Validation Loss: 0.026951363310217857
Validation Loss: 0.06168060004711151
Validation Loss: 0.018684295937418938
Validation Loss: 0.01913115754723549
Validation Loss: 0.01569516584277153
Validation Loss: 0.02808341197669506
Validation Loss: 0.03230183571577072
Validation Loss: 0.057559143751859665
Validation Loss: 0.086514413356781
Validation Loss: 0.11133988946676254
Validation Loss: 0.02524111047387123
Validation Loss: 0.0940365195274353
Validation Loss: 0.028848797082901
Validation Loss: 0.05415705218911171
Validation Loss: 0.03580300882458687
Validation Loss: 0.0433744341135025
Validation Loss: 0.040093932300806046
Validation Loss: 0.08163733035326004
Validation Loss: 0.06495898962020874
Validation Loss: 0.040473468601703644
Validation Loss: 0.02635749988257885
Validation Loss: 0.06885123252868652
Validation Loss: 0.04002239182591438
Validation Loss: 0.03546563535928726


Validation: |          | 0/? [00:00<?, ?it/s]

Validation Loss: 0.01991906203329563
Validation Loss: 0.04703469201922417
Validation Loss: 0.037046097218990326
Validation Loss: 0.03282788768410683
Validation Loss: 0.05853511020541191
Validation Loss: 0.01730584353208542
Validation Loss: 0.01995965838432312
Validation Loss: 0.015467851422727108
Validation Loss: 0.027159377932548523
Validation Loss: 0.03776974976062775
Validation Loss: 0.05938779562711716
Validation Loss: 0.08037101477384567
Validation Loss: 0.07629185914993286
Validation Loss: 0.024963991716504097
Validation Loss: 0.09735530614852905
Validation Loss: 0.03880855813622475
Validation Loss: 0.057613808661699295
Validation Loss: 0.04448091611266136
Validation Loss: 0.047008953988552094
Validation Loss: 0.04185080528259277
Validation Loss: 0.07893676310777664
Validation Loss: 0.06365572661161423
Validation Loss: 0.039560750126838684
Validation Loss: 0.027693048119544983
Validation Loss: 0.07384626567363739
Validation Loss: 0.03956305980682373
Validation Loss: 0.03958076238

`Trainer.fit` stopped: `max_epochs=10` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
SLURM auto-requeueing enabled. Setting signal handlers.


Validation Loss: 0.0241930540651083
Validation Loss: 0.0407274067401886
Validation Loss: 0.023255830630660057
Validation Loss: 0.026588086038827896
Validation Loss: 0.030264703556895256
Validation Loss: 0.01730334199965
Validation Loss: 0.03086857497692108
Validation Loss: 0.03664743900299072
Validation Loss: 0.09494487196207047
Validation Loss: 0.06236562132835388
Validation Loss: 0.03820749744772911
Validation Loss: 0.03410409390926361
Validation Loss: 0.031864166259765625
Validation Loss: 0.022941075265407562
Validation Loss: 0.21155264973640442
Validation Loss: 0.3739306926727295
Validation Loss: 0.08282119035720825
Validation Loss: 0.02928859367966652
Validation Loss: 0.011288709938526154
Validation Loss: 0.028414247557520866
Validation Loss: 0.0407157726585865
Validation Loss: 0.03598395735025406
Validation Loss: 0.09035992622375488
Validation Loss: 0.039944637566804886
Validation Loss: 0.07138731330633163
Validation Loss: 0.017857879400253296
Validation Loss: 0.02252500876784324

/global/homes/p/parshvam/.local/perlmutter/python-3.11/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.0693124383687973
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.0693124383687973}]

In [16]:
%load_ext tensorboard

In [17]:
%tensorboard --logdir logs --port=6006