In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FullyConnectedNetwork(nn.Module):
    def __init__(self):
        super(FullyConnectedNetwork, self).__init__()
        
        # First fully connected layer
        self.fc1 = nn.Linear(25 * 25, 512)
        self.dropout1 = nn.Dropout(0.5)

        # Second fully connected layer
        self.fc2 = nn.Linear(512, 256)
        self.dropout2 = nn.Dropout(0.5)

        # Third fully connected layer
        self.fc3 = nn.Linear(256, 128)

        # Output layer
        self.fc4 = nn.Linear(128, 15)

    def forward(self, x):
        # Flatten the image input
        x = x.view(x.size(0), -1)

        # Pass through the first layer, apply ReLU and dropout
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)

        # Second layer
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)

        # Third layer
        x = F.relu(self.fc3(x))

        # Output layer with softmax activation
        x = self.fc4(x)
        x = F.log_softmax(x, dim=1)
        return x

# Create the neural network
model = FullyConnectedNetwork()

# Print the model structure
print(model)


FullyConnectedNetwork(
  (fc1): Linear(in_features=625, out_features=512, bias=True)
  (dropout1): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc3): Linear(in_features=256, out_features=128, bias=True)
  (fc4): Linear(in_features=128, out_features=15, bias=True)
)


In [2]:
class StormShadow(torch.utils.data.Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx, :, :]
        label = self.labels[idx, :]
        return sample, label

# Load the data

In [6]:
import xarray as xr 
vil_ds = xr.open_zarr('../hrrr17x.zarr/')
vil_dr = vil_ds.VIL
# Get all the file names in bystormlabels folder
from tqdm import tqdm
import numpy as np 
import os

labels = np.zeros((5625, 15))

for date_str in tqdm(vil_dr.attrs['date']):
    # Replace : with _ in date_str
    filename_to_open = date_str.replace(':', '_')
    # check if the file exists
    if not os.path.isfile('bystormlabels/' + filename_to_open + '.npz'):
        continue
    labell = np.load('bystormlabels/' + filename_to_open + '.npz')['arr_0']
    frequency_vector = np.array([(labell==i).sum() for i in range(15)])
    if frequency_vector.sum() == 0:
        continue
    else:
        print(vil_dr.attrs['date'].index(date_str))
    labels[vil_dr.attrs['date'].index(date_str)] = frequency_vector
vil_input = vil_dr.values

100%|██████████| 5625/5625 [00:00<00:00, 312448.15it/s]

3414





In [7]:
# Convert to tensor
vil_tensor = torch.from_numpy(vil_input)
labels_tensor = torch.from_numpy(labels)

In [9]:
# Create dataset
dataset = StormShadow(vil_tensor, labels_tensor)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

# Define the model
model = FullyConnectedNetwork()

# Define the loss
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Number of epochs to train the model
n_epochs = 10

# Training loop
for epoch in range(n_epochs):
    for data, target in train_loader:
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        output = model(data)
        
        # Compute the loss
        loss = criterion(output, target)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
    
    print(f'Epoch {epoch+1}/{n_epochs}, Loss: {loss.item()}')

Epoch 1/10, Loss: -0.0
Epoch 2/10, Loss: -0.0
Epoch 3/10, Loss: -0.0
Epoch 4/10, Loss: -0.0
Epoch 5/10, Loss: -0.0
Epoch 6/10, Loss: -0.0
Epoch 7/10, Loss: -0.0
Epoch 8/10, Loss: -0.0
Epoch 9/10, Loss: -0.0
Epoch 10/10, Loss: -0.0


In [10]:
model(torch.randn(1, 25, 25))

tensor([[-2.8590, -2.8626, -2.7978, -2.5336, -2.9397, -1.9958, -2.6702, -2.7409,
         -2.7764, -2.9866, -2.8852, -2.8088, -2.9817, -2.6740, -2.6013]],
       grad_fn=<LogSoftmaxBackward0>)