In [1]:
import os
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms.functional as tf

In [2]:
# Number of Historical Days used as Input!
num_days_history = 14

In [3]:
class CustomDataset(Dataset):
    def __init__(self, start_pt, num_pts):
        self.start_pt = start_pt
        self.num_pts = num_pts
        self.factor = 10

        base = "data/"
        self.input_folder = base + "input/"
        self.output_folder = base + "output/"
        self.input_files = os.listdir(self.input_folder)
        self.output_files = os.listdir(self.output_folder)

        assert (start_pt + num_pts * self.factor) < len(self.output_files)

    def __len__(self):
        return self.num_pts

    def __getitem__(self, idx):
        idx = idx * self.factor
        input = []
        for i in range(num_days_history):
            input.append(torch.load(self.input_folder + self.input_files[self.start_pt + idx + 67 - i]))
        input = torch.cat(input, 0)
        output = torch.load(self.output_folder + self.output_files[self.start_pt + idx])
        return input, output

In [4]:
dataloader = DataLoader(CustomDataset(0, 12), batch_size=4, shuffle=True)
train_features, train_labels = next(iter(dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

Feature batch shape: torch.Size([4, 84, 676, 407])
Labels batch shape: torch.Size([4, 1, 676, 407])


Model

In [5]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        convs = []
        # conv_chs = [6 * num_days_history, 2048, 512, 128, 32, 8] # Can't execute this on local GPU!
        conv_chs = [6 * num_days_history, 64, 32, 16, 8]
        for i in range(len(conv_chs) - 1):
            convs.append(nn.Conv2d(conv_chs[i], conv_chs[i + 1], kernel_size=(3,3), padding="same"))
        self.convs = nn.ModuleList(convs)
        self.act = nn.ReLU()
        # self.LogSoftmax = nn.LogSoftmax(dim=1)
        self.final_conv = nn.Conv2d(conv_chs[-1], 1, kernel_size=(3,3), padding="same")
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        for conv in self.convs:
            x = self.act(conv(x))
        # x = self.LogSoftmax(x)
        x = self.sigmoid(self.final_conv(x))
        return x

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [7]:
# Building mask to remove invalid areas
file_path = "data/FD/" + os.listdir("data/FD/")[0]
im = Image.open(file_path)
mask = tf.to_tensor(im)
mask[mask >= 0] = True
mask[mask < 0] = False
mask = mask.bool().to(device)
print(f"Number of valid points:, {torch.sum(mask):,}")

Number of valid points:, 152,661


In [8]:
BATCH_SIZE = 4
train_dataloader = DataLoader(CustomDataset(0, 400), batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(CustomDataset(400, 40), batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(CustomDataset(440, 80), batch_size=BATCH_SIZE, shuffle=False)

In [9]:
train_steps = len(train_dataloader.dataset) // BATCH_SIZE
batch_mask = mask.repeat(BATCH_SIZE, 1, 1, 1)

In [10]:
weight = 50
eps = 1e-8
mask_sum = torch.sum(mask)
def weightedBCELoss(output, target):
    output = torch.clamp(output, eps, 1. - eps)
    loss = (weight * (target * torch.log(output))) + ((1 - target) * torch.log(1 - output))
    loss = loss * mask
    return -torch.sum(loss) / mask_sum

model = Net().to(device)
# loss_fn = nn.NLLLoss(torch.Tensor([1, 50]).to(device))
loss_fn = weightedBCELoss
optimizer = optim.Adam(model.parameters(), lr=1e-4)

print(model)

Net(
  (convs): ModuleList(
    (0): Conv2d(84, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (2): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (3): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=same)
  )
  (act): ReLU()
  (final_conv): Conv2d(8, 1, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (sigmoid): Sigmoid()
)


In [11]:
def get_metrics(dataloader):
    model.eval()

    correct_num = 0
    correct_denom = 0
    recall_num = 0
    recall_denom = 0
    precision_num = 0
    precision_denom = 0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        y_pred = model(inputs)
        pred = y_pred > 0.5

        pred = pred[batch_mask]
        labels = labels[batch_mask]
        
        correct_num += (pred == labels).float().sum().item()
        correct_denom += labels.numel()
        recall_num += (pred[labels > 0] == labels[labels > 0]).float().sum().item()
        recall_denom += labels[labels > 0].numel()
        precision_num += (pred[pred > 0] == labels[pred > 0]).float().sum().item()
        precision_denom += labels[pred > 0].numel()
    
    correct = 0
    if correct_denom > 0:
        correct = correct_num / correct_denom
    recall = 0
    if recall_denom > 0:
        recall = recall_num / recall_denom
    precision = 0
    if precision_denom > 0:
        precision = precision_num / precision_denom
    
    return correct, recall, precision

In [15]:
best_model_value = 0

n_epochs = 200
nan_found = False
for epoch in range(n_epochs):
    model.train()

    total_train_loss = 0
    train_correct = 0
    train_recall = 0
    train_precision = 0
    for inputs, labels in tqdm(train_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        y_pred = model(inputs)
        loss = loss_fn(y_pred, labels)

        if torch.isnan(loss):
            print(f"NaN detected in loss!")
            torch.save(model.state_dict(), "model_nan.pt")
            nan_found = True
            break
        
        optimizer.zero_grad()
        loss.backward()

        # Check for NaNs in gradients
        for name, param in model.named_parameters():
            if param.grad is not None:
                if torch.isnan(param.grad).any():
                    print(f"NaN detected in gradients of {name}")
                    nan_found = True
        if nan_found == True:
            break
    
        optimizer.step()

        # Check for NaNs in model weights
        for name, param in model.named_parameters():
            if torch.isnan(param).any():
                print(f"NaN detected in weights of {name}")
                nan_found = True
        if nan_found == True:
            break

        total_train_loss += loss

        # pred = torch.argmax(y_pred, dim=1)
        pred = y_pred > 0.5

        pred = pred[batch_mask]
        labels = labels[batch_mask]
        
        train_correct += ((pred == labels).float().sum().item() / labels.numel())
        train_recall += ((pred[labels > 0] == labels[labels > 0]).float().sum().item() / labels[labels > 0].numel())
        if labels[pred > 0].numel() > 0:
            train_precision += ((pred[pred > 0] == labels[pred > 0]).float().sum().item() / labels[pred > 0].numel())
    
    if nan_found == True:
        break

    avg_train_loss = total_train_loss / train_steps
    train_correct = train_correct / train_steps
    train_recall = train_recall / train_steps
    train_precision = train_precision / train_steps
    
    print("Epoch: %d, Loss: %.4f" % (epoch+1, avg_train_loss))
    print("[Training Metrics]   Accuracy: %.4f, Recall: %.4f, Precision: %.4f" % (100*train_correct, 100*train_recall, 100*train_precision))

    val_correct, val_recall, val_precision = get_metrics(val_dataloader)

    print("[Validation Metrics] Accuracy: %.4f, Recall: %.4f, Precision: %.4f" % (100*val_correct, 100*val_recall, 100*val_precision))

    model_value = 0
    if (val_recall + val_precision) > 1e-4:
        model_value = (val_recall * val_precision) / (val_recall + val_precision) # F1 Score
    print("Model Value:", model_value)
    if model_value > best_model_value:
        print("Saving Model!")
        torch.save(model.state_dict(), "model.pt")
        best_model_value = model_value

100%|██████████| 100/100 [10:19<00:00,  6.20s/it]


Epoch: 1, Loss: 3.7438
[Training Metrics]   Accuracy: 69.8187, Recall: 81.5737, Precision: 5.1075
[Validation Metrics] Accuracy: 69.4475, Recall: 81.6051, Precision: 6.7359
Model Value: 0.062222820837959825


100%|██████████| 100/100 [10:16<00:00,  6.16s/it]


Epoch: 2, Loss: 3.7713
[Training Metrics]   Accuracy: 70.4100, Recall: 81.5707, Precision: 5.1846
[Validation Metrics] Accuracy: 63.5688, Recall: 87.9285, Precision: 6.0846
Model Value: 0.0569082535112298


100%|██████████| 100/100 [10:08<00:00,  6.08s/it]


Epoch: 3, Loss: 3.7434
[Training Metrics]   Accuracy: 70.4403, Recall: 81.6525, Precision: 5.3107
[Validation Metrics] Accuracy: 65.0786, Recall: 87.1413, Precision: 6.2838
Model Value: 0.058611064079227726


100%|██████████| 100/100 [10:05<00:00,  6.06s/it]


Epoch: 4, Loss: 3.6962
[Training Metrics]   Accuracy: 71.3529, Recall: 80.5405, Precision: 5.4123
[Validation Metrics] Accuracy: 60.3570, Recall: 89.6353, Precision: 5.7121
Model Value: 0.053698941684617446


100%|██████████| 100/100 [10:04<00:00,  6.04s/it]


Epoch: 5, Loss: 3.7170
[Training Metrics]   Accuracy: 70.9508, Recall: 80.5467, Precision: 5.3714
[Validation Metrics] Accuracy: 66.6345, Recall: 85.5528, Precision: 6.4562
Model Value: 0.060032088864127635


100%|██████████| 100/100 [10:04<00:00,  6.05s/it]


Epoch: 6, Loss: 3.6930
[Training Metrics]   Accuracy: 71.3105, Recall: 81.6744, Precision: 5.2704
[Validation Metrics] Accuracy: 67.8611, Recall: 84.4757, Precision: 6.6159
Model Value: 0.06135370544624717


100%|██████████| 100/100 [10:16<00:00,  6.16s/it]


Epoch: 7, Loss: 3.6781
[Training Metrics]   Accuracy: 71.6575, Recall: 80.7326, Precision: 5.4015
[Validation Metrics] Accuracy: 58.6048, Recall: 90.9604, Precision: 5.5541
Model Value: 0.05234455522694671


100%|██████████| 100/100 [10:12<00:00,  6.12s/it]


Epoch: 8, Loss: 3.7079
[Training Metrics]   Accuracy: 70.8256, Recall: 79.0039, Precision: 5.1468
[Validation Metrics] Accuracy: 62.1869, Recall: 88.2289, Precision: 5.8912
Model Value: 0.05522482071551988


100%|██████████| 100/100 [10:13<00:00,  6.14s/it]


Epoch: 9, Loss: 3.7162
[Training Metrics]   Accuracy: 71.4990, Recall: 79.6432, Precision: 5.3784
[Validation Metrics] Accuracy: 60.8633, Recall: 90.4668, Precision: 5.8296
Model Value: 0.054766942876720445


100%|██████████| 100/100 [10:02<00:00,  6.03s/it]


Epoch: 10, Loss: 3.7203
[Training Metrics]   Accuracy: 70.1245, Recall: 81.6306, Precision: 5.0847
[Validation Metrics] Accuracy: 74.9329, Recall: 75.6927, Precision: 7.6187
Model Value: 0.06922002309941981
Saving Model!


100%|██████████| 100/100 [09:56<00:00,  5.96s/it]


Epoch: 11, Loss: 3.7552
[Training Metrics]   Accuracy: 70.8304, Recall: 80.2556, Precision: 5.3321
[Validation Metrics] Accuracy: 60.2163, Recall: 89.5214, Precision: 5.6865
Model Value: 0.05346862882627741


100%|██████████| 100/100 [10:07<00:00,  6.07s/it]


Epoch: 12, Loss: 3.6589
[Training Metrics]   Accuracy: 71.6075, Recall: 81.9063, Precision: 5.4326
[Validation Metrics] Accuracy: 64.7708, Recall: 86.3369, Precision: 6.1813
Model Value: 0.05768353768788015


100%|██████████| 100/100 [10:10<00:00,  6.11s/it]


Epoch: 13, Loss: 3.7259
[Training Metrics]   Accuracy: 71.0123, Recall: 80.8578, Precision: 5.4360
[Validation Metrics] Accuracy: 64.9783, Recall: 87.1832, Precision: 6.2693
Model Value: 0.058487645336534146


100%|██████████| 100/100 [10:03<00:00,  6.04s/it]


Epoch: 14, Loss: 3.7515
[Training Metrics]   Accuracy: 71.0375, Recall: 80.8451, Precision: 5.2973
[Validation Metrics] Accuracy: 60.5444, Recall: 90.3862, Precision: 5.7804
Model Value: 0.054329139418055866


100%|██████████| 100/100 [10:01<00:00,  6.02s/it]


Epoch: 15, Loss: 3.6530
[Training Metrics]   Accuracy: 70.6265, Recall: 81.9304, Precision: 5.2712
[Validation Metrics] Accuracy: 59.4925, Recall: 90.4237, Precision: 5.6399
Model Value: 0.053087754964137805


100%|██████████| 100/100 [09:53<00:00,  5.94s/it]


Epoch: 16, Loss: 3.6954
[Training Metrics]   Accuracy: 71.8576, Recall: 79.1929, Precision: 5.3922
[Validation Metrics] Accuracy: 57.5348, Recall: 90.7536, Precision: 5.4100
Model Value: 0.0510561575836145


100%|██████████| 100/100 [09:53<00:00,  5.94s/it]


Epoch: 17, Loss: 3.6106
[Training Metrics]   Accuracy: 72.2625, Recall: 80.4053, Precision: 5.4906
[Validation Metrics] Accuracy: 56.2134, Recall: 90.8878, Precision: 5.2614
Model Value: 0.04973456978480652


100%|██████████| 100/100 [10:08<00:00,  6.08s/it]


Epoch: 18, Loss: 3.6176
[Training Metrics]   Accuracy: 71.9747, Recall: 80.6560, Precision: 5.4386
[Validation Metrics] Accuracy: 59.4983, Recall: 89.7836, Precision: 5.6052
Model Value: 0.05275841042300038


100%|██████████| 100/100 [10:06<00:00,  6.07s/it]


Epoch: 19, Loss: 3.6910
[Training Metrics]   Accuracy: 71.9228, Recall: 79.9785, Precision: 5.4678
[Validation Metrics] Accuracy: 61.3572, Recall: 89.7442, Precision: 5.8586
Model Value: 0.05499558333540777


100%|██████████| 100/100 [10:13<00:00,  6.14s/it]


Epoch: 20, Loss: 3.6606
[Training Metrics]   Accuracy: 71.8312, Recall: 80.2561, Precision: 5.4453
[Validation Metrics] Accuracy: 66.2887, Recall: 86.9616, Precision: 6.4852
Model Value: 0.060350862805048644


100%|██████████| 100/100 [10:03<00:00,  6.03s/it]


Epoch: 21, Loss: 3.6174
[Training Metrics]   Accuracy: 72.3658, Recall: 81.0878, Precision: 5.5746
[Validation Metrics] Accuracy: 63.7953, Recall: 88.4437, Precision: 6.1521
Model Value: 0.057520338546128195


100%|██████████| 100/100 [10:02<00:00,  6.03s/it]


Epoch: 22, Loss: 3.6076
[Training Metrics]   Accuracy: 72.3147, Recall: 80.5472, Precision: 5.4847
[Validation Metrics] Accuracy: 69.3849, Recall: 84.7803, Precision: 6.9481
Model Value: 0.0642176386150592


100%|██████████| 100/100 [09:57<00:00,  5.98s/it]


Epoch: 23, Loss: 3.5963
[Training Metrics]   Accuracy: 72.3397, Recall: 80.9744, Precision: 5.5506
[Validation Metrics] Accuracy: 65.7507, Recall: 85.8839, Precision: 6.3197
Model Value: 0.05886509444468818


100%|██████████| 100/100 [10:10<00:00,  6.10s/it]


Epoch: 24, Loss: 3.6223
[Training Metrics]   Accuracy: 72.8135, Recall: 80.9891, Precision: 5.6128
[Validation Metrics] Accuracy: 68.9911, Recall: 84.4178, Precision: 6.8394
Model Value: 0.06326850122886392


100%|██████████| 100/100 [10:04<00:00,  6.04s/it]


Epoch: 25, Loss: 3.6069
[Training Metrics]   Accuracy: 72.3099, Recall: 81.8873, Precision: 5.5081
[Validation Metrics] Accuracy: 70.8503, Recall: 82.8329, Precision: 7.1332
Model Value: 0.06567634700878892


100%|██████████| 100/100 [10:09<00:00,  6.10s/it]


Epoch: 26, Loss: 3.6902
[Training Metrics]   Accuracy: 72.3580, Recall: 79.0815, Precision: 5.4288
[Validation Metrics] Accuracy: 66.3855, Recall: 87.2380, Precision: 6.5208
Model Value: 0.06067256531100581


100%|██████████| 100/100 [10:04<00:00,  6.04s/it]


Epoch: 27, Loss: 3.5846
[Training Metrics]   Accuracy: 72.7236, Recall: 82.2324, Precision: 5.6971
[Validation Metrics] Accuracy: 65.4245, Recall: 87.7174, Precision: 6.3798
Model Value: 0.05947245208377672


100%|██████████| 100/100 [09:59<00:00,  6.00s/it]


Epoch: 28, Loss: 3.5978
[Training Metrics]   Accuracy: 72.5783, Recall: 81.1969, Precision: 5.4333
[Validation Metrics] Accuracy: 72.7917, Recall: 80.9453, Precision: 7.4639
Model Value: 0.06833758554214364


100%|██████████| 100/100 [10:05<00:00,  6.05s/it]


Epoch: 29, Loss: 3.5817
[Training Metrics]   Accuracy: 73.1534, Recall: 80.7692, Precision: 5.6863
[Validation Metrics] Accuracy: 68.5035, Recall: 84.4012, Precision: 6.7382
Model Value: 0.06240060066209344


 35%|███▌      | 35/100 [03:25<06:27,  5.96s/it]

In [14]:
print("Best Model Value:", model_value)

Best Model Value: 0.06233690708736561


In [19]:
torch.cuda.empty_cache()

In [15]:
checkpoint = torch.load("model.pt")
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [16]:
test_correct, test_recall, test_precision = get_metrics(test_dataloader)

print("[Test Metrics] Accuracy: %.4f, Recall: %.4f, Precision: %.4f" % (100*test_correct, 100*test_recall, 100*test_precision))

[Test Metrics] Accuracy: 79.0923, Recall: 75.4540, Precision: 7.6479


In [17]:
val_correct, val_recall, val_precision = get_metrics(val_dataloader)

print("[Validation Metrics] Accuracy: %.4f, Recall: %.4f, Precision: %.4f" % (100*val_correct, 100*val_recall, 100*val_precision))

model_value = 0
if (val_recall + val_precision) > 1e-4:
    model_value = (val_recall * val_precision) / (val_recall + val_precision) # F1 Score
print("Model Value:", model_value)

[Validation Metrics] Accuracy: 74.9329, Recall: 75.6927, Precision: 7.6187
Model Value: 0.06922002309941981


In [14]:
torch.save(model.state_dict(), "model_least_loss.pt")