In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datetime import datetime
from dateutil.relativedelta import relativedelta
from collections import defaultdict
from tqdm import tqdm
from torch.profiler import profile, record_function, ProfilerActivity

# scikit-learn scaler
from sklearn.preprocessing import StandardScaler

def month_range(start, end):
    """
    Return a list of monthly dates (as strings 'YYYY-MM-01')
    from start to end inclusive.
    """
    dates = []
    cur = datetime.strptime(start, "%Y-%m-%d")
    stop = datetime.strptime(end, "%Y-%m-%d")
    while cur <= stop:
        dates.append(cur.strftime("%Y-%m-01"))
        # move forward one month
        cur += relativedelta(months=1)
    return dates

def drop_all_nan_columns_before_cutoff(arr, cutoff_idx, table_name):
    """
    Drop the columns that are all NaNs before or at the cutoff index.
    :param arr: np.array of shape (n, m)
    :return: np.array of shape (n, m') where m' <= m
    """
    mask = np.all(np.isnan(arr[:cutoff_idx, :]), axis=0)
    if mask.sum() > 0:
        print(f"Dropping {mask.sum()} columns with all NaNs before cutoff for table {table_name}")

    return arr[:, ~mask]


In [2]:
from typing import List, Tuple

def read_and_scale_tables(
        csv_file_paths,
        meta_file_paths,
        start_date="1960-01-01",
        end_date="2024-01-01",
        train_cutoff_str="2018-01-01",
        drop_cutoff_str="2004-01-01"
):
    """
    Reads each CSV data file and corresponding CSV meta file,
    aligns data to a monthly timeline, scales the data,
    and returns:
      - table_data_dict: dict[table_name] -> np.array of shape (num_months, k_i)
      - meta_data_dict: dict[table_name] -> dict[col_index -> freq_str]
      - scalers: dict[table_name] -> a fitted StandardScaler
      - monthly_dates: list of monthly date strings

    :param csv_file_paths: list of str (data CSVs, one per table)
    :param meta_file_paths: list of str (metadata CSVs, same length as above)
    :param start_date: str, earliest date
    :param end_date: str, latest date
    :param train_cutoff_str: str, date boundary for training
    :param drop_cutoff_str: str, date boundary for dropping columns with all NaNs
    """
    assert len(csv_file_paths) == len(meta_file_paths), \
        "Must have matching data and meta files"

    monthly_dates = month_range(start_date, end_date)
    date_to_idx = {d: i for i, d in enumerate(monthly_dates)}
    num_months = len(monthly_dates)

    train_cutoff_idx = date_to_idx[train_cutoff_str]

    table_data_dict = {}
    meta_data_dict = {}
    scalers = {}

    for data_path, meta_path in tqdm(zip(csv_file_paths, meta_file_paths), desc="Reading and scaling tables", total=len(csv_file_paths)):
        table_name = data_path.split("/")[-1].replace(".csv", "")

        # Read the data CSV
        df = pd.read_csv(data_path)
        feature_cols = [c for c in df.columns if c != "DATE_PARSED"]

        # Read the meta CSV
        df_meta = pd.read_csv(meta_path)

        # Create a dictionary col_index -> freq_str
        col_to_freq = {}
        for col_index, col_name in enumerate(feature_cols):
            row_meta = df_meta[df_meta['TITLE_FR'] == col_name]
            if len(row_meta) == 0:
                raise ValueError(f"Could not find metadata for column {col_name}")
            else:
                freq_str = row_meta['FREQ'].values[0]
            col_to_freq[col_index] = freq_str

        # Now create array_data (num_months, k_i)
        array_data = np.full((num_months, len(feature_cols)), np.nan, dtype=np.float32)
        for _, row in df.iterrows():
            date_str = str(row["DATE_PARSED"])
            if date_str in date_to_idx:
                idx = date_to_idx[date_str]
                array_data[idx] = row[feature_cols].values.astype(np.float32)

        # Optionally drop columns that are all NaN up to the cutoff, this should have been done in the ETL
        array_data = drop_all_nan_columns_before_cutoff(array_data, train_cutoff_idx, table_name)

        # The dropping above might reduce shape from (num_months, k_i) to fewer columns
        # so we also need to shrink col_to_freq accordingly:
        keep_mask = ~np.all(np.isnan(array_data[:train_cutoff_idx, :]), axis=0)
        # new frequency map
        old_col_indices = np.where(keep_mask)[0]
        new_col_to_freq = {}
        for new_i, old_i in enumerate(old_col_indices):
            new_col_to_freq[new_i] = col_to_freq[old_i]
        col_to_freq = new_col_to_freq

        # Fit StandardScaler on training portion (ignoring NaNs by filling w/ mean)
        scaler = StandardScaler()
        train_data = array_data[:train_cutoff_idx]  # shape (train_cutoff_idx, k_i')

        col_means = np.nanmean(train_data, axis=0)
        train_data_copy = train_data.copy()
        for c in range(train_data_copy.shape[1]):
            np.place(train_data_copy[:, c], np.isnan(train_data_copy[:, c]), col_means[c])

        scaler.fit(train_data_copy)

        # Transform entire array, ignoring NaNs by temp-filling them
        array_copy = array_data.copy()
        nan_mask = np.isnan(array_copy)
        array_copy[nan_mask] = 0.0
        scaled_data = scaler.transform(array_copy)
        scaled_data[nan_mask] = np.nan  # put NaNs back

        # Store results
        table_data_dict[table_name] = scaled_data
        meta_data_dict[table_name] = col_to_freq
        scalers[table_name] = scaler

    return table_data_dict, meta_data_dict, scalers, monthly_dates


In [3]:
def is_expected(freq_str, date_str):
    """
    freq_str: 'A', 'Q', or 'M'
    date_str: e.g. '2021-03-01'
    Returns True if we *should* have a value on this date.
    """
    dt = datetime.strptime(date_str, "%Y-%m-%d")
    m = dt.month
    if freq_str == 'M':
        return True
    elif freq_str == 'Q':
        # let’s say Q covers months 3,6,9,12
        return (m in [3,6,9,12])
    elif freq_str == 'A':
        # annual => only january
        return (m == 1)
    else:
        # default
        return True

In [4]:
def is_expected_mask_for_col(freq_str, months_of_date):
    """
    Given a freq_str in {'A','Q','M'} and an array of month integers,
    return a boolean mask of where we expect a value for that freq.
    E.g. for 'Q', we might only expect months 3,6,9,12; for 'A', only month 1, etc.
    """
    if freq_str == 'M':
        return np.ones_like(months_of_date, dtype=bool)
    elif freq_str == 'T':
        return np.isin(months_of_date, [1, 4, 7, 10])
    elif freq_str == 'A':
        return (months_of_date == 1)
    else:
        # Default: raise error
        raise ValueError(f"Unknown frequency string: {freq_str}")

def generate_expected_and_truly_missing_masks_vectorized(
    table_array,    # shape (L, k_i), might contain np.nan
    freq_dict,      # dict[col_index -> freq_str], e.g. {0:'A',1:'Q',2:'M',...}
    date_list       # list of date strings 'YYYY-MM-01' of length L
):
    """
    Generates two boolean masks of shape (L, k_i):
      - expected_missing_mask: True where the value is *missing* but that is *expected*
                               (due to freq not reporting that month)
      - truly_missing_mask:    True where the value is missing but it *should* have been reported
    """
    L, k_i = table_array.shape

    # 1) Parse months out of each date in date_list
    #    We'll create an array of shape (L,) containing the month number.
    months_of_date = np.array([
        datetime.strptime(d, "%Y-%m-%d").month for d in date_list
    ], dtype=int)  # shape (L,)

    # 2) Build an array freq_array of shape (k_i,) from freq_dict
    #    so freq_array[c] = freq_dict[c]
    freq_array = np.array([freq_dict[c] for c in range(k_i)], dtype=object)  # shape (k_i,)

    # 3) Build an (L, k_i) boolean array "expected_array"
    #    which is True where freq says "we expect a value"
    #    We'll do this by computing a mask for each column, then stacking.
    expected_array = np.zeros((L, k_i), dtype=bool)
    for c in range(k_i):
        freq_str = freq_array[c]
        expected_array[:, c] = is_expected_mask_for_col(freq_str, months_of_date)

    # 4) Now check where table_array is NaN => "missing"
    missing_mask = np.isnan(table_array)  # shape (L, k_i)

    # 5) truly_missing => "missing" AND "expected"
    truly_missing_mask = missing_mask & expected_array
    # 6) expected_missing => "missing" AND "NOT expected"
    expected_missing_mask = missing_mask & ~expected_array

    return expected_missing_mask, truly_missing_mask

In [5]:
import json

with open("Data/all_data.json", "r") as f:
    all_data = json.load(f)

csv_file_paths = [f"Data/{table_name}.csv" for table_name in all_data]
csv_file_paths_meta = [f"Data/{table_name}_meta.csv" for table_name in all_data]

table_data_dict, meta_data_dict, scalers, monthly_dates = read_and_scale_tables(
    csv_file_paths,
    csv_file_paths_meta,
    start_date="1970-01-01",
    end_date="2024-01-01",
    train_cutoff_str="2018-01-01"
)

Reading and scaling tables: 100%|██████████| 76/76 [00:09<00:00,  8.39it/s]


In [6]:
def test_data_integrity(table_data_dict, meta_data_dict, monthly_dates):
    # Test that all arrays have the same number of months as monthly_dates
    num_months = len(monthly_dates)
    for table_name, table_data in table_data_dict.items():
        assert table_data.shape[0] == num_months, \
            f"Table {table_name} has {table_data.shape[0]} months, expected {num_months}"

    # Test that all arrays have the same number of columns as meta_data_dict
    for table_name, table_data in table_data_dict.items():
        assert table_data.shape[1] == len(meta_data_dict[table_name]), \
            f"Table {table_name} has {table_data.shape[1]} columns, expected {len(meta_data_dict[table_name])}"

    # Test that are not null where True mask AND expected mask == 0
    for table_name, table_data in table_data_dict.items():
        expected_mask, truly_missing_mask = generate_expected_and_truly_missing_masks_vectorized(
            table_data, meta_data_dict[table_name], monthly_dates
        )

        place_with_value = np.where(~(expected_mask | truly_missing_mask))

        # Check if there are any nans where we expect a value
        assert not np.isnan(table_data[place_with_value]).any(), \
            f"Table {table_name} has NaNs where we expect a value"

    print("Data integrity test passed!")

print("Testing data integrity...")
test_data_integrity(table_data_dict, meta_data_dict, monthly_dates)

Testing data integrity...
Data integrity test passed!


true_mask is true wherever the value is missing and it should not be (report date is consistent with the presence of a value there)

expected_mask is true wherever the value is missing and it should be (report date is consistent with the absence of a value there)

In [120]:
import numpy as np
import torch
from torch.utils.data import Dataset
from typing import Optional, Dict, List

class EconDataset(Dataset):
    def __init__(self,
                 table_data_dict: Dict[str, np.ndarray],
                 monthly_dates: List[str],
                 expected_missing_dict: Optional[Dict[str, np.ndarray]],
                 true_missing_dict: Optional[Dict[str, np.ndarray]],
                 min_window_length_year: int = 1,
                 max_window_length_year: Optional[int] = None,
                 train: bool = True,
                 test_start_date: str = "2018-01-01",
                 number_of_samples: int = 100_000,
                 # -- Masking probabilities
                 p_1_none: float = 0.1,
                 p_2_uniform: float = 0.2,
                 p_3_last1yr: float = 0.3,
                 p_4_last2yr: float = 0.1,
                 p_5_table: float = 0.3,
                 p_uniform: float = 0.2,     # Probability to mask each cell in uniform masking
                 seed: Optional[int] = None,
                 inference_mode: bool = False):
        """
        table_data_dict: dict[table_name] -> (num_months, k_i) scaled arrays
        monthly_dates: list of str, aligned to the arrays in table_data_dict
        expected_missing_dict: dict[table_name] -> (num_months, k_i) array of expected missing indicators
        true_missing_dict: dict[table_name] -> (num_months, k_i) array of true missing indicators
        min_window_length_year: minimum window length (in years)
        max_window_length_year: maximum window length (in years), if None => no upper bound
        train: whether this dataset is for training or test
        test_start_date: str, e.g. "2018-01-01"
        number_of_samples: total number of random samples (i.e. random time windows) to generate
        p_1_none, p_2_uniform, p_3_last1yr, p_4_last2yr, p_5_table: probabilities for the 5 masking modes
        p_uniform: for the uniform random mask, each cell has this probability of being masked
        seed: optional random seed for reproducibility
        """
        super().__init__()

        # -- Basic checks
        if not inference_mode:
            p_sum = p_1_none + p_2_uniform + p_3_last1yr + p_4_last2yr + p_5_table
            assert abs(p_sum - 1.0) < 1e-7, "Mask probabilities must sum to 1.0!"
        else:
            p_1_none, p_2_uniform, p_3_last1yr, p_4_last2yr, p_5_table = 0.0, 0.0, 0.0, 0.0, 0.0

        self.monthly_dates = monthly_dates

        # -- Build train/test boundary
        self.test_start_idx = self.monthly_dates.index(test_start_date)

        self.num_months = len(monthly_dates)
        self.train = train
        self.min_window_length_months = 12 * min_window_length_year
        if max_window_length_year is not None:
            self.max_window_length_months = 12 * max_window_length_year
        else:
            # If not specified, let’s default to using up to the entire range - 1
            self.max_window_length_months = self.test_start_idx - 1 if self.train else self.num_months - self.test_start_idx - 1

        # -- Masking probabilities
        self.p_1_none = p_1_none
        self.p_2_uniform = p_2_uniform
        self.p_3_last1yr = p_3_last1yr
        self.p_4_last2yr = p_4_last2yr
        self.p_5_table = p_5_table
        self.p_uniform = p_uniform

        # We store the total number of random samples
        self.number_of_samples = number_of_samples

        # Optionally set random seed
        if seed is not None:
            np.random.seed(seed)

        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Transform each table from (L, k_i) => (L, 3*k_i)
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~
        self.table_names = list(table_data_dict.keys())
        new_table_data_dict = {}

        for table_name in self.table_names:
            raw_table = table_data_dict[table_name]             # shape (L, k_i)
            exp_missing = expected_missing_dict[table_name]     # shape (L, k_i), boolean
            true_missing = true_missing_dict[table_name]        # shape (L, k_i)

            # Check where raw_table is NaN
            is_nan = np.isnan(raw_table)

            # Fill out each triple (value, expected_missing_bit, unexpected_missing_bit)
            L, k_i = raw_table.shape
            transformed_table = np.zeros((L, 3 * k_i), dtype=raw_table.dtype)

            # 1) First feature: if not NaN => x, if NaN => 0.0
            transformed_table[:, 0::3] = np.where(is_nan, 0.0, raw_table)

            # 2) Second feature: 1 if NaN & expected == True, else 0
            transformed_table[:, 1::3] = np.where(is_nan & exp_missing, 1.0, 0.0)

            # 3) Third feature: 1 if NaN & expected == False, else 0
            transformed_table[:, 2::3] = np.where(is_nan & true_missing, 1.0, 0.0)

            # Store the transformed array
            new_table_data_dict[table_name] = transformed_table.astype(np.float32)

        # Now replace the original table_data_dict with our new 3*k_i version
        self.table_data_dict = new_table_data_dict
        self.num_tables = len(self.table_names)

    def __len__(self):
        return self.number_of_samples

    def __getitem__(self, idx):
        """
        Returns a dict:
            {
              "full_data": { table_name -> np.ndarray of shape (window_length, 3*k_i) },
              "mask": np.ndarray of shape (window_length, num_tables),
            }
        """
        # 1) Sample a random window length
        window_length = np.random.randint(self.min_window_length_months,
                                          self.max_window_length_months + 1)

        # 2) Sample a random start within either train or test
        if self.train:
            max_start = self.test_start_idx - window_length
            if max_start < 0:
                raise ValueError("Not enough months for the requested window in training set.")
            start_idx = np.random.randint(0, max_start + 1)
        else:
            max_start = self.num_months - window_length
            if max_start < self.test_start_idx:
                raise ValueError("Not enough months for the requested window in test set.")
            start_idx = np.random.randint(self.test_start_idx, max_start + 1)

        end_idx = start_idx + window_length

        # 3) Prepare output data structures
        full_data = {}
        # shape: (window_length, num_tables), default zeros
        mask = np.zeros((window_length, self.num_tables), dtype=np.float32)

        # 4) Decide which masking mode to apply
        r = np.random.rand()
        if r < self.p_1_none:
            mask_mode = "none"
        elif r < self.p_1_none + self.p_2_uniform:
            mask_mode = "uniform"
        elif r < self.p_1_none + self.p_2_uniform + self.p_3_last1yr:
            mask_mode = "last1yr"
        elif r < (self.p_1_none + self.p_2_uniform + self.p_3_last1yr + self.p_4_last2yr):
            mask_mode = "last2yr"
        else:
            mask_mode = "table"

        # 5) Slice out the transformed data (now shape = (L, 3*k_i))
        for i, tn in enumerate(self.table_names):
            # the array is now (L, 3*k_i)
            table_array = self.table_data_dict[tn]
            full_data[tn] = table_array[start_idx:end_idx, :]

        # 6) Fill the mask
        if mask_mode == "none":
            # do nothing
            pass

        elif mask_mode == "uniform":
            # For each (t, i), mask with probability p_uniform
            random_matrix = np.random.rand(window_length, self.num_tables)
            mask[random_matrix < self.p_uniform] = 1.0

        elif mask_mode == "last1yr":
            omit_start = max(0, window_length - 12)
            mask[omit_start:, :] = 1.0

        elif mask_mode == "last2yr":
            omit_start = max(0, window_length - 24)
            mask[omit_start:, :] = 1.0

        elif mask_mode == "table":
            n_mask_tables = np.random.randint(1, self.num_tables + 1)
            table_indices_to_mask = np.random.choice(self.num_tables,
                                                     size=n_mask_tables,
                                                     replace=False)
            mask[:, table_indices_to_mask] = 1.0

        # 7) Return the sample
        return {
            "full_data": full_data,    # dict[str, np.ndarray], each (window_length, 3*k_i)
            "mask": mask,              # np.ndarray of shape (window_length, num_tables)
        }


In [121]:
def create_dataset(table_data_dict: Dict[str, np.ndarray],
                   meta_data_dict: Dict[str, Dict[int, str]],
                   monthly_dates: List[str],
                   train: bool=True,
                   min_window_length_year: int = 1,
                   max_window_length_year: Optional[int] = None,
                   test_start_date: str = "2018-01-01",
                   number_of_samples: int = 100_000,
                   # -- Masking probabilities
                   p_1_none: float = 0.1,
                   p_2_uniform: float = 0.2,
                   p_3_last1yr: float = 0.2,
                   p_4_last2yr: float = 0.2,
                   p_5_table: float = 0.3,
                   p_uniform: float = 0.3,     # Probability to mask each cell in uniform masking
                   seed: Optional[int] = None,
                   inference_mode: bool = False
                ):
    expected_missing_dict, true_missing_dict = {}, {}
    for table_name, table_data in table_data_dict.items():
        expected_missing, true_missing = generate_expected_and_truly_missing_masks_vectorized(table_data_dict[table_name], meta_data_dict[table_name], monthly_dates)
        expected_missing_dict[table_name] = expected_missing
        true_missing_dict[table_name] = true_missing

    dataset = EconDataset(
        table_data_dict=table_data_dict,
        monthly_dates=monthly_dates,
        expected_missing_dict=expected_missing_dict,
        true_missing_dict=true_missing_dict,
        min_window_length_year=min_window_length_year,
        max_window_length_year=max_window_length_year,
        train=train,
        test_start_date=test_start_date,
        number_of_samples=number_of_samples,
        p_1_none=p_1_none,
        p_2_uniform=p_2_uniform,
        p_3_last1yr=p_3_last1yr,
        p_4_last2yr=p_4_last2yr,
        p_5_table=p_5_table,
        p_uniform=p_uniform,
        seed=seed,
        inference_mode=inference_mode
    )

    return dataset

In [36]:
dataset = create_dataset(table_data_dict, meta_data_dict, monthly_dates, train=True, max_window_length_year=6)

In [77]:
import torch
import numpy as np
import time

def econ_collate_fn(batch):
    """
    Collate function for EconDataset.

    Args:
        batch: List of size B, each element is a dict:
            {
              "full_data": { table_name -> np.ndarray (L, 3*k_i) },
              "mask": np.ndarray (L, num_tables),
            }

    Returns a dict:
      {
        "full_data": { table_name -> FloatTensor of shape (B, L_max, 3*k_i) },
        "mask": BoolTensor of shape (B, L_max, N),
        "padding_mask": BoolTensor of shape (B, L_max, N),
      }
    where:
      - B = batch size
      - L_max = maximum sequence length in this batch
      - N = number of tables
      - k_i = dimension of the expanded features per table
    """
    # -------------------------
    # 1. Basic info
    # -------------------------
    batch_size = len(batch)
    lengths = [item["mask"].shape[0] for item in batch]      # each item["mask"] is (L, num_tables)
    L_max = max(lengths)                                     # maximum L among the batch
    num_tables = batch[0]["mask"].shape[1]

    # Build boolean masks
    mask_tensor = torch.zeros((batch_size, L_max, num_tables), dtype=torch.bool)
    padding_mask = torch.ones((batch_size, L_max, num_tables), dtype=torch.bool)

    # Fill them
    for b, item in enumerate(batch):
        L = item['mask'].shape[0]
        mask_tensor[b, :L, :] = torch.from_numpy(item['mask']).bool()
        padding_mask[b, :L, :] = False

    # -------------------------------------------------------
    # 2. Build 'full_data' for each table
    #    We'll create a dict {table_name -> FloatTensor (B, L_max, k_i)}
    # -------------------------------------------------------
    table_names = list(batch[0]["full_data"].keys())
    full_data_dict = {}

    t41 = 0
    t42 = 0
    # For each table, we figure out its feature dimension k_i by looking at the first item
    for tn in table_names:
        # shape of the first item is (L, k_i). We'll use that k_i for the entire batch
        _, k_i = batch[0]["full_data"][tn].shape

        # Allocate a PyTorch tensor (B, L_max, k_i) for this table, filled with 0.0 (pad value)
        s41 = time.time()
        data_np = np.empty((batch_size, L_max, k_i), dtype=np.float32)
        data_np.fill(0.0)
        t41 += time.time() - s41

        s42 = time.time()
        # Fill it with each item’s data
        for b, item in enumerate(batch):
            arr_np = item["full_data"][tn]   # shape = (L, 3*k_i)
            L = arr_np.shape[0]
            data_np[b, :L, :] = arr_np

        data_tensor = torch.from_numpy(data_np)
        t42 += time.time() - s42

        # Store this in our dictionary
        full_data_dict[tn] = data_tensor

    # -------------------------------------------------------
    # 5. Return the collated batch as a dict
    # -------------------------------------------------------
    batch_output = {
        "full_data": full_data_dict,         # {table_name -> FloatTensor (B, L_max, k_i)}
        "mask": mask_tensor,                 # BoolTensor (B, L_max, num_tables)
        "padding_mask": padding_mask,        # BoolTensor (B, L_max, num_tables)
    }

    return batch_output


In [63]:
s = time.time()
for _ in range(500):
    x = torch.zeros(8, 100, 100, 500, dtype=torch.float32)
print("Created 500 tensors in", time.time() - s, "seconds")

s = time.time()
for _ in range(500):
    x = np.zeros((8, 100, 100, 500), dtype=np.float32)
print("Created 500 numpy arrays in", time.time() - s, "seconds")

Created 500 tensors in 4.085897922515869 seconds
Created 500 numpy arrays in 0.008311033248901367 seconds


In [64]:
import time
s = time.time()
for i in range(100*4):
    x = dataset[i]

print("Loaded 400 samples in", time.time() - s, "seconds")

s = time.time()
tt41, tt42 = 0, 0
for i in range(100):
    x, t41, t42 = econ_collate_fn([dataset[4*i + j] for j in range(4)])
    tt41 += t41
    tt42 += t42

print("Collated 100 batches in", time.time() - s, "seconds")
print("Time breakdown:", tt41, tt42)

Loaded 400 samples in 0.009163618087768555 seconds
Collated 100 batches in 0.8471071720123291 seconds
Time breakdown: 0.45110011100769043 0.3578493595123291


In [65]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=econ_collate_fn)

s = time.time()
for i, batch in enumerate(dataloader):
    if i == 100:
        break
    pass
print("Iterated over the dataloader in", time.time() - s, "seconds")

Iterated over the dataloader in 0.8882815837860107 seconds


In [66]:
class TableEmbedding(nn.Module):
    def __init__(self, k_in, embed_dim):
        super().__init__()
        self.l1 = nn.Linear(k_in, embed_dim)
        self.l2 = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        """
        x: (B, L, k_in)
        returns: (B, L, embed_dim)
        """
        x = self.l1(x)
        x = F.relu(x)
        x = self.l2(x)
        return x


In [67]:
def key_padding_mask_to_attention_mask(key_padding_mask: torch.Tensor) -> torch.Tensor:
    """
    Convert a key_padding_mask of shape (B, S) to an attention_mask of shape (B, S, S).

    Args:
        key_padding_mask: Boolean tensor of shape (B, S) where True indicates padding tokens
                         and False indicates actual tokens.

    Returns:
        attention_mask: Boolean tensor of shape (B, S, S) where False indicates allowed
                       attention and True indicates masked (blocked) attention.
    """
    batch_size, seq_len = key_padding_mask.size()

    # First, we need to convert the key_padding_mask to the right shape
    # We want each position to not attend to padding tokens
    # So we expand the key_padding_mask to (B, 1, S) and broadcast it to (B, S, S)
    # Make it contiguous to ensure the memory layout allows setting elements
    expanded_mask = key_padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len).contiguous()

    return expanded_mask

def fix_fully_masked_rows(attn_mask_3d: torch.Tensor, key_padding_mask: torch.Tensor) -> torch.Tensor:
    """
    For any row b where key_padding_mask[b] is all True (i.e., fully masked),
    replace that entire (L, L) block in attn_mask_3d with ~torch.eye(L).
    This makes each token attend only to itself, preventing NaNs.
    """
    B, L, _ = attn_mask_3d.shape
    fully_masked_rows = key_padding_mask.all(dim=1)  # shape (B,)
    attn_mask_3d = attn_mask_3d.clone()
    attn_mask_3d[fully_masked_rows] = ~torch.eye(L, L, dtype=torch.bool, device=attn_mask_3d.device, requires_grad=False)
    return attn_mask_3d

In [68]:
import torch
import torch.nn as nn

class DoubleAttention(nn.Module):
    """
    Perform two-step attention on data of shape (B, L, N, E):
    1) Attention over N dimension (tables)
    2) Attention over L dimension (time)
    """
    def __init__(self, embed_dim, num_heads):
        super(DoubleAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        # We'll define two MultiheadAttention modules:
        # - attnN: handles attention across N
        # - attnL: handles attention across L
        self.attnN = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.attnL = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

    def forward(self, x: torch.Tensor, padding_mask: torch.Tensor):
        """
        x: Tensor of shape (B, L, N, E)
        padding_mask: BoolTensor of shape (B, L, N). True will be ignored in attention.
        returns: Tensor of shape (B, L, N, E) after double attention
        """
        B, L, N, E = x.shape

        # Clone padding mask to prevent any modification from affecting its reference
        padding_mask = padding_mask.clone()

        # -- Attention across N --
        # Flatten into (B*L, N, E)
        x_reshape = x.view(B * L, N, E).contiguous()
        padding_mask = padding_mask.view(B * L, N)
        # -----------------
        # This is where some issues can be introduced. We padded sequences to have same length.
        # We do attention across all clues for all pairs of month and elements in the batch.
        # /!\ We transform the padding mask into an attention mask and let padded elements attend to themselves.
        # /!\ We zero-out the attention scores for sequences that are entirely padded.
        # -----------------
        attn_mask = key_padding_mask_to_attention_mask(padding_mask)
        attn_mask = fix_fully_masked_rows(attn_mask, padding_mask) # (B*L, N, N)

        # Expand the mask for the number of heads
        attn_mask = attn_mask.unsqueeze(1).expand(B * L, self.num_heads, N, N).reshape(B * L * self.num_heads, N, N).contiguous()


        # Multi-head attention across N
        attn_outN, _ = self.attnN(x_reshape, x_reshape, x_reshape, attn_mask=attn_mask) # (B*L, N, E)
        attn_outN = attn_outN.masked_fill(padding_mask.all(dim=1).view(-1, 1, 1), 0.0)

        # Reshape back to (B, L, N, E)
        xN = attn_outN.view(B, L, N, E)
        padding_mask = padding_mask.view(B, L, N)

        # -- Attention across L --
        # Permute to (B, N, L, E) and (B, N, L)
        xN = xN.permute(0, 2, 1, 3).contiguous()
        padding_mask = padding_mask.permute(0, 2, 1).contiguous()

        # Flatten into (B*N, L, E)
        xN_reshape = xN.view(B * N, L, E)
        padding_mask = padding_mask.view(B * N, L)

        # Multi-head attention across L
        attn_outL, _ = self.attnL(xN_reshape, xN_reshape, xN_reshape, key_padding_mask=padding_mask)

        # -----------------
        # No issue should arise here as we are only attending across time dimension.
        # Elements replaced by 0.0 are masked elements in the second attention.
        # Each pair of batch element and table should have at least 1 month of data as min_time_window_year > 0.
        # -----------------

        # Reshape back to (B, N, L, E), we discard padding_mask
        xL = attn_outL.view(B, N, L, E)

        # Permute back to (B, L, N, E)
        out = xL.permute(0, 2, 1, 3).contiguous()

        return out

In [69]:
class TableDecoder(nn.Module):
    def __init__(self, embed_dim, k_out):
        super().__init__()
        self.linear = nn.Linear(embed_dim, k_out)

    def forward(self, x):
        """
        x: (B, L, embed_dim)
        -> (B, L, k_out)
        """
        return self.linear(x)


In [70]:
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [71]:
import torch
import torch.nn as nn
import math

class FeedForward(nn.Module):
    """
    Standard transformer feed-forward network with GELU activation
    """
    def __init__(self, embed_dim: int, ff_dim: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class TransformerLayer2D(nn.Module):
    """
    A single transformer layer using DoubleAttention
    Uses pre-norm architecture (norm before attention/FFN)
    """
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        ff_dim: int,
        dropout: float = 0.1
    ):
        super().__init__()
        self.attention = DoubleAttention(embed_dim, num_heads)
        self.ff = FeedForward(embed_dim, ff_dim, dropout)

        # Layer norms before attention and FF
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        # Dropout for attention output
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, padding_mask):
        # Pre-norm architecture
        # Attention block
        normed_x = self.norm1(x)
        attn_out = self.attention(normed_x, padding_mask)
        x = x + self.dropout(attn_out)

        # FFN block
        normed_x = self.norm2(x)
        ff_out = self.ff(normed_x)
        x = x + ff_out

        return x

class PositionalEncoding2D(nn.Module):
    """
    2D positional encoding that handles both time (L) and table (N) dimensions
    """
    def __init__(self, embed_dim: int, max_len: int = 5000):
        super().__init__()
        self.embed_dim = embed_dim

        # Create a standard 1D positional encoding for the time dimension
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim)
        )

        # Apply sin to even indices and cos to odd indices
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Register as buffer so it's not a learnable parameter
        self.register_buffer('pe', pe.unsqueeze(0).unsqueeze(2))
        print('Created 2D positional encoding with shape:', self.pe.shape)

    def forward(self, x):
        """
        x: tensor of shape (B, L, N, E)
        returns: same tensor with positional encoding added to time dimension
        """
        B, L, N, E = x.shape
        # Slice the positional embeddings to match L, then broadcast across B and N
        pos_encoding = self.pe[:, :L, :, :]

        # Add positional encoding to x
        return x + pos_encoding

class Transformer2D(nn.Module):
    """
    Complete 2D transformer with multiple layers
    """
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        num_layers: int,
        ff_dim: int,
        dropout: float = 0.1,
        max_len: int = 5000,
        use_pos_encoding: bool = True
    ):
        super().__init__()

        # Optional positional encoding
        self.pos_encoding = PositionalEncoding2D(embed_dim, max_len) if use_pos_encoding else None

        # Stack of transformer layers
        self.layers = nn.ModuleList([
            TransformerLayer2D(
                embed_dim=embed_dim,
                num_heads=num_heads,
                ff_dim=ff_dim,
                dropout=dropout
            ) for _ in range(num_layers)
        ])

        # Final layer norm (following BERT)
        self.final_norm = nn.LayerNorm(embed_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, padding_mask: torch.Tensor):
        """
        x: tensor of shape (B, L, N, E)
        padding_mask: boolean tensor of shape (B, L, N)
        returns: transformed tensor of shape (B, L, N, E)
        """
        # Optional positional encoding
        if self.pos_encoding is not None:
            x = self.pos_encoding(x)

        x = self.dropout(x)

        # Pass through all transformer layers
        for layer in self.layers:
            x = layer(x, padding_mask)


        # Final layer norm
        x = self.final_norm(x)

        return x


In [72]:
class EconModel(nn.Module):
    def __init__(self, table_names, table_shapes,
                 embed_dim=32, n_heads=4, ff_dim=128, num_layers=2,
                 dropout=0.1, use_pos_encoding=True):
        super().__init__()

        self.table_names = table_names
        self.N = len(table_names)

        self.table_embeds = nn.ModuleDict()
        self.table_decoders = nn.ModuleDict()

        # Create embeddings/decoders
        for tn, k_in in zip(table_names, table_shapes):
            self.table_embeds[tn] = TableEmbedding(3*k_in, embed_dim)
            self.table_decoders[tn] = TableDecoder(embed_dim, k_in) # Since we tripled the features for Nan representation

        # 2D Transformer core
        self.core_transformer = Transformer2D(
            embed_dim=embed_dim,
            num_heads=n_heads,
            num_layers=num_layers,
            ff_dim=ff_dim,
            dropout=dropout,
            use_pos_encoding=use_pos_encoding
        )

        # Learned parameter for masking
        self.mask_embedding = nn.Parameter(
            torch.randn(embed_dim) * 0.02  # Shape: (E), scale down following Transformer implementation
        )

        self.embed_dim = embed_dim

    def forward(self, batch_data):
        """
        batch_data:
          {
            "full_data": {tn -> (B, L_max, 3*k_i)},
            "mask": BoolTensor (B, L_max, N)
            "padding_mask": BoolTensor (B, L_max, N)
          }
        Returns a dict {tn -> (B, L, k_i)} of predictions.
        """
        # Embed each table
        embed_list = []

        for tn in self.table_names:
            x = batch_data["full_data"][tn]  # (B, L_max, 3*k_i)

            # embed
            x_emb = self.table_embeds[tn](x)  # -> (B, L, E)

            embed_list.append(x_emb)

        # Stack into (B, L, N, E)
        embed_stack = torch.stack(embed_list, dim=2)

        # Apply the mask with the learned masking vector. Where mask=1, we'll use the learned mask.
        mask = batch_data["mask"]  # (B, L_max, N)
        masked_embedding = torch.where(
            mask.unsqueeze(-1),  # (B, L, N, 1)
            self.mask_embedding,      # Will broadcast to (B, L, N, E)
            embed_stack              # (B, L, N, E)
        )

        # Pass through the transformer
        padding_mask = batch_data["padding_mask"]  # (B, L, N)
        out_2d = self.core_transformer(masked_embedding, padding_mask=padding_mask) # (B, L, N, E)

        # Flatten to get a single tensor for all tables
        out_2d = out_2d[:, :, 0, :] # (B, L, E)

        # 3) decode table by table
        decoded = {}
        for i, tn in enumerate(self.table_names):
            out = self.table_decoders[tn](out_2d)  # (B, L, k_i)
            decoded[tn] = out

        return decoded


In [48]:
table_shapes = [table_data_dict[tn].shape[1] for tn in dataset.table_names]

In [26]:
module_c = EconModel(dataset.table_names, table_shapes, embed_dim=32, n_heads=4, ff_dim=128, num_layers=2, dropout=0.1, use_pos_encoding=True).to(device)
module_c = torch.compile(module_c)

Created 2D positional encoding with shape: torch.Size([1, 5000, 1, 32])


In [82]:
dataloader1 = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=econ_collate_fn, pin_memory=True)

In [84]:
start = time.time()
moving_to_gpu = 0
i = 0
for data in dataloader1:
    if i > 200:
        break
    i += 1
    s = time.time()
    data = {
        "full_data": {tn: v.to(device, non_blocking=True) for tn, v in data["full_data"].items()},
        "mask": data["mask"].to(device, non_blocking=True),
        "padding_mask": data["padding_mask"].to(device, non_blocking=True)
    }
    moving_to_gpu += time.time() - s
    out = module_c(data)

print("Time taken:", time.time() - start)
print("Moving to GPU:", moving_to_gpu)

KeyboardInterrupt: 

In [21]:
module_c = EconModel(dataset.table_names, table_shapes, embed_dim=32, n_heads=4, ff_dim=128, num_layers=2, dropout=0.1, use_pos_encoding=True).to(device)

Created 2D positional encoding with shape: torch.Size([1, 5000, 1, 32])


In [29]:
dataloader1 = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=econ_collate_fn, pin_memory=True)

start = time.time()

i = 0
for data in dataloader1:
    i += 1
    if i > 50:
        break

    for tn in data["full_data"].keys():
        data["full_data"][tn] = data["full_data"][tn].to(device)
    data["mask"] = data["mask"].to(device)
    data["padding_mask"] = data["padding_mask"].to(device)
    out = module(data)

print("Time taken:", time.time() - start)

Time taken: 2.181070566177368


In [73]:
def masked_mse_loss(pred: torch.Tensor, target: torch.Tensor, mask: torch.BoolTensor):
    """
    pred: (B, L, k_i)
    target: (B, L, k_i)
    mask: (B, L, k_i)  # 1 where ground truth is valid, 0 where no ground truth
    Returns average MSE over valid entries.
    """
    diff = (pred - target) ** 2
    diff = diff * mask
    valid_count = mask.sum()
    if valid_count > 0:
        return diff.sum() / valid_count
    else:
        return torch.tensor(0.0, device=pred.device)

In [80]:
from tqdm.notebook import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import get_linear_schedule_with_warmup

torch.autograd.set_detect_anomaly(True)

def train_econ_model(csv_file_paths,
                     meta_file_paths,
                     epochs=5,
                     batch_size=8,
                     min_window_length_year=2,
                     max_window_length_year=25,
                     embed_dim=32,
                     lr=1e-2,
                     num_train_samples=25_000,
                     num_test_samples=500):
    """
    Full pipeline:
    1) Read & scale data with scikit-learn
    2) Create train/test datasets
    3) Model + optimizer
    4) Training loop with masked MSE
    """
    # Read and scale tables
    table_data_dict, meta_data_dict, scalers, monthly_dates = read_and_scale_tables(
        csv_file_paths,
        meta_file_paths,
        start_date="1970-01-01",
        end_date="2024-01-01",
        train_cutoff_str="2018-01-01",
    )

    # Create train and test dataset
    train_dataset = create_dataset(table_data_dict, meta_data_dict, monthly_dates, train=True, min_window_length_year=min_window_length_year, max_window_length_year=max_window_length_year, number_of_samples=num_train_samples)
    test_dataset = create_dataset(table_data_dict, meta_data_dict, monthly_dates, train=False, min_window_length_year=min_window_length_year, max_window_length_year=4, number_of_samples=num_test_samples)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=econ_collate_fn
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=econ_collate_fn
    )

    # 3) model + optimizer
    table_names = list(table_data_dict.keys())
    table_shapes = [table_data_dict[tn].shape[1] for tn in table_names]

    print(table_names[0])
    print(table_shapes[0])
    print(table_data_dict[table_names[0]].shape)

    model = EconModel(
        table_names,
        table_shapes,
        embed_dim=embed_dim,
        n_heads=4,
        ff_dim=128,
        num_layers=4,
        dropout=0.15,
        use_pos_encoding=True
    )

    update_loss_every = 100
    batch_count = 0
    running_loss = 0.0

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model = torch.compile(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    num_training_steps = len(train_loader) * epochs
    num_warmup_steps = len(train_loader) * 2

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )

    # 4) Training loop
    for ep in range(epochs):
        model.train()
        total_train_loss = 0.0
        # Training loop with tqdm
        for b_idx, batch_data in tqdm(enumerate(train_loader), desc="Training", total=len(train_loader)):
            # Move data to device
            for tn in batch_data["full_data"]:
                batch_data["full_data"][tn] = batch_data["full_data"][tn].to(device)
            batch_data["mask"] = batch_data["mask"].to(device)
            batch_data["padding_mask"] = batch_data["padding_mask"].to(device)

            optimizer.zero_grad()
            outputs = model(batch_data)  # dict {tn -> (B, L, k_i)}

            # Compute loss
            losses = []
            padding_mask = batch_data["padding_mask"][:, :, 0].unsqueeze(-1)  # (B, L, N) => (B, L, 1)
            for tn in table_names:
                pred = outputs[tn]
                tgt = batch_data["full_data"][tn][:, :, 0::3]  # (B, L, 3*k_i) => (B, L, k_i)
                expected_missing_mask = batch_data["full_data"][tn][:, :, 1::3] == 1.0  # (B, L, k_i)
                true_missing_mask = batch_data["full_data"][tn][:, :, 2::3] == 1.0  # (B, L, k_i)
                valid_mask = ~(expected_missing_mask | true_missing_mask | padding_mask)  # (B, L, k_i)
                losses.append(masked_mse_loss(pred, tgt, valid_mask))

            loss_val = torch.stack(losses).mean()
            with record_function("backward"):
                loss_val.backward()
                optimizer.step()
                scheduler.step()

            # Update running and total loss
            running_loss += loss_val.item()
            total_train_loss += loss_val.item()

            batch_count += 1

            # Update tqdm every 50 batches
            if batch_count % update_loss_every == 0:
                current_lr = optimizer.param_groups[0]['lr']
                tqdm.write(f"Average loss (last {update_loss_every} batches): {running_loss / update_loss_every:.4f}, lr: {current_lr:.2e}")
                running_loss = 0.0  # Reset running loss

        avg_train_loss = total_train_loss / (b_idx + 1)

        # Evaluate
        model.eval()
        total_test_loss = 0.0
        with torch.no_grad():
            for b_idx, batch_data in tqdm(enumerate(test_loader), desc="Testing"):
                for tn in batch_data["full_data"]:
                    batch_data["full_data"][tn] = batch_data["full_data"][tn].to(device)
                batch_data["mask"] = batch_data["mask"].to(device)
                batch_data["padding_mask"] = batch_data["padding_mask"].to(device)


                outputs = model(batch_data)
                losses = []
                padding_mask = batch_data["padding_mask"][:, :, 0].unsqueeze(-1)
                for tn in table_names:
                    pred = outputs[tn]
                    tgt = batch_data["full_data"][tn][:, :, 0::3]
                    expected_missing_mask = batch_data["full_data"][tn][:, :, 1::3] == 1.0
                    true_missing_mask = batch_data["full_data"][tn][:, :, 2::3] == 1.0
                    valid_mask = ~(expected_missing_mask | true_missing_mask | padding_mask)
                    losses.append(masked_mse_loss(pred, tgt, valid_mask))

                loss_val = torch.stack(losses).mean()
                total_test_loss += loss_val.item()

            avg_test_loss = total_test_loss / len(test_loader)


        print(f"Epoch {ep+1}/{epochs} - Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f}")

    print("Training complete!")
    return model



In [81]:
import json

with open("Data/all_data.json", "r") as f:
    all_data = json.load(f)

csv_file_paths = [f"Data/{table_name}.csv" for table_name in all_data]
csv_file_paths_meta = [f"Data/{table_name}_meta.csv" for table_name in all_data]

model = train_econ_model(
    csv_file_paths,
    csv_file_paths_meta,
    batch_size=8,
    lr=5e-4,
    min_window_length_year=2,
    max_window_length_year=5,
    num_train_samples=200_000,
    num_test_samples=100,
    epochs=10)

Reading and scaling tables:   0%|          | 0/76 [00:00<?, ?it/s]

BALANCE-PAIEMENTS
193
(649, 193)
Created 2D positional encoding with shape: torch.Size([1, 5000, 1, 32])


Training:   0%|          | 0/25000 [00:00<?, ?it/s]

Average loss (last 100 batches): 13.2726, lr: 1.00e-06
Average loss (last 100 batches): 12.4859, lr: 2.00e-06
Average loss (last 100 batches): 12.6743, lr: 3.00e-06
Average loss (last 100 batches): 13.1723, lr: 4.00e-06
Average loss (last 100 batches): 13.5393, lr: 5.00e-06
Average loss (last 100 batches): 13.1617, lr: 6.00e-06
Average loss (last 100 batches): 13.5773, lr: 7.00e-06
Average loss (last 100 batches): 13.1624, lr: 8.00e-06
Average loss (last 100 batches): 12.8970, lr: 9.00e-06
Average loss (last 100 batches): 13.3026, lr: 1.00e-05
Average loss (last 100 batches): 13.8302, lr: 1.10e-05
Average loss (last 100 batches): 12.3023, lr: 1.20e-05
Average loss (last 100 batches): 13.2099, lr: 1.30e-05
Average loss (last 100 batches): 12.5497, lr: 1.40e-05
Average loss (last 100 batches): 11.9360, lr: 1.50e-05
Average loss (last 100 batches): 12.7197, lr: 1.60e-05
Average loss (last 100 batches): 12.4442, lr: 1.70e-05
Average loss (last 100 batches): 12.4517, lr: 1.80e-05
Average lo

Testing: 0it [00:00, ?it/s]

Epoch 1/10 - Train Loss: 4.7426 | Test Loss: 47632388009258.0312


Training:   0%|          | 0/25000 [00:00<?, ?it/s]

Average loss (last 100 batches): 1.2096, lr: 2.51e-04
Average loss (last 100 batches): 1.4338, lr: 2.52e-04
Average loss (last 100 batches): 1.2666, lr: 2.53e-04
Average loss (last 100 batches): 1.2669, lr: 2.54e-04
Average loss (last 100 batches): 1.1993, lr: 2.55e-04
Average loss (last 100 batches): 1.2912, lr: 2.56e-04
Average loss (last 100 batches): 1.3233, lr: 2.57e-04
Average loss (last 100 batches): 1.1999, lr: 2.58e-04
Average loss (last 100 batches): 1.4589, lr: 2.59e-04
Average loss (last 100 batches): 1.2408, lr: 2.60e-04
Average loss (last 100 batches): 1.0920, lr: 2.61e-04
Average loss (last 100 batches): 1.2881, lr: 2.62e-04
Average loss (last 100 batches): 1.4501, lr: 2.63e-04
Average loss (last 100 batches): 1.1006, lr: 2.64e-04
Average loss (last 100 batches): 1.3288, lr: 2.65e-04
Average loss (last 100 batches): 1.1885, lr: 2.66e-04
Average loss (last 100 batches): 1.1111, lr: 2.67e-04
Average loss (last 100 batches): 1.2471, lr: 2.68e-04
Average loss (last 100 batch

Testing: 0it [00:00, ?it/s]

Epoch 2/10 - Train Loss: 1.0572 | Test Loss: 46683249248206.3203


Training:   0%|          | 0/25000 [00:00<?, ?it/s]

Average loss (last 100 batches): 0.8456, lr: 5.00e-04
Average loss (last 100 batches): 1.1094, lr: 5.00e-04
Average loss (last 100 batches): 0.9049, lr: 4.99e-04
Average loss (last 100 batches): 1.0024, lr: 4.99e-04
Average loss (last 100 batches): 1.0204, lr: 4.99e-04
Average loss (last 100 batches): 1.1085, lr: 4.99e-04
Average loss (last 100 batches): 0.9810, lr: 4.98e-04
Average loss (last 100 batches): 0.9987, lr: 4.98e-04
Average loss (last 100 batches): 0.8545, lr: 4.98e-04
Average loss (last 100 batches): 0.9192, lr: 4.98e-04
Average loss (last 100 batches): 0.8779, lr: 4.97e-04
Average loss (last 100 batches): 0.8613, lr: 4.97e-04
Average loss (last 100 batches): 0.8256, lr: 4.97e-04
Average loss (last 100 batches): 1.0480, lr: 4.96e-04
Average loss (last 100 batches): 0.9077, lr: 4.96e-04
Average loss (last 100 batches): 1.0123, lr: 4.96e-04
Average loss (last 100 batches): 0.8837, lr: 4.96e-04
Average loss (last 100 batches): 0.8094, lr: 4.95e-04
Average loss (last 100 batch

Testing: 0it [00:00, ?it/s]

Epoch 3/10 - Train Loss: 0.8500 | Test Loss: 49260364628865.7188


Training:   0%|          | 0/25000 [00:00<?, ?it/s]

Average loss (last 100 batches): 0.5990, lr: 4.37e-04
Average loss (last 100 batches): 0.8301, lr: 4.37e-04
Average loss (last 100 batches): 0.7893, lr: 4.37e-04
Average loss (last 100 batches): 0.6204, lr: 4.36e-04
Average loss (last 100 batches): 1.0170, lr: 4.36e-04
Average loss (last 100 batches): 0.8431, lr: 4.36e-04
Average loss (last 100 batches): 0.8215, lr: 4.36e-04
Average loss (last 100 batches): 0.6496, lr: 4.36e-04
Average loss (last 100 batches): 0.7371, lr: 4.35e-04
Average loss (last 100 batches): 0.9302, lr: 4.35e-04
Average loss (last 100 batches): 0.7298, lr: 4.35e-04
Average loss (last 100 batches): 0.9703, lr: 4.34e-04
Average loss (last 100 batches): 0.7597, lr: 4.34e-04
Average loss (last 100 batches): 0.6552, lr: 4.34e-04
Average loss (last 100 batches): 0.8019, lr: 4.34e-04
Average loss (last 100 batches): 0.6256, lr: 4.34e-04
Average loss (last 100 batches): 0.7226, lr: 4.33e-04
Average loss (last 100 batches): 0.9142, lr: 4.33e-04
Average loss (last 100 batch

Testing: 0it [00:00, ?it/s]

Epoch 4/10 - Train Loss: 0.6959 | Test Loss: 43563699629367.8047


Training:   0%|          | 0/25000 [00:00<?, ?it/s]

Average loss (last 100 batches): 0.6437, lr: 3.75e-04
Average loss (last 100 batches): 0.9844, lr: 3.74e-04
Average loss (last 100 batches): 0.5588, lr: 3.74e-04
Average loss (last 100 batches): 0.5997, lr: 3.74e-04
Average loss (last 100 batches): 0.7170, lr: 3.74e-04
Average loss (last 100 batches): 0.5703, lr: 3.74e-04
Average loss (last 100 batches): 0.5274, lr: 3.73e-04
Average loss (last 100 batches): 0.6636, lr: 3.73e-04
Average loss (last 100 batches): 0.5563, lr: 3.73e-04
Average loss (last 100 batches): 0.6481, lr: 3.73e-04
Average loss (last 100 batches): 0.6727, lr: 3.72e-04
Average loss (last 100 batches): 0.5965, lr: 3.72e-04
Average loss (last 100 batches): 0.5959, lr: 3.72e-04
Average loss (last 100 batches): 0.5696, lr: 3.72e-04
Average loss (last 100 batches): 0.6058, lr: 3.71e-04
Average loss (last 100 batches): 0.5895, lr: 3.71e-04
Average loss (last 100 batches): 0.6420, lr: 3.71e-04
Average loss (last 100 batches): 0.5754, lr: 3.71e-04
Average loss (last 100 batch

Testing: 0it [00:00, ?it/s]

Epoch 5/10 - Train Loss: 0.6169 | Test Loss: 58964718498220.6562


Training:   0%|          | 0/25000 [00:00<?, ?it/s]

Average loss (last 100 batches): 0.6333, lr: 3.12e-04
Average loss (last 100 batches): 0.5560, lr: 3.12e-04
Average loss (last 100 batches): 0.4952, lr: 3.12e-04
Average loss (last 100 batches): 0.5703, lr: 3.11e-04
Average loss (last 100 batches): 0.4954, lr: 3.11e-04
Average loss (last 100 batches): 0.4736, lr: 3.11e-04
Average loss (last 100 batches): 0.6059, lr: 3.11e-04
Average loss (last 100 batches): 0.5067, lr: 3.11e-04
Average loss (last 100 batches): 0.5794, lr: 3.10e-04
Average loss (last 100 batches): 0.4942, lr: 3.10e-04
Average loss (last 100 batches): 0.6017, lr: 3.10e-04
Average loss (last 100 batches): 0.4988, lr: 3.09e-04
Average loss (last 100 batches): 0.5638, lr: 3.09e-04
Average loss (last 100 batches): 0.5144, lr: 3.09e-04
Average loss (last 100 batches): 0.5041, lr: 3.09e-04
Average loss (last 100 batches): 0.4731, lr: 3.09e-04
Average loss (last 100 batches): 0.5795, lr: 3.08e-04
Average loss (last 100 batches): 0.5745, lr: 3.08e-04
Average loss (last 100 batch

Testing: 0it [00:00, ?it/s]

Epoch 6/10 - Train Loss: 0.5512 | Test Loss: 61113542007283.5938


Training:   0%|          | 0/25000 [00:00<?, ?it/s]

Average loss (last 100 batches): 0.4431, lr: 2.50e-04
Average loss (last 100 batches): 0.5502, lr: 2.49e-04
Average loss (last 100 batches): 0.6018, lr: 2.49e-04
Average loss (last 100 batches): 0.7237, lr: 2.49e-04
Average loss (last 100 batches): 0.4408, lr: 2.49e-04
Average loss (last 100 batches): 0.5175, lr: 2.49e-04
Average loss (last 100 batches): 0.5516, lr: 2.48e-04
Average loss (last 100 batches): 0.4839, lr: 2.48e-04
Average loss (last 100 batches): 0.4416, lr: 2.48e-04
Average loss (last 100 batches): 0.5433, lr: 2.47e-04
Average loss (last 100 batches): 0.4475, lr: 2.47e-04
Average loss (last 100 batches): 0.6924, lr: 2.47e-04
Average loss (last 100 batches): 0.5220, lr: 2.47e-04
Average loss (last 100 batches): 0.5299, lr: 2.47e-04
Average loss (last 100 batches): 0.4685, lr: 2.46e-04
Average loss (last 100 batches): 0.4145, lr: 2.46e-04
Average loss (last 100 batches): 0.4635, lr: 2.46e-04
Average loss (last 100 batches): 0.4950, lr: 2.46e-04
Average loss (last 100 batch

Testing: 0it [00:00, ?it/s]

Epoch 7/10 - Train Loss: 0.5272 | Test Loss: 56558133547421.3281


Training:   0%|          | 0/25000 [00:00<?, ?it/s]

Average loss (last 100 batches): 0.5639, lr: 1.87e-04
Average loss (last 100 batches): 0.4635, lr: 1.87e-04
Average loss (last 100 batches): 0.4928, lr: 1.87e-04
Average loss (last 100 batches): 0.6635, lr: 1.87e-04
Average loss (last 100 batches): 0.4518, lr: 1.86e-04
Average loss (last 100 batches): 0.4639, lr: 1.86e-04
Average loss (last 100 batches): 0.4106, lr: 1.86e-04
Average loss (last 100 batches): 0.6901, lr: 1.86e-04
Average loss (last 100 batches): 0.4438, lr: 1.85e-04
Average loss (last 100 batches): 0.4066, lr: 1.85e-04
Average loss (last 100 batches): 0.4366, lr: 1.85e-04
Average loss (last 100 batches): 0.4726, lr: 1.85e-04
Average loss (last 100 batches): 0.5555, lr: 1.84e-04
Average loss (last 100 batches): 0.4374, lr: 1.84e-04
Average loss (last 100 batches): 0.3897, lr: 1.84e-04
Average loss (last 100 batches): 0.4939, lr: 1.83e-04
Average loss (last 100 batches): 0.4222, lr: 1.83e-04
Average loss (last 100 batches): 0.4256, lr: 1.83e-04
Average loss (last 100 batch

Testing: 0it [00:00, ?it/s]

Epoch 8/10 - Train Loss: 0.4911 | Test Loss: 53727816099541.6562


Training:   0%|          | 0/25000 [00:00<?, ?it/s]

Average loss (last 100 batches): 0.4912, lr: 1.25e-04
Average loss (last 100 batches): 0.4948, lr: 1.24e-04
Average loss (last 100 batches): 0.5184, lr: 1.24e-04
Average loss (last 100 batches): 0.4008, lr: 1.24e-04
Average loss (last 100 batches): 0.5049, lr: 1.24e-04
Average loss (last 100 batches): 0.5571, lr: 1.23e-04
Average loss (last 100 batches): 0.4889, lr: 1.23e-04
Average loss (last 100 batches): 0.4142, lr: 1.23e-04
Average loss (last 100 batches): 0.4339, lr: 1.23e-04
Average loss (last 100 batches): 0.3774, lr: 1.22e-04
Average loss (last 100 batches): 0.4092, lr: 1.22e-04
Average loss (last 100 batches): 0.4790, lr: 1.22e-04
Average loss (last 100 batches): 0.4187, lr: 1.22e-04
Average loss (last 100 batches): 0.4579, lr: 1.21e-04
Average loss (last 100 batches): 0.4248, lr: 1.21e-04
Average loss (last 100 batches): 0.3976, lr: 1.21e-04
Average loss (last 100 batches): 0.4148, lr: 1.21e-04
Average loss (last 100 batches): 0.4143, lr: 1.21e-04
Average loss (last 100 batch

Testing: 0it [00:00, ?it/s]

Epoch 9/10 - Train Loss: 0.4703 | Test Loss: 55731426572943.1719


Training:   0%|          | 0/25000 [00:00<?, ?it/s]

Average loss (last 100 batches): 0.4323, lr: 6.22e-05
Average loss (last 100 batches): 0.4001, lr: 6.20e-05
Average loss (last 100 batches): 0.3826, lr: 6.17e-05
Average loss (last 100 batches): 0.3837, lr: 6.15e-05
Average loss (last 100 batches): 0.5509, lr: 6.12e-05
Average loss (last 100 batches): 0.4141, lr: 6.10e-05
Average loss (last 100 batches): 0.5576, lr: 6.07e-05
Average loss (last 100 batches): 0.5329, lr: 6.05e-05
Average loss (last 100 batches): 0.3910, lr: 6.03e-05
Average loss (last 100 batches): 0.5988, lr: 6.00e-05
Average loss (last 100 batches): 0.4587, lr: 5.97e-05
Average loss (last 100 batches): 0.4321, lr: 5.95e-05
Average loss (last 100 batches): 0.5158, lr: 5.92e-05
Average loss (last 100 batches): 0.5583, lr: 5.90e-05
Average loss (last 100 batches): 0.4594, lr: 5.87e-05
Average loss (last 100 batches): 0.3911, lr: 5.85e-05
Average loss (last 100 batches): 0.5189, lr: 5.83e-05
Average loss (last 100 batches): 0.5244, lr: 5.80e-05
Average loss (last 100 batch

Testing: 0it [00:00, ?it/s]

Epoch 10/10 - Train Loss: 0.4741 | Test Loss: 46506442880304.2188
Training complete!


In [115]:
device = torch.device("cuda")

def infer(model, dataset, device, column: str = "CNA-2020-PIB"):
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=econ_collate_fn)
    table_names = list(dataset.table_names)
    data = next(iter(dataloader))
    for tn in data["full_data"]:
        data["full_data"][tn] = data["full_data"][tn].to(device)
    data["mask"] = data["mask"].to(device)
    data["padding_mask"] = data["padding_mask"].to(device)

    with torch.no_grad():
        outputs = model(data)

    targets = {}
    valid_masks = {}
    losses = []
    padding_mask = data["padding_mask"][:, :, 0].unsqueeze(-1)
    for tn in table_names:
        if tn != column:
            continue

        pred = outputs[tn]
        tgt = data["full_data"][tn][:, :, 0::3]
        targets[tn] = tgt
        expected_missing_mask = data["full_data"][tn][:, :, 1::3] == 1.0
        true_missing_mask = data["full_data"][tn][:, :, 2::3] == 1.0
        valid_mask = ~(expected_missing_mask | true_missing_mask | padding_mask)
        valid_masks[tn] = valid_mask
        losses.append(masked_mse_loss(pred, tgt, valid_mask))

    loss_val = torch.stack(losses).mean().item()
    print("Computed loss:", loss_val)
    return outputs, targets, valid_masks

In [131]:
model.eval()
test_dataset = create_dataset(table_data_dict, meta_data_dict, monthly_dates, train=False, min_window_length_year=2, max_window_length_year=4, number_of_samples=100, inference_mode=True)
outputs, targets, valid_masks = infer(model, test_dataset, device)

Computed loss: 68.79239654541016


In [118]:
targets["CNA-2020-PIB"][0, 12, :]

tensor([-2.0836,  1.8250, -1.6541,  2.8923,  2.8842,  0.7365,  0.3575, -1.1170,
         1.6394,  0.8702,  0.5023,  1.0544, -1.6651,  2.7071, -4.1691, -3.3951,
        -3.9235, -3.3062, -4.0105, -1.4309,  3.0046], device='cuda:0')

In [117]:
valid_masks["CNA-2020-PIB"][0, :, :].cpu().numpy()

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True],
       [False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False],
       [False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False],
       [False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False],
       [False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False],
       [False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False

In [119]:
outputs["CNA-2020-PIB"][0, 12, :]

tensor([-1.2565,  3.4274, -1.9030,  1.6195,  2.3284, -0.3057, -2.3308, -1.3956,
         1.1599,  0.6699,  0.4282,  0.2531,  0.4027,  1.7227, -4.4886, -3.8208,
        -4.2485, -3.7883, -4.3450,  0.1779,  0.6506], device='cuda:0')

In [82]:
torch.save(model.state_dict(), "model_t.pt")

First profiling with N then L

In [232]:
print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               backward        25.48%        1.151s        27.41%        1.238s      95.265ms       0.000us         0.00%      10.298ms     792.166us       1.23 Kb        -832 b      -9.40 Gb      -9.43 G

In [None]:
# With compile
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))

In [28]:
import os
import requests
from dotenv import load_dotenv

load_dotenv()
key = os.getenv("NYT_KEY")

result = requests.request("GET", f"https://api.nytimes.com/svc/archive/v1/1970/1.json?api-key={key}")

In [42]:
print(result.json()['response']['docs'][10])
for d in result.json()['response']['docs'][:500]:
    if 'Front Page 2 -- No Title' in d['headline']['main']:
        print(d)
        break

filtered_articles = [d for d in result.json()['response']['docs'] if 'print_page' in d and int(d['print_page']) < 3]
print(len(filtered_articles))


{'abstract': "Prosecution of charges involving ousted Water Supply, Gas and Electricity Dept Comr Marcus continues; M Kaufman indicted for '68 perjury concerning attempt to bribe City Planning Comm member to delay bldg application by competitor S Sommer; is charged with denying bribe in testimony before grand jury probing incident and role of Kaufman, Marcus, informer H Itkin, real estate operator R Elyachar 'and others'; Itkin has testified that city official shared $10,000 payoff with him and Marcus; Elyachar pleaded guilty to perjury last Sept, is reptd cooperating with probe; Kaufman pleads not guilty", 'web_url': 'https://www.nytimes.com/1970/01/01/archives/builder-is-accused-of-perjury-in-cityplanning-bribery-case.html', 'snippet': "Prosecution of charges involving ousted Water Supply, Gas and Electricity Dept Comr Marcus continues; M Kaufman indicted for '68 perjury concerning attempt to bribe City Planning Comm member to delay bldg application by competitor S Sommer; is cha..."

In [12]:
print(len(result.json()['response']['docs']))

6560


In [13]:
headers = {
        'User-Agent': 'Mozilla/5.0 (compatible; Googlebot/2.1; +http://www.google.com/bot.html)',
        'Referer': 'https://www.google.com/'
    }
response = requests.get("https://www.nytimes.com/1979/01/01/archives/israelis-decide-to-continue-talks-with-egypt-on-treaty-israel-would.html", headers=headers)

In [14]:
print(response.text)

<!DOCTYPE html>
<html lang="en" class=" story nytapp-vi-article "  xmlns:og="http://opengraphprotocol.org/schema/">
  <head>
    
    
    <meta charset="utf-8" />
    <title data-rh="true">Israelis Decide to Continue Talks With Egypt on Treaty - The New York Times</title>
    <meta data-rh="true" name="robots" content="noarchive, max-image-preview:large"/><meta data-rh="true" name="description" content="Begins says Israel will continue talks with Egypt; illus (M)"/><meta data-rh="true" property="twitter:url" content="https://www.nytimes.com/1979/01/01/archives/israelis-decide-to-continue-talks-with-egypt-on-treaty-israel-would.html"/><meta data-rh="true" property="twitter:title" content="Israelis Decide to Continue Talks With Egypt on Treaty (Published 1979)"/><meta data-rh="true" property="twitter:description" content="Begins says Israel will continue talks with Egypt; illus (M)"/><meta data-rh="true" property="twitter:image" content="https://static01.nyt.com/newsgraphics/images/icon

In [24]:
from newspaper import Article

for url in [x['web_url'] for x in result.json()['response']['docs']]:
    # Wait for the page to load and get the html
    html = requests.get(url, headers=headers).text
    article = Article(url)
    article.download(input_html=html)
    article.parse()
    article.nlp()
    print("------------------")
    print(article.title)
    print(article.text)
    print()


------------------
Israelis Decide to Continue Talks With Egypt on Treaty
A low point was reached on Dec. 15 when the Israeli Cabinet, in a unanimous vote, appeared to close the door to further negotiations by rejecting all Egyptian proposals to amend the draft treaty. Since then, Mr. Begin has said that Israel is willing to discuss some but not all the Egyptian demands. The Government decision today and Prime Minister Begin's later remarks are part of the gradual movement back to the negotiating table.

Israel Would Review Security

In announcing agreement to resume the talks, Mr. Begin told reporters that Israel was prepared to discuss with Egypt its demand to review security arrangments in the Sinai Peninsula five years after a peace treaty is signed. Sinai, now under Israeli occupation, would be returned to Egypt under the peace treaty, but Egypt wants to eventually renegotiate the size of the military force it can deploy there.

Mr. Begin also said that Israel was prepared to disc

KeyboardInterrupt: 

In [29]:
urls = [x['web_url'] for x in result.json()['response']['docs']]


In [30]:
for i, url in enumerate(urls[:20]):
    print(i, url)

0 https://dealbook.nytimes.com/2006/06/21/facebook-and-that-2-billion/
1 https://www.nytimes.com/1970/01/01/archives/hail-and-farewell.html
2 https://www.nytimes.com/1970/01/01/archives/front-page-2-no-title.html
3 https://www.nytimes.com/1970/01/01/archives/mississippi-adds-3-negroes-to-storm-relief-unit-governor-moves-to.html
4 https://www.nytimes.com/1970/01/01/archives/icc-aide-is-named-as-acting-chairman.html
5 https://www.nytimes.com/1970/01/01/archives/letters-to-the-editor-of-the-times.html
6 https://www.nytimes.com/1970/01/01/archives/traffic-snarled-by-freezing-rain-road-and-rail-facilities-are.html
7 https://www.nytimes.com/1970/01/01/archives/business-tax-bills-signed-by-shafer.html
8 https://www.nytimes.com/1970/01/01/archives/article-4-no-title.html
9 https://www.nytimes.com/1970/01/01/archives/cigarette-maker-loses-courtr-test-set-back-in-dispute-over-tv.html
10 https://www.nytimes.com/1970/01/01/archives/builder-is-accused-of-perjury-in-cityplanning-bribery-case.html
11

In [68]:
import torch
import torch.nn as nn

# For reproducibility
torch.manual_seed(0)

def key_padding_mask_to_attention_mask(key_padding_mask: torch.Tensor) -> torch.Tensor:
    """
    Convert a key_padding_mask of shape (B, S) to an attention_mask of shape (B, S, S).

    Args:
        key_padding_mask: Boolean tensor of shape (B, S) where True indicates padding tokens
                         and False indicates actual tokens.

    Returns:
        attention_mask: Boolean tensor of shape (B, S, S) where False indicates allowed
                       attention and True indicates masked (blocked) attention.
    """
    batch_size, seq_len = key_padding_mask.size()

    # First, we need to convert the key_padding_mask to the right shape
    # We want each position to not attend to padding tokens
    # So we expand the key_padding_mask to (B, 1, S) and broadcast it to (B, S, S)
    expanded_mask = key_padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len)

    # The attention_mask should be False (1) where we want to allow attention
    # and True (0) where we want to block it
    attention_mask = expanded_mask

    return attention_mask

def fix_fully_masked_rows(attn_mask_3d: torch.Tensor, key_padding_mask: torch.Tensor) -> torch.Tensor:
    """
    For any row b where key_padding_mask[b] is all True (i.e., fully masked),
    replace that entire (L, L) block in attn_mask_3d with ~torch.eye(L).
    This makes each token attend only to itself, preventing NaNs.
    """
    B, L, _ = attn_mask_3d.shape
    print("DEBUGGING FIX FULLY MASKED ROWS")
    print(attn_mask_3d)
    print(key_padding_mask)
    fully_masked_rows = key_padding_mask.all(dim=1)  # shape (B,)
    print(fully_masked_rows)

    attn_mask_3d[fully_masked_rows] = ~torch.eye(L, L, dtype=torch.bool, device=attn_mask_3d.device)
    print(attn_mask_3d)
    print("END DEBUGGING FIX FULLY MASKED ROWS")
    return attn_mask_3d

class SimpleModel(nn.Module):
    def __init__(self, embed_dim=4, num_heads=1):
        super().__init__()
        self.attn = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            batch_first=True
        )
        self.proj = nn.Linear(embed_dim, 1)

    def forward(self, x, key_padding_mask):
        """
        x:              Tensor of shape (B, L, E)
        key_padding_mask: shape (B, L), where True indicates "ignore" that token.
        """
        B, L, E = x.shape
        print("x:", x)
        print("key_padding_mask:", key_padding_mask)
        # 1) Build an initial (B, L, L) attn_mask from key_padding_mask
        attn_mask_3d = key_padding_mask_to_attention_mask(key_padding_mask).contiguous()

        # 2) Replace fully-masked rows with the identity diagonal trick
        attn_mask_3d = fix_fully_masked_rows(attn_mask_3d, key_padding_mask)

        # 3) Forward pass with attn_mask only (no separate key_padding_mask!)
        attn_out, _ = self.attn(
            x, x, x,
            attn_mask=attn_mask_3d
        )

        # 4) Zero out entire batch rows that are fully masked.
        fully_masked_rows = key_padding_mask.all(dim=1)
        print("Fully masked rows:", fully_masked_rows)
        attn_out[fully_masked_rows] = 0.0

        # 5) Projection -> scalar
        out = self.proj(attn_out).view(B, L)  # shape (B, L)
        print(out)
        print(~fully_masked_rows.unsqueeze(-1).expand(B, L))
        # average for non masked elements
        out_filtered = out * (~fully_masked_rows.unsqueeze(-1).expand(B, L))
        return out_filtered.sum() / (~fully_masked_rows.unsqueeze(-1).expand(B, L)).sum()


# -----------------------------------------------------------------------
# DEMO: B=2, second row fully masked
# -----------------------------------------------------------------------
print("==== Test 1: B=2, second row fully masked ====")
model = SimpleModel(embed_dim=4, num_heads=1)

x = torch.randn(2, 5, 4, requires_grad=True)
x2 = x.clone().detach()
x2.requires_grad = True

padding_mask = torch.zeros(2, 5, dtype=torch.bool)
padding_mask[1, :] = True  # fully mask the second row

# -- Forward/backward pass #1 --
out = model(x, padding_mask)
print("Forward output #1:", out.item())
out.backward()

# Store gradients after the first pass
grads_pass1 = {}
for name, param in model.named_parameters():
    if param.grad is not None:
        grads_pass1[name] = param.grad.clone().detach()

# Clear gradients
model.zero_grad()

# -----------------------------------------------------------------------
# DEMO: B=1, single row fully masked
# -----------------------------------------------------------------------
print("\n==== Test 2: B=1, single row fully masked ====")
x_single = x2[:1, :, :]  # (2, 5, 4) => (1, 5, 4)
padding_mask_single = torch.zeros(1, 5, dtype=torch.bool)  # all True => fully masked

# -- Forward/backward #1 --
out_single_1 = model(x_single, padding_mask_single)
print("Forward output #2:", out_single_1.item())
out_single_1.backward()

# Grab gradients from pass #1
grads_pass2 = {}
for name, param in model.named_parameters():
    if param.grad is not None:
        grads_pass2[name] = param.grad.clone().detach()

model.zero_grad()
# Compare
for name in grads_pass2:
    print(f"Param: {name}")
    print(f"  Grad pass1 norm = {grads_pass2[name].norm(2).item():.4f}")
    print(f"  Grad pass2 norm = {grads_pass1[name].norm(2).item():.4f}")
    # Check for NaNs
    if torch.isnan(grads_pass2[name]).any() or torch.isnan(grads_pass2[name]).any():
        print("    --> Found NaN in gradients!\n")

print("\nDone! We expect no NaNs and no errors in backward.")


==== Test 1: B=2, second row fully masked ====
x: tensor([[[ 2.0820,  1.7067,  2.3804, -1.1256],
         [-0.3170, -1.0925, -0.0852, -0.0933],
         [-0.7607, -1.5991,  0.0185, -0.7504],
         [ 0.1854,  0.6211,  0.6382, -0.2460],
         [-0.5344,  1.1687,  0.3945,  1.9415]],

        [[ 0.7915, -0.0203, -0.4372,  1.6459],
         [-2.4351, -0.0729, -0.0340,  0.9625],
         [ 0.3492, -0.9215, -0.0562, -0.7015],
         [-0.4637,  1.9218, -0.4025,  0.1239],
         [ 1.1648,  0.9234,  1.3873,  1.3750]]], requires_grad=True)
key_padding_mask: tensor([[False, False, False, False, False],
        [ True,  True,  True,  True,  True]])
DEBUGGING FIX FULLY MASKED ROWS
tensor([[[False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False]],

        [[ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,