In [1]:
import datasets
import numpy as np
import pandas as pd
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import tqdm
from torch.utils.tensorboard import SummaryWriter

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE IS ... ", device)

DEVICE IS ...  cuda


In [14]:
class MultiTaskNet(nn.Module):
    def __init__(self, embed_dim=11348, layer_sizes=[2048, 500, 2048, 500]):
        super().__init__()

        self.embedding_dim = embed_dim

        self.mlp_net = nn.Sequential(
            nn.Linear(layer_sizes[0], layer_sizes[1]),  ## 96x64
            nn.ReLU(),
            nn.Linear(layer_sizes[1], layer_sizes[2]),  ## 64x1
            nn.ReLU(),
            nn.Linear(layer_sizes[2], layer_sizes[3]),
        )

        self.last_layer = nn.Linear(
            layer_sizes[3], 1
        )  ## change if we need classification or softmax

    def forward(self, x):
        x = self.mlp_net(x)

        out_x = self.last_layer(x)

        return out_x

In [7]:
epochs = 20
lr = 1e-4
batch_size = 16


In [4]:
with open("data/test_random_proj.pt", "rb") as f:
    dev = torch.load(f)

dev_labels= pd.read_parquet("data/test_20221130.parquet.gzip", columns=["Mean_BMI", "Under5_Mortality_Rate"])

In [37]:
def collator_fn(data):
    x, y_df = data
    x_inp = x.to(device)
    y_bmi = torch.tensor(y_df["Mean_BMI"].values, dtype=torch.float32, device=device)
    y_cmr = torch.tensor(
        y_df["Under5_Mortality_Rate"].values, dtype=torch.float32, device=device
    )
    return x_inp, y_bmi, y_cmr

# dataloader = DataLoader((dev,dev_labels), batch_size=1, collate_fn=collator_fn)

In [40]:
dataloader = DataLoader(TensorDataset(*collator_fn((dev,dev_labels))), batch_size=1)

In [58]:
for idx, batch in enumerate(dataloader):
    x, y_bmi, y_cmr = batch
    break

In [59]:
y_bmi.shape

torch.Size([16])

In [19]:
def masked_mse(output, target):
    mse_loss = nn.MSELoss()
    mask = torch.isnan(target)
    target = torch.where(mask, 0.0, target)
    output = torch.where(mask, 0.0, output)
    return mse_loss(target, output)

def r2_loss(output, target):
    target_mean = torch.nanmean(target)
    ss_tot = torch.nansum(((target - target_mean)) ** 2)
    ss_res = torch.nansum(((target - output)) ** 2)
    r2 = 1 - ss_res / ss_tot
    return r2

In [15]:
print("Model loading")
model = MultiTaskNet().to(device)
loss_fn = masked_mse
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

Model loading


In [17]:
model.load_state_dict(torch.load("outputs/best_bmi.pth")["model_state_dict"])

<All keys matched successfully>

In [42]:
def evaluate_model(model, dataloader):
    mse_loss = []
    r2_losses = []

    for idx, batch in enumerate(dataloader):
        x, y_bmi, y_cmr = batch
        with torch.no_grad():
            outs = model(x).squeeze()
            loss = loss_fn(outs, y_bmi)
            r2_val_loss = r2_loss(outs, y_bmi)

            if np.isnan(loss.item()):
                print(idx ,outs, y_bmi)

            mse_loss.append(loss.item())
            r2_losses.append(r2_val_loss.item())
    # print(np.sum())
    mse_loss_avg = np.array(mse_loss).mean()
    r2_losses_avg = np.array(r2_losses).mean()

    return mse_loss_avg, r2_losses_avg

In [43]:
evaluate_model(model, dataloader)

10324 tensor(nan, device='cuda:0') tensor([22.4700], device='cuda:0')
10328 tensor(nan, device='cuda:0') tensor([23.5900], device='cuda:0')
10330 tensor(nan, device='cuda:0') tensor([23.1000], device='cuda:0')
10333 tensor(nan, device='cuda:0') tensor([24.1100], device='cuda:0')
10358 tensor(nan, device='cuda:0') tensor([28.9500], device='cuda:0')
10375 tensor(nan, device='cuda:0') tensor([23.2200], device='cuda:0')
10376 tensor(nan, device='cuda:0') tensor([23.3700], device='cuda:0')
10387 tensor(nan, device='cuda:0') tensor([28.4600], device='cuda:0')


(nan, nan)

In [44]:
dev[10324]

tensor([-inf, -inf, -inf,  ..., inf, -inf, inf])

In [48]:
with open("data/drop_cols.pickle", "rb") as f:
    drop = pickle.load(f)

In [45]:
test_df = pd.read_parquet("data/test_20221130.parquet.gzip")

In [49]:
test_df.drop(drop, errors='ignore', inplace=True, axis = 1)

In [54]:
df = test_df.iloc[10324]

In [52]:
tmp[tmp.isna()]

Deep_Blue_Single_Scattering_Albedo_Land_Mean_Mean_412_median@MODIS/061/MOD08_M3&timestamped             NaN
Deep_Blue_Single_Scattering_Albedo_Land_Std_Deviation_Mean_412_median@MODIS/061/MOD08_M3&timestamped    NaN
Deep_Blue_Single_Scattering_Albedo_Land_Mean_Mean_412_mean@MODIS/061/MOD08_M3&timestamped               NaN
Deep_Blue_Single_Scattering_Albedo_Land_Std_Deviation_Mean_412_mean@MODIS/061/MOD08_M3&timestamped      NaN
Deep_Blue_Single_Scattering_Albedo_Land_Mean_Mean_412_max_max@MODIS/061/MOD08_M3&timestamped            NaN
Deep_Blue_Single_Scattering_Albedo_Land_Mean_Mean_412_max_min@MODIS/061/MOD08_M3&timestamped            NaN
Deep_Blue_Single_Scattering_Albedo_Land_Mean_Mean_412_min_max@MODIS/061/MOD08_M3&timestamped            NaN
Deep_Blue_Single_Scattering_Albedo_Land_Mean_Mean_412_min_min@MODIS/061/MOD08_M3&timestamped            NaN
Deep_Blue_Single_Scattering_Albedo_Land_Std_Deviation_Mean_412_max_max@MODIS/061/MOD08_M3&timestamped   NaN
Deep_Blue_Single_Scattering_

In [53]:
df_mean_std = pd.read_parquet("data/train_means_std.parquet.gzip")

In [55]:
df_norm = (df-df_mean_std["means"])/df_mean_std["stds"]

In [79]:
df_mean_std["stds"][11305]

0.0

In [78]:
df[11305]

0.0

In [77]:
df_norm[11305]

-inf

In [62]:
df_norm.fillna(value=0, inplace=True)
x_inp = torch.tensor(df_norm.values, dtype=torch.float32)

In [76]:
x_inp[11305]

tensor(-inf)

In [57]:
with open("data/rand_proj.pt", "rb") as f:
    rand_proj = torch.load(f)


In [75]:
torch.argmin(x_inp)

tensor(11305)

In [63]:
x_proj = x_inp @ rand_proj

In [64]:
x_proj

tensor([-inf, -inf, -inf,  ..., inf, -inf, inf])

In [60]:
rand_proj

tensor([[ 0.0118,  0.0084, -0.0104,  ...,  0.0082, -0.0079, -0.0010],
        [-0.0092,  0.0034, -0.0022,  ..., -0.0015, -0.0187,  0.0070],
        [ 0.0083,  0.0114, -0.0076,  ...,  0.0016, -0.0196,  0.0105],
        ...,
        [ 0.0082,  0.0071,  0.0028,  ...,  0.0110, -0.0117,  0.0114],
        [ 0.0026, -0.0003,  0.0116,  ..., -0.0170,  0.0168,  0.0071],
        [-0.0044, -0.0087,  0.0202,  ..., -0.0058, -0.0088, -0.0183]])