# Example 4: Model Training
First we import the necessary modules

In [11]:
import torch
from torch.utils.data import DataLoader
from pythermondt import transforms as T
from pythermondt.data import ThermoDataset, DataContainer, random_split
from pythermondt.readers import S3Reader
from example_models.defect_classifier import DefectClassifier3DCNN

First we define some general parameters for the model training

In [12]:
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 1
batch_size = 2
learning_rate = 1e-5

Now we can the define all the datasources used for training. Here you could specifiy multiple datasources (which are later combined usign the dataset) if you need that. In this example we only use one datasource.

**Note:** For the S3Reader object we set the cache_files flag to true. Therefore all the files are cached to a folder (.pyThermoNDT_cache) in the current working directory. This makes training way faster, because the files are now only downloaded once and not every time the datasource is loaded. 

In [13]:
# Specifiy the datasource
s3reader = S3Reader(
    source='s3://ffg-bp/example4_model_training/.hdf5',
    cache_files=True
)

Downloading Files for S3Reader(source=s3://ffg-bp/example4_model_training/.hdf5):   0%|          | 0/195 [00:0â€¦

Now we can combine these datasources by creating a dataset.

In [14]:
# Create a Dataset
dataset = ThermoDataset(data_source=s3reader)

Afterwards the transform pipeline which will be applied to the data before it gets fed into the model is defined. In this example we use Data Augmentation techniques like flipping and rotating the images or adding noise to the images (to simulate NETD of the camera). Therefore we need 2 different pipelines. One for the training set and one for the test set.

In [15]:
# Setup transform pipeline for training set
train_pipeline = T.Compose([
    T.ApplyLUT(),
    T.GaussianNoise(std=1e-3), # Data Augmentation
    T.RandomFlip(p_height=0.3, p_width=0.3), # Data Augmentation
    T.SubstractFrame(0), 
    T.RemoveFlash(method='excitation_signal'),
    T.NonUniformSampling(64),
    T.MinMaxNormalize(),
])

# Setup transform pipeline for test set
test_pipeline = T.Compose([
    T.ApplyLUT(),
    T.SubstractFrame(0), 
    T.RemoveFlash(method='excitation_signal'),
    T.NonUniformSampling(64),
    T.MinMaxNormalize(),
])

Before we can continue, we first need to write a custom collate function. All our readers and datasets always load data in form of Datacontainer objects. However, when training a model the input data needs to be in form of a tensor. Therefore the collate function extracts the data from all the Datacontainer objects in the current batch and stacks them along the batch dimension:

In [16]:
# Custom collate function to extract data and target from the DataContainers in the batch
def collate_fn(batch: list[DataContainer]) -> tuple[torch.Tensor, torch.Tensor]:
    # Extract data and target from the DataContainers
    tdata = []
    mask = []

    # Extract the data and label from the DataContainer
    for container in batch:
        tdata.append(container.get_dataset("/Data/Tdata").unsqueeze(0)[:,:,0:499])
        mask.append(torch.tensor([0, 1]) if container.get_dataset("/GroundTruth/DefectMask").equal(torch.zeros(100,100)) else torch.tensor([1, 0]))

    # Stack the tensors along the batch dimension
    data = torch.stack(tdata).to(device=device, dtype=torch.float32)
    label = torch.stack(mask).to(device=device, dtype=torch.float32)

    return data, label

Next we can split the dataset into a training and validation subset, using the random_split function provided with pyThermoNDT. Afterwards the dataloaders for each of the subsets are created.

**Note:** In this example we apply the same transformation pipeline to both subsets. However, the pipeline could be different for each subset if needed.

In [17]:
# Split the dataset in train and test subsets
train_set, test_set = random_split(dataset, [0.8, 0.2], [train_pipeline, test_pipeline])

# Print the length of the subsets
print(f"Train set length: {len(train_set)}")
print(f"Test set length: {len(test_set)}")

# Create a DataLoader
train_loader = DataLoader(train_set, batch_size=2, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(test_set, batch_size=2, shuffle=False, collate_fn=collate_fn)

Train set length: 672
Test set length: 168


Before we can start the training we also need to define the model, the loss function and the optimizer:

In [18]:
# Define the model and optimizer
model = DefectClassifier3DCNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

Now we can run the training loop with the following code. The training loop is very simple and only consists of a few lines of code. For real world applications you might want to add more features like logging, early stopping, learning rate scheduling, etc.

**Note:** The training loop is stopped after 30 batches and only runs 1 epoch for demonstration purposes! 

In [19]:
# Training loop
for epoch in range(epochs):
    print(f"Epoch {epoch}")

    # Set the model to training mode
    model.train()
    print("Training:")
    for batch_idx, (data, label) in enumerate(train_loader):
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        output = model(data)

        # Compute the loss
        loss = loss_fn(output, label)

        # Backward pass
        loss.backward()

        # Update the weights
        optimizer.step()

        # Print the loss
        if batch_idx % 10 == 0:
            print(f"Batch {batch_idx}, Loss: {loss.item()}")
        
        # Stop after 30 batches
        if batch_idx == 30:
            break

    # Evaluate the model
    with torch.no_grad():
        model.eval()
        val_loss_summed = 0
        print("Validation:")
        for batch_idx, (data, label) in enumerate(val_loader):
            # Forward pass
            output = model(data)

            # Compute the loss
            loss = loss_fn(output, label)
            val_loss_summed += loss.item()

            # Stop after 30 batches
            if batch_idx == 30:
                break
        
        
        print(f"Validation Loss: {val_loss_summed / len(val_loader)}")

Epoch 0
Training:


RuntimeError: shape '[-1, 285696]' is invalid for input of size 36864