In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split

from model import AttentionModel, XGBoostModel

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [3]:
class NO2Dataset(Dataset):
    def __init__(self, df, max_days=15):
        self.max_days = max_days
        # Keep a global index directly from the DataFrame's index after sorting
        self.data = df.sort_values(by=['LAT', 'LON', 'Date']).reset_index(drop=False)
        self.data['Date'] = pd.to_datetime(self.data['Date'])
        
        # Create a column for global indices using the reset index
        self.data['global_idx'] = self.data.index
        
        self.locations = self.data.groupby(['LAT', 'LON']).groups
        self.location_keys = list(self.locations.keys())

         # Prepare samples with (location, global index, id_zindi)
        self.samples = []
        for loc in self.location_keys:
            for idx in self.locations[loc]:
                id_zindi = self.data.loc[idx, 'ID_Zindi']  # Assuming this column exists
                self.samples.append((loc, idx, id_zindi))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Get the location and index from self.samples (global index)
        location, global_data_idx, id_string = self.samples[idx]

        # Get all data for the location and sort by 'Date'
        loc_data = self.data.loc[self.locations[location]].sort_values(by='Date')

        # Instead of filtering by global_idx again, directly access the row using iloc
        loc_data_row = loc_data.iloc[(global_data_idx - loc_data.index[0])]

        if loc_data_row is None:
            print(f"No data found for global index {global_data_idx} in location {location}")
            return None  # Handle this case appropriately
        
        # Extract the current date from the location data
        current_date = loc_data_row['Date']
        
        # Define date range for the last `max_days` days including current date
        start_date = current_date - pd.DateOffset(days=self.max_days - 1)
        end_date = current_date

        # Get past data for the last `max_days` days
        past_data = loc_data[(loc_data['Date'] >= start_date) & (loc_data['Date'] <= end_date)]

        # Padding if fewer than `max_days` days
        if len(past_data) < self.max_days:
            num_padding_days = self.max_days - len(past_data)
            padding_dates = pd.date_range(end=start_date - pd.DateOffset(days=1), periods=num_padding_days)
            padding_data = pd.DataFrame({
                'Date': padding_dates,
                'LAT': loc_data['LAT'].iloc[0],  # Fill with location LAT
                'LON': loc_data['LON'].iloc[0],  # Fill with location LON
                'LST': 0, 'AAI': 0, 'CloudFraction': 0, 'Precipitation': 0, 'NO2_strat': 0, 
                'NO2_total': 0, 'NO2_trop': 0, 'TropopausePressure': 0,
                'index': -1,  # Use -1 to indicate padding rows
                'global_idx': -1  # Same for global_idx
                })

            # Concatenate padding and past data
            past_data = pd.concat([padding_data, past_data], ignore_index=True)

        # Sort past data again (optional) to ensure order
        past_data = past_data.sort_values(by='Date').reset_index(drop=True)

        # Extract the relevant features and convert to tensors
        features_tensor = torch.tensor(past_data[['LST', 'AAI', 'CloudFraction', 'Precipitation', 
                                                  'NO2_strat', 'NO2_total', 'NO2_trop', 
                                                  'TropopausePressure', 'LAT', 'LON']].values, dtype=torch.float32)
        # lat = torch.tensor(past_data['LAT'].values[0], dtype=torch.float32)  # Only one LAT value
        # lon = torch.tensor(past_data['LON'].values[0], dtype=torch.float32)  # Only one LON value
        
        # Return the feature tensor, lat/lon, and ground truth
        return features_tensor, id_string


def collate_fn(batch):
    features, id_string = zip(*batch)

    features_padded = pad_sequence(features, batch_first=True)  # (batch_size, max_seq_len, num_features)


    return features_padded, id_string

In [4]:
def test_model(model, xgb, test_loader):
    model.eval() 
    model.to(device)  
    all_preds = []
    all_ids = []

    with torch.no_grad(): 
        for i, (features_seq, ids) in enumerate(test_loader):
            features_seq = features_seq.to(device)
            test_outputs = model(features_seq)
            predicitions = xgb.inference(test_outputs)
            # print("------------------------------")
            # print(test_outputs.squeeze().cpu().numpy())
            # print("------------------------------")
            all_preds.append(predicitions.cpu().numpy())
            all_ids.extend(ids) 

            print(f'Batch {i+1} processed', end='\r')

    all_preds = np.concatenate(all_preds)

    print(f'Number of predictions: {len(all_preds)}')
    print(f'Number of IDs: {len(all_ids)}')

    if len(all_preds) != len(all_ids):
        print("Warning: The number of predictions and IDs do not match!")
    
    results_df = pd.DataFrame({'ID': all_ids, 'Predicted_NO2': all_preds})
    results_df.to_csv('test_predictions.csv', index=False)


In [7]:
checkpoint_path = 'trained-model-xgboost/best_Att-CNN-LSTM_model_22.pt'

model = AttentionModel()
xgb = XGBoostModel(0)
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.to(device)

AttentionModel(
  (conv1d): Conv1d(10, 64, kernel_size=(1,), stride=(1,))
  (dropout): Dropout(p=0.5, inplace=False)
  (bilstm): LSTM(64, 128, batch_first=True, bidirectional=True)
  (attention_block): AttentionBlock(
    (multihead_attention): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
    )
    (fc): Linear(in_features=256, out_features=15, bias=True)
    (attention_fc): Linear(in_features=15, out_features=256, bias=True)
  )
  (fc1): Linear(in_features=256, out_features=128, bias=True)
  (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=128, out_features=8, bias=True)
  (bn2): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [8]:
dataset = pd.read_csv('./Test_Cleaned_KNN.csv')

batch_size = 128

no2_dataset = NO2Dataset(dataset)
test_loader = DataLoader(no2_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [9]:
test_model(model, xgb, test_loader)

Number of predictions: 6576
Number of IDs: 6576
