# Example 4: Model Training
First we import the necessary modules

In [34]:
import torch
from torch.utils.data import DataLoader, random_split
from pythermondt import transforms as T
from pythermondt.data import ThermoDataset, DataContainer
from pythermondt.readers import S3Reader
from example_models.defect_classifier import DefectClassifier3DCNN

First we define some general parameters for the model training

In [35]:
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 1
batch_size = 2
learning_rate = 1e-5

Now we can the define all the datasources and the transform pipeline used for training. Here you could specifiy multiple datasources (which are later combined usign the dataset) if you need that. In this example we only use one datasource.

**Note**: For the S3Reader object we set the cache_files flag to true. Therefore all the files are cached to a folder (.pyThermoNDT_cache) in the current working directory. This makes training way faster, because the files are now only downloaded once and not every time the datasource is loaded. 

In [36]:
# Specifiy the datasource
s3reader = S3Reader(
    source='s3://ffg-bp/example4_model_training/.hdf5',
    cache_files=True
)

# Setup transform pipeline
pipeline = T.Compose([
    T.ApplyLUT(),
    T.MinMaxNormalize(),
    T.SubstractFrame(),
])

Now we can combine these datasources and the transform pipeline by creating a dataset.

In [37]:
# Create a Dataset
dataset = ThermoDataset(data_source=s3reader, transform=pipeline)

Before we can continue, we first need to write a custom collate function. All our readers and datasets always load data in form of Datacontainer objects. However, when training a model the input data needs to be in form of a tensor. Therefore the collate function extracts the data from all the Datacontainer objects in the current batch and stacks them along the batch dimension:

In [38]:
# Custom collate function to extract data and target from the DataContainers in the batch
def collate_fn(batch: list[DataContainer]) -> tuple[torch.Tensor, torch.Tensor]:
    # Extract data and target from the DataContainers
    tdata = []
    mask = []

    # Extract the data and label from the DataContainer
    for container in batch:
        tdata.append(container.get_dataset("/Data/Tdata").unsqueeze(0)[:,:,0:499])
        mask.append(torch.tensor([0, 1]) if container.get_dataset("/GroundTruth/DefectMask").equal(torch.zeros(100,100)) else torch.tensor([1, 0]))

    # Stack the tensors along the batch dimension
    data = torch.stack(tdata).to(device=device, dtype=torch.float32)
    label = torch.stack(mask).to(device=device, dtype=torch.float32)

    return data, label

Now we can split the dataset into a training and validation dataset and create the dataloaders for both datasets. 

In [39]:
# Split the dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create a DataLoader
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)

Before we can start the training we also need to define the model, the loss function and the optimizer:

In [40]:
# Define the model and optimizer
model = DefectClassifier3DCNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

Now we can run the training loop with the following code. The training loop is very simple and only consists of a few lines of code. For real world applications you might want to add more features like logging, early stopping, learning rate scheduling, etc.

**Note:** The training loop is stopped after 30 batches and only runs 1 epoch for demonstration purposes! 

In [41]:
# Training loop
for epoch in range(epochs):
    print(f"Epoch {epoch}")

    # Set the model to training mode
    model.train()
    print("Training:")
    for batch_idx, (data, label) in enumerate(train_loader):
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        output = model(data)

        # Compute the loss
        loss = loss_fn(output, label)

        # Backward pass
        loss.backward()

        # Update the weights
        optimizer.step()

        # Print the loss
        if batch_idx % 10 == 0:
            print(f"Batch {batch_idx}, Loss: {loss.item()}")
        
        # Stop after 30 batches
        if batch_idx == 30:
            break

    # Evaluate the model
    with torch.no_grad():
        model.eval()
        val_loss_summed = 0
        print("Validation:")
        for batch_idx, (data, label) in enumerate(val_loader):
            # Forward pass
            output = model(data)

            # Compute the loss
            loss = loss_fn(output, label)
            val_loss_summed += loss.item()

            # Stop after 30 batches
            if batch_idx == 30:
                break
        
        
        print(f"Validation Loss: {val_loss_summed / len(val_loader)}")

Epoch 0
Training:
Batch 0, Loss: 0.6381669044494629
Batch 10, Loss: 0.7491753101348877
Batch 20, Loss: 0.696014404296875
Batch 30, Loss: 0.6932556629180908
Validation:
Validation Loss: 0.14315616210301718
