In [20]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Normal
from torch.distributions.kl import kl_divergence
import sys
sys.path.append('..')

import logging
from src.utils import setup_logging
setup_logging(console=True, file=True, debug=True, file_basename="k_sweep_BIG")
logger = logging.getLogger(__name__)

import pandas as pd
from data.tempdata import TempData
import matplotlib.pyplot as plt

from src.plot import plot_predictive
from src.informed_np import InformedNeuralProcess
from tqdm import tqdm
# from src.loss import ELBOLoss
from src.loss import ELBOLoss
from src.train import train  
from src.plot import plot_predictive

[INFO]: Logging setup completed at 15-08-2024-100551        (utils.py:62 [10:05:51])
[INFO]: Logging setup completed at 15-08-2024-100551        (utils.py:62 [10:05:51])


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logging.info(f'Using DEVICE: {DEVICE}')

x_dim = 1
y_dim = 1
determ_dim = 128  # Dimension of representation of context points
latent_dim = 128  # Dimension of sampled latent variable
hidden_dim = 128  # Dimension of hidden layers in encoder and decoder

USE_KNOWLEDGE = True

args = dict(
            x_dim=x_dim,
            y_dim=y_dim,
            hidden_dim=hidden_dim,
            latent_dim=latent_dim,
            determ_dim=determ_dim,
            knowledge_dim=128,
            mlps_activation=nn.GELU(),
            x_proj_dim=1,
            n_h_layers_x_proj=0,
            n_h_layers_decoder=4,
            n_h_layers_latent_xy_encoder=3,
            n_h_layers_film_latent_encoder=3,
            path='latent',
            train_num_z_samples=4,
            test_num_z_samples=32,
            use_bias=True,
            use_context_in_target=True, # TODO investigate
            use_latent_self_attn=True,
            # use_determ_self_attn=True,
            # use_determ_cross_attn=True,
            use_knowledge=USE_KNOWLEDGE,
            knowledge_dropout=0.4,
            roberta_return_cls=True,
            tune_llm_layer_norms=True,
            freeze_llm=True,
            knowledge_projection_n_h_layers=0,
            knowledge_aggregation_method='FiLM+MLP',
            device='cuda',
            beta=1.0
        )
# assert "use_knowledge" not in args
# assert "knowledge_dropout" not in args

data_path = '../data/data_with_desc.csv'
data_df = pd.read_csv(data_path, header=None)


AVG_LOSS_PRINT_INTERVAL = 250
PLOT_SAMPLE_INTERVAL = 1000
MAX_ITERS = 10000

LEARNING_RATE = 1e-3
loss_function = ELBOLoss(beta=1, reduction='mean')
random_states = [85, 98, 87]

MAX_NUM_CONTEXT = 10



[INFO]: Using DEVICE: cuda:0        (2661872398.py:2 [10:05:53])
[INFO]: Using DEVICE: cuda:0        (2661872398.py:2 [10:05:53])


In [22]:
data = TempData(data=data_df , max_num_context=MAX_NUM_CONTEXT, device=DEVICE, random_state=85)

inp_model = InformedNeuralProcess(
        **args
).to(DEVICE)

inp_model.load_state_dict(torch.load('../exp/sweep/inp-kdropsweep-0.4-rs-85_iter4000.pt'))

[DEBUG]: XYEncoder has x_dim=1 and y_dim=1        (xy_encoders.py:57 [10:05:55])
[DEBUG]: XYEncoder has x_dim=1 and y_dim=1        (xy_encoders.py:57 [10:05:55])
[DEBUG]: Resetting dropped connection: huggingface.co        (connectionpool.py:291 [10:05:56])
[DEBUG]: Resetting dropped connection: huggingface.co        (connectionpool.py:291 [10:05:56])
[DEBUG]: https://huggingface.co:443 "HEAD /roberta-base/resolve/main/config.json HTTP/1.1" 200 0        (connectionpool.py:474 [10:05:56])
[DEBUG]: https://huggingface.co:443 "HEAD /roberta-base/resolve/main/config.json HTTP/1.1" 200 0        (connectionpool.py:474 [10:05:56])
[DEBUG]: https://huggingface.co:443 "HEAD /roberta-base/resolve/main/config.json HTTP/1.1" 200 0        (connectionpool.py:474 [10:05:56])
[DEBUG]: https://huggingface.co:443 "HEAD /roberta-base/resolve/main/config.json HTTP/1.1" 200 0        (connectionpool.py:474 [10:05:56])
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-bas

<All keys matched successfully>

In [5]:
batch_size = 32
batch = data.generate_batch(batch_size=batch_size,
                                        device=DEVICE,
                                        return_knowledge=USE_KNOWLEDGE,
                                        split='test')

print(batch.knowledge)

['The night will start off cold with temperatures falling to a chilly -13.8°C in the early morning, but as we move into the afternoon, temperatures will gradually rise to a slightly less frosty -5.9°C. Expect the evening to remain consistently cold, hovering around -6°C.', 'The night will start off slightly chilly, gradually getting colder towards early morning with the lowest temperature at -1.9°C. The day will continue to be cold, reaching a severe low of -7.3°C in the late evening.', 'The night will start off cool with temperatures around 8.9°C, gradually decreasing to 8.4°C by early morning. Temperatures will rise during the day, peaking at 10.4°C in the afternoon, before cooling off to 8.7°C by late evening.', 'The night will start off chilly, gradually falling to a low of -3.0°C in the early morning, before temperatures slowly climb, reaching a high of 3.7°C in the late evening. Expect a cold day with a slight warming trend towards the end of the day.', 'The night will start off 

In [6]:
plt.style.use('../figures/mplstyles/thesis.mplstyle')
num_trajectories = 2

inp_model.training = False
with torch.no_grad():
    # make device be the device of the model
    p_y_pred, q_z_context, q_z_target = inp_model(batch.x_context,
                                              batch.y_context,
                                              batch.x_target,
                                              batch.y_target,
                                              batch.knowledge)

    print(f"NLL: {loss_function(p_y_pred, q_z_context, None, batch.y_target)['loss'].item()}")
    
    mu = p_y_pred.mean  # Shape [num_z_samples, batch_size, num_target_points, y_dim=1]
    sigma = p_y_pred.stddev


LINE_COLOURS = ['#1f77b4', '#ff7f0e', '#2ca02c',]
FILL_COLOURS = ['#A6CEE3', '#f0c39c', '#b8deb8']


batch_size, _, _ = batch.x_context.shape
assert batch_size <= 3, 'Batch size should be <= 3 for plot clarity'


x_context, y_context = batch.x_context.cpu(), batch.y_context.cpu()
x_target, y_target = batch.x_target.cpu(), batch.y_target.cpu()
mu, sigma = mu.cpu(), sigma.cpu()

plt.figure(figsize=(6, 3))
for i in range(batch_size):
    plt.plot(x_target[i].flatten(), y_target[i].flatten(), 'k:')  # Plot ground truth GP
    

    num_z_samples = mu.shape[0]
    assert num_trajectories <= num_z_samples, "num_trajectories must be less than num_z_samples"
    z_sample_idx = np.random.choice(num_z_samples, size=num_trajectories)
    for j in z_sample_idx:
        plt.plot(x_target[i].flatten(), mu[j, i].flatten(), color=LINE_COLOURS[i])
        plt.fill_between(
            x_target[i].flatten(),
            mu[j, i].flatten() - sigma[j, i].flatten(),
            mu[j, i].flatten() + sigma[j, i].flatten(),
            alpha=0.3,
            facecolor=FILL_COLOURS[i],
            interpolate=True)
    plt.scatter(x_context[i].flatten(), y_context[i].flatten(), c='k')  # Plot context points
#plt.ylim(-4, 4)

plt.xlim(-2, 2)
# print(list(x_target[0::36])+ [2.0])

# Formatting the x-axis to display time in "HHMM" format
plt.xticks(list(x_target[0].flatten()[::36])+ [2.0], labels=["0000", "0300", "0600", "0900", "1200", "1500", "1800", "2100", "2400"])

# Label axes
plt.xlabel('Time (HHMM)')
plt.ylabel('Temperature (°C)')
plt.ylabel('Temperature (°C)')
# plt.savefig(f'{save_name}.png', dpi=300)
plt.show()

NLL: 484.36859130859375


AssertionError: Batch size should be <= 3 for plot clarity

In [None]:
import re
import random

n_values = [-30, -15, -8, -5, -3, -2, -1, 0, 1, 2, 3, 5, 8, 15, 30]

def perturb_knowledge(knowledge_string, n):
    def perturb_number(match):
        number = int(match.group())
        new_number = number + n
        return str(new_number)
    
    return re.sub(r'\d+', perturb_number, knowledge_string)

perturbed_results = []

for knowledge_string in batch.knowledge:
    print(f"\nOriginal: {knowledge_string}")
    perturbed_list = []
    for n in n_values:
        perturbed = perturb_knowledge(knowledge_string, n)
        perturbed_list.append(perturbed)
        print(f"Perturbed (n={n}): {perturbed}")
    perturbed_results.append(perturbed_list)

print("\nAccessing results:")
for i, original in enumerate(batch.knowledge):
    print(f"\nOriginal: {original}")
    for j, n in enumerate(n_values):
        print(f"Perturbed (n={n}): {perturbed_results[i][j]}")

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

plt.style.use('../figures/mplstyles/thesis.mplstyle')
num_trajectories = 2
inp_model.training = False

LINE_COLOURS = ['#1f77b4', '#ff7f0e', '#2ca02c']
FILL_COLOURS = ['#A6CEE3', '#f0c39c', '#b8deb8']

def run_experiment(original_knowledge, changed_knowledge, n_value):
    with torch.no_grad():
        p_y_pred, q_z_context, q_z_target = inp_model(batch.x_context,
                                                      batch.y_context,
                                                      batch.x_target,
                                                      batch.y_target,
                                                      [changed_knowledge])
        nll = loss_function(p_y_pred, q_z_context, None, batch.y_target)['loss'].item()
        print(f"Original: {original_knowledge}")
        print(f"Perturbed (n={n_value}): {changed_knowledge}")
        print(f"NLL: {nll}")
        print("-" * 50)
        
        mu = p_y_pred.mean
        sigma = p_y_pred.stddev
        
    '''batch_size, _, _ = batch.x_context.shape
    assert batch_size <= 3, 'Batch size should be <= 3 for plot clarity'
    x_context, y_context = batch.x_context.cpu(), batch.y_context.cpu()
    x_target, y_target = batch.x_target.cpu(), batch.y_target.cpu()
    mu, sigma = mu.cpu(), sigma.cpu()
    
    plt.figure(figsize=(6, 3))
    for i in range(batch_size):
        plt.plot(x_target[i].flatten(), y_target[i].flatten(), 'k:')  # Plot ground truth GP
        
        num_z_samples = mu.shape[0]
        assert num_trajectories <= num_z_samples, "num_trajectories must be less than num_z_samples"
        z_sample_idx = np.random.choice(num_z_samples, size=num_trajectories)
        for j in z_sample_idx:
            plt.plot(x_target[i].flatten(), mu[j, i].flatten(), color=LINE_COLOURS[i])
            plt.fill_between(
                x_target[i].flatten(),
                mu[j, i].flatten() - sigma[j, i].flatten(),
                mu[j, i].flatten() + sigma[j, i].flatten(),
                alpha=0.3,
                facecolor=FILL_COLOURS[i],
                interpolate=True)
        plt.scatter(x_context[i].flatten(), y_context[i].flatten(), c='k')  # Plot context points
    
    plt.xlim(-2, 2)
    plt.xticks(list(x_target[0].flatten()[::36]) + [2.0], labels=["0000", "0300", "0600", "0900", "1200", "1500", "1800", "2100", "2400"])
    plt.xlabel('Time (HHMM)')
    plt.ylabel('Temperature (°C)')
    plt.title(f"n={n_value}, NLL={nll:.4f}")
    plt.tight_layout()
    plt.show()'''

for original, perturbed_list, n_values in zip(batch.knowledge, perturbed_results, [n_values] * len(batch.knowledge)):
    for perturbed, n in zip(perturbed_list, n_values):
        run_experiment(original, perturbed, n)

In [None]:
from scipy import stats

plt.style.use('../figures/mplstyles/thesis.mplstyle')
num_trajectories = 2
inp_model.training = False

def calculate_nll(knowledge):
    with torch.no_grad():
        p_y_pred, q_z_context, q_z_target = inp_model(batch.x_context,
                                                      batch.y_context,
                                                      batch.x_target,
                                                      batch.y_target,
                                                      [knowledge])
        nll = loss_function(p_y_pred, q_z_context, None, batch.y_target)['loss'].item()
    return nll

# Calculate NLL for unperturbed knowledge
unperturbed_nlls = [calculate_nll(k) for k in batch.knowledge]

# Calculate NLL changes for perturbed knowledge
nll_changes = []
for original, perturbed_list, unperturbed_nll in zip(batch.knowledge, perturbed_results, unperturbed_nlls):
    changes = []
    for perturbed in perturbed_list:
        nll = calculate_nll(perturbed)
        change = abs(nll - unperturbed_nll)
        changes.append(change)
    nll_changes.append(changes)

# Convert to numpy array for easier manipulation
nll_changes = np.array(nll_changes)

print(nll_changes)

# Calculate mean and standard error of NLL changes
mean_changes = np.mean(nll_changes, axis=0)
se_changes = stats.sem(nll_changes, axis=0)

# Plot
plt.figure(figsize=(6, 3))
plt.errorbar(n_values, mean_changes, yerr=se_changes, fmt='-o', capsize=2.5, markersize=4,  mfc='white')
plt.xlabel('Number added to numbers in knowledge string')
plt.ylabel(r'Average $|\Delta\text{NLL}|$')
plt.title('Effect of changing the numbers in the knowledge')
plt.grid(True, linestyle='--', alpha=0.7)

plt.tight_layout()
plt.savefig('../figures/number_robustness_graph.png', dpi=600)
plt.show()

# Print detailed results
print("Detailed Results:")
print("n\tMean Change\tStandard Error")
for n, mean, se in zip(n_values, mean_changes, se_changes):
    print(f"{n}\t{mean:.4f}\t\t{se:.4f}")

In [None]:
['The night will be bitterly cold with temperatures hovering around -14 degrees, slightly warming up to -11 degrees in the afternoon. Expect a chilly evening as temperatures continue to remain below freezing.']