In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from tfrecord.torch.dataset import TFRecordDataset
import matplotlib.pyplot as plt
import pandas as pd

import torch.nn.functional as F
import torch.nn as nn

import torch.optim as optim
from tqdm import tqdm
import copy

from utils import load
from attrdict import AttrDict
import json

### Load dataset and metadata

In [2]:
lf_meta = pd.read_csv('fuel_autoencoder/landfire_metadata.csv')
lf_meta.drop(index = 0,inplace= True)

landfire_fuel_classes = dict(zip(lf_meta['VALUE'],lf_meta['FBFM40']))

kys = list(landfire_fuel_classes.keys())
fuel_classes = {i:landfire_fuel_classes[x] for i,x in enumerate(kys)}
fuel_class_map = {x:i for i,x in enumerate(kys)}

In [3]:
viz = {row['FBFM40']:(row['R'],row['G'],row['B']) for i,row in lf_meta.iterrows()}

In [4]:
modes = ['train','test','val']
data_loaders = {}
for mode in modes:
    tfrecord_path = f"dataset/conus_west_fbfm40_{mode}.tfrecord"
    index_path = None
    description = {"fbfm": "float"}
    dataset = TFRecordDataset(tfrecord_path, index_path, description)
    loader = torch.utils.data.DataLoader(dataset, batch_size=2048)
    data_loaders[mode] = loader
    data_batch = next(iter(loader))

### data processing funcs

In [5]:
def replace_categories(value):
    try:
        return fuel_class_map[value]
    except KeyError:
        return float('nan')

def get_pcts(arr): # takes (B,H,W) array
    B,H,W = arr.shape
    n_cats = len(kys)
    res = torch.zeros((B,n_cats))
    for i,img in enumerate(arr):
        for k,cat in enumerate(kys):
            res[i,k] = torch.sum(arr[i] == cat).item()/(H*W)
    return res

def process_batch(data_batch,onehot = False):
    data = data_batch['fbfm'].clone()
    data = data.reshape((data.shape[0],16,16))
    labels = get_pcts(data)
    data.apply_(replace_categories)
    if onehot:
        data = F.one_hot(data.long(),num_classes = len(kys)).float()
    return data,labels

## Model development

In [9]:
data,labels = process_batch(data_batch,onehot=True)

In [10]:
data[0].unique()

tensor([0., 1.])

### Model definition/development workspace

In [12]:
class Encoder(nn.Module):
    def __init__(
        self,
        input_channels=45,
        input_height=16,
        input_width=16,
        conv_channels=[64, 32,32],
        kernel_sizes=[3, 3,3],
        strides=[1, 1,1],
        paddings=[1, 1,1],
        pooling='max',
        pool_kernels=[2, 2,2],
        fc_hidden_dims=[32,16,8],
        n_H=4,
        activation=nn.ReLU()
    ):
        super(Encoder, self).__init__()

        self.activation = activation
        self.conv_layers = nn.ModuleList()
        self.pool_layers = nn.ModuleList()
        in_channels = input_channels
        H_in, W_in = input_height, input_width

        # Define convolutional and pooling layers
        for i, (out_channels, kernel_size, stride, padding) in enumerate(
            zip(conv_channels, kernel_sizes, strides, paddings)
        ):
            # Convolutional layer
            conv = nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding
            )
            self.conv_layers.append(conv)

            # Update spatial dimensions after convolution
            H_in = self._compute_output_dim(H_in, kernel_size, stride, padding)
            W_in = self._compute_output_dim(W_in, kernel_size, stride, padding)

            # Pooling layer
            if pooling == 'max':
                pool = nn.MaxPool2d(kernel_size=pool_kernels[i])
            elif pooling == 'avg':
                pool = nn.AvgPool2d(kernel_size=pool_kernels[i])
            else:
                pool = None
            self.pool_layers.append(pool)

            # Update spatial dimensions after pooling
            if pool is not None:
                H_in = self._compute_output_dim(H_in, pool_kernels[i], pool_kernels[i], 0)
                W_in = self._compute_output_dim(W_in, pool_kernels[i], pool_kernels[i], 0)

            in_channels = out_channels

        # Calculate the flattened feature dimension after convolutions
        self.feature_dim = H_in * W_in * in_channels

        # Define fully connected layers
        fc_dims = [self.feature_dim] + fc_hidden_dims + [n_H]
        self.fc_layers = nn.ModuleList()
        self.fc_batch_norms = nn.ModuleList()
        for i in range(len(fc_dims) - 1):
            in_dim = fc_dims[i]
            out_dim = fc_dims[i + 1]
            self.fc_layers.append(nn.Linear(in_dim, out_dim))
            if i < len(fc_dims) - 2:
                # Add BatchNorm1d layer for all but the last FC layer
                self.fc_batch_norms.append(nn.BatchNorm1d(out_dim))

    def _compute_output_dim(self, size, kernel_size, stride, padding):
        return (size + 2 * padding - kernel_size) // stride + 1

    def forward(self, x):
        # Permute input to (B, C, H, W)
        x = x.permute(0, 3, 1, 2)

        # Apply convolutional and pooling layers
        for conv_layer, pool_layer in zip(self.conv_layers, self.pool_layers):
            x = self.activation(conv_layer(x))
            if pool_layer is not None:
                x = pool_layer(x)

        # Flatten the output from convolutional layers
        x = x.view(x.size(0), -1)

        # Apply fully connected layers with batch normalization
        for i, fc_layer in enumerate(self.fc_layers[:-1]):
            x = fc_layer(x)
            x = self.activation(x)
            x = self.fc_batch_norms[i](x)

        # Output layer without batch normalization and activation
        x = self.fc_layers[-1](x)

        return x


class Decoder(nn.Module):
    def __init__(
        self,
        n_H=4,
        fc_hidden_dims=[8,16,32],
        output_dim=45,
        activation=nn.ReLU()
    ):
        super(Decoder, self).__init__()

        self.activation = activation
        self.output_dim = output_dim

        # Define fully connected layers
        fc_dims = [n_H] + fc_hidden_dims + [self.output_dim]
        self.fc_layers = nn.ModuleList()
        self.fc_batch_norms = nn.ModuleList()
        for i in range(len(fc_dims) - 1):
            in_dim = fc_dims[i]
            out_dim = fc_dims[i + 1]
            self.fc_layers.append(nn.Linear(in_dim, out_dim))
            if i < len(fc_dims) - 2:
                # Add BatchNorm1d layer for all but the last FC layer
                self.fc_batch_norms.append(nn.BatchNorm1d(out_dim))

    def forward(self, x):
        # x is of shape (B, n_H)
        # Apply fully connected layers with batch normalization
        for i, fc_layer in enumerate(self.fc_layers[:-1]):
            x = fc_layer(x)
            x = self.activation(x)
            x = self.fc_batch_norms[i](x)
            
        # Output layer without batch normalization and activation
        x = self.fc_layers[-1](x)

        return x

class fuel_autoencoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(fuel_autoencoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        latent = self.encoder(x)
        output_vector = self.decoder(latent)
        return output_vector

### Training loop

In [14]:
run_name = 'example_run'

#### Define and load the model from its configuration file

* note to self: this final model was trained for 25 + 10 + 10 epochs, annealing LR to 1e-3 -> 1e-4 after the 20

In [15]:
config = AttrDict({
    "conv_channels":[32, 32],
    "kernel_sizes":[3, 3,3],
    "strides":[1, 1,1],
    "paddings":[1, 1,1],
    "pooling":'max',
    "pool_kernels":[2, 2,2],
    "encoder_hidden_dims":[32,16],
    "latent_dim":3,
    "decoder_hidden_dims":[8,16,32],
    "weights_seed":123,
    "learning_rate":1e-3
})

model_config_file = 'fuel_autoencoder_config.py'
autoencoder,name = load(model_config_file,config)

with open(f'models/{run_name}.json','w') as json_file:
    json.dump(config,json_file)

Loading 'fuel_autoencoder_config' from /Users/gorg/Documents/ndws/fuel_embedding/fuel_autoencoder_config.py


In [16]:
autoencoder

fuel_autoencoder(
  (encoder): Encoder(
    (activation): ReLU()
    (conv_layers): ModuleList(
      (0): Conv2d(45, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (pool_layers): ModuleList(
      (0-1): 2 x MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (fc_layers): ModuleList(
      (0): Linear(in_features=512, out_features=32, bias=True)
      (1): Linear(in_features=32, out_features=16, bias=True)
      (2): Linear(in_features=16, out_features=3, bias=True)
    )
    (fc_batch_norms): ModuleList(
      (0): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (decoder): Decoder(
    (activation): ReLU()
    (fc_layers): ModuleList(
      (0): Linear(in_features=3, out_features=8, bias=True)
      (1): Linear(in_features=8,

### Model testing

In [28]:
# Define the optimizer
autoencoder.load_state_dict(best_model_wts)

optimizer = optim.AdamW(autoencoder.parameters(), lr=1e-4)

# Define the loss function
criterion = nn.KLDivLoss(reduction='batchmean')

train_losses = []
val_losses = []

# Number of epochs
num_epochs = 10

# Device configuration
device = torch.device('mps')
autoencoder.to(device)

best_model_wts = copy.deepcopy(autoencoder.state_dict())
best_loss = float('inf')

for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            autoencoder.train()  # Set model to training mode
        else:
            autoencoder.eval()   # Set model to evaluate mode

        running_loss = 0.0

        # Iterate over data
        with tqdm(data_loaders[phase], unit="batch") as tepoch:
            for batch in tepoch:
                tepoch.set_description(f"{phase.capitalize()} Epoch {epoch+1}")

                inputs,targets = process_batch(batch,onehot = True)
                
                inputs = inputs.to(device)  # Shape: (B, 16, 16, 45)
                targets = targets.to(device)  # Shape: (B, 45)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = autoencoder(inputs)  # Shape: (B, 45)

                    # Apply log_softmax to outputs
                    log_probs = F.log_softmax(outputs, dim=-1)

                    # Compute loss
                    loss = criterion(log_probs, targets)

                    # Backward pass and optimization
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                tepoch.set_postfix(loss=loss.item())

        epoch_loss = running_loss / 911360 # batch 2048 * 445 batches = 911360 

        if phase == 'train':
            train_losses.append(epoch_loss)
        else:
            val_losses.append(epoch_loss)

            # Deep copy the model if it has better validation loss
            if epoch_loss < best_loss:
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(autoencoder.state_dict())

        print(f'{phase.capitalize()} Loss: {epoch_loss:.4f}')

print('Training complete')
print(f'Best Validation Loss: {best_loss:.4f}')

# Load best model weights
autoencoder.load_state_dict(best_model_wts)

Epoch 1/10
----------


Train Epoch 1: : 445batch [10:22,  1.40s/batch, loss=0.161]


Train Loss: 0.1451


Val Epoch 1: : 134batch [02:59,  1.34s/batch, loss=0.152]


Val Loss: 0.0433
Epoch 2/10
----------


Train Epoch 2: : 445batch [10:16,  1.38s/batch, loss=0.161]


Train Loss: 0.1447


Val Epoch 2: : 134batch [03:01,  1.36s/batch, loss=0.151]


Val Loss: 0.0431
Epoch 3/10
----------


Train Epoch 3: : 445batch [10:03,  1.36s/batch, loss=0.161]


Train Loss: 0.1444


Val Epoch 3: : 134batch [02:57,  1.32s/batch, loss=0.151]


Val Loss: 0.0431
Epoch 4/10
----------


Train Epoch 4: : 445batch [10:07,  1.37s/batch, loss=0.16] 


Train Loss: 0.1441


Val Epoch 4: : 134batch [03:02,  1.36s/batch, loss=0.151]


Val Loss: 0.0429
Epoch 5/10
----------


Train Epoch 5: : 445batch [10:14,  1.38s/batch, loss=0.16] 


Train Loss: 0.1438


Val Epoch 5: : 134batch [02:58,  1.33s/batch, loss=0.15] 


Val Loss: 0.0429
Epoch 6/10
----------


Train Epoch 6: : 445batch [10:12,  1.38s/batch, loss=0.159]


Train Loss: 0.1435


Val Epoch 6: : 134batch [03:00,  1.35s/batch, loss=0.15] 


Val Loss: 0.0427
Epoch 7/10
----------


Train Epoch 7: : 445batch [10:05,  1.36s/batch, loss=0.159]


Train Loss: 0.1432


Val Epoch 7: : 134batch [02:59,  1.34s/batch, loss=0.149]


Val Loss: 0.0425
Epoch 8/10
----------


Train Epoch 8: : 445batch [10:08,  1.37s/batch, loss=0.159]


Train Loss: 0.1429


Val Epoch 8: : 134batch [03:00,  1.35s/batch, loss=0.149]


Val Loss: 0.0426
Epoch 9/10
----------


Train Epoch 9: : 445batch [10:06,  1.36s/batch, loss=0.158]


Train Loss: 0.1426


Val Epoch 9: : 134batch [03:00,  1.34s/batch, loss=0.149]


Val Loss: 0.0425
Epoch 10/10
----------


Train Epoch 10: : 445batch [10:05,  1.36s/batch, loss=0.158]


Train Loss: 0.1424


Val Epoch 10: : 134batch [02:57,  1.33s/batch, loss=0.149]

Val Loss: 0.0424
Training complete
Best Validation Loss: 0.0424





<All keys matched successfully>

In [29]:
torch.save(best_model_wts,f'models/{run_name}.pth')

In [17]:
plt.plot(train_losses,label = 'train')
plt.plot(val_losses,label = 'test')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.yscale('log')

NameError: name 'train_losses' is not defined