In [1]:
from torchvision.models.segmentation import deeplabv3_resnet50, deeplabv3
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
from data_manager import create_modified_crop_labels, filter_balanced_patches, setup_training_loader
from tqdm import tqdm

In [2]:
weights = deeplabv3.DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1
model = deeplabv3_resnet50(num_classes=2)



# Get the original conv1 layer
original_conv = model.backbone.conv1

# Create a new conv1 layer with 6 input channels
new_conv = torch.nn.Conv2d(
    in_channels=18,
    out_channels=original_conv.out_channels,
    kernel_size=original_conv.kernel_size,
    stride=original_conv.stride,
    padding=original_conv.padding,
    bias=original_conv.bias is not None,
)

model.backbone.conv1 = new_conv


In [3]:
TARGET_CROP = -1  # The crop ID we're training to detect
UNCHANGED_CROPS = [1, 5, 23, 176]  # List of unchanged crops


train_loader = setup_training_loader(
    path_to_train_data='./training_data/train_patches.npy',
    unchanged_crops=UNCHANGED_CROPS,
    target_crops=[TARGET_CROP],
    train_batch_size=16,
    crop_band_index=18,
    device='cuda',
    ignore_crops=None,
    min_ratio=0.1,
    max_ratio=0.9
)

    # Setup validation loader
val_loader = setup_training_loader(
    path_to_train_data='./training_data/val_patches.npy',
    unchanged_crops=UNCHANGED_CROPS,
    target_crops=[TARGET_CROP],
    train_batch_size=16,
    crop_band_index=18,
    device='cuda',
    ignore_crops=None,
    min_ratio=0.1,
    max_ratio=0.9
)

test_loader = setup_training_loader(
    path_to_train_data='./training_data/test_patches.npy',
    unchanged_crops=UNCHANGED_CROPS,
    target_crops=[TARGET_CROP],
    train_batch_size=16,
    crop_band_index=18,
    device='cuda',
    ignore_crops=None,
    min_ratio=0.1,
    max_ratio=0.9
)


Filtered 1074 patches to 951 good patches (88.55%)
Dataset loaded with 951 patches
Total pixels: 47717376
Positive pixels (+1): 18198112
Negative pixels (-1): 29519264
Filtered 231 patches to 207 good patches (89.61%)
Dataset loaded with 207 patches
Total pixels: 10386432
Positive pixels (+1): 3942491
Negative pixels (-1): 6443941
Filtered 231 patches to 209 good patches (90.48%)
Dataset loaded with 209 patches
Total pixels: 10486784
Positive pixels (+1): 4125291
Negative pixels (-1): 6361493


In [21]:
# Since the labels are +1 and -1, we need to convert them to 0 and 1 for model compatibility
def transform_labels(labels):
    return ((labels + 1) / 2).long()  # Converts -1 to 0 and +1 to 1

# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

# Function to calculate precision, recall, and F1-score for binary classification
def calculate_metrics(outputs, labels):
    _, predicted = torch.max(outputs, 1)
    true_positive = ((predicted == 1) & (labels == 1)).sum().item()
    false_positive = ((predicted == 1) & (labels == 0)).sum().item()
    false_negative = ((predicted == 0) & (labels == 1)).sum().item()
    true_negative = ((predicted == 0) & (labels == 0)).sum().item()
    
    precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
    recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    accuracy = (true_positive + true_negative) / (true_positive + true_negative + false_positive + false_negative)
    
    return accuracy, precision, recall, f1

# Training function
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    total_accuracy = 0
    total_precision = 0
    total_recall = 0
    total_f1 = 0
    batches = 0
    
    # Add tqdm progress bar
    pbar = tqdm(train_loader, desc='Training')
    for images, labels in pbar:
        # Move data to device
        images = images.permute(0, 3, 1, 2).to(device)  # Change to (B, C, H, W)
        labels = transform_labels(labels).to(device)
        
        # Forward pass
        outputs = model(images)['out']
        
        # Calculate loss
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Calculate metrics
        accuracy, precision, recall, f1 = calculate_metrics(outputs, labels)
        
        total_loss += loss.item()
        total_accuracy += accuracy
        total_precision += precision
        total_recall += recall
        total_f1 += f1
        batches += 1
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{accuracy:.4f}',
            'f1': f'{f1:.4f}'
        })
    
    return (total_loss / batches, total_accuracy / batches, 
            total_precision / batches, total_recall / batches, total_f1 / batches)

# Validation function
def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    total_accuracy = 0
    total_precision = 0
    total_recall = 0
    total_f1 = 0
    batches = 0
    
    # Add tqdm progress bar
    pbar = tqdm(val_loader, desc='Validation')
    with torch.no_grad():
        for images, labels in pbar:
            # Move data to device
            images = images.permute(0, 3, 1, 2).to(device)  # Change to (B, C, H, W)
            labels = transform_labels(labels).to(device)
            
            # Forward pass
            outputs = model(images)['out']
            
            # Calculate loss
            loss = criterion(outputs, labels)
            
            # Calculate metrics
            accuracy, precision, recall, f1 = calculate_metrics(outputs, labels)
            
            total_loss += loss.item()
            total_accuracy += accuracy
            total_precision += precision
            total_recall += recall
            total_f1 += f1
            batches += 1
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{accuracy:.4f}',
                'f1': f'{f1:.4f}'
            })
    
    return (total_loss / batches, total_accuracy / batches, 
            total_precision / batches, total_recall / batches, total_f1 / batches)

In [None]:
# Training loop
num_epochs = 100
best_val_f1 = 0.0

# Add tqdm for epochs
epoch_pbar = tqdm(range(num_epochs), desc='Epochs')
for epoch in epoch_pbar:
    # Training
    train_loss, train_acc, train_prec, train_rec, train_f1 = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validation
    val_loss, val_acc, val_prec, val_rec, val_f1 = validate(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step()
    
    # Update epoch progress bar
    epoch_pbar.set_postfix({
        'train_loss': f'{train_loss:.4f}',
        'train_acc': f'{train_acc:.4f}',
        'train_f1': f'{train_f1:.4f}',
        'val_loss': f'{val_loss:.4f}',
        'val_acc': f'{val_acc:.4f}',
        'val_f1': f'{val_f1:.4f}'
    })
    
    # Save best model based on F1-score
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), 'best_model_binary.pth')
        print(f'\nNew best model saved with validation F1-score: {val_f1:.4f}')
        print(f'Validation metrics - Accuracy: {val_acc:.4f}, Precision: {val_prec:.4f}, Recall: {val_rec:.4f}')

# Load best model for testing
model.load_state_dict(torch.load('best_model_binary.pth'))

# Test the model
test_loss, test_acc, test_prec, test_rec, test_f1 = validate(model, test_loader, criterion, device)
print(f'\nTest Results:')
print(f'Test Loss: {test_loss:.4f}')
print(f'Test Accuracy: {test_acc:.4f}')
print(f'Test Precision: {test_prec:.4f}')
print(f'Test Recall: {test_rec:.4f}')
print(f'Test F1-score: {test_f1:.4f}')

Epochs:   0%|                 | 0/100 [00:00<?, ?it/s]
Training:   0%|                | 0/32 [00:00<?, ?it/s][A
Training:   0%| | 0/32 [00:02<?, ?it/s, loss=0.6835, a[A
Training:   3%| | 1/32 [00:02<01:19,  2.56s/it, loss=0[A
Training:   3%| | 1/32 [00:04<01:19,  2.56s/it, loss=0[A
Training:   6%| | 2/32 [00:04<01:11,  2.38s/it, loss=0[A
Training:   6%| | 2/32 [00:07<01:11,  2.38s/it, loss=1[A
Training:   9%| | 3/32 [00:07<01:07,  2.32s/it, loss=1[A
Training:   9%| | 3/32 [00:09<01:07,  2.32s/it, loss=1[A
Training:  12%|▏| 4/32 [00:09<01:04,  2.29s/it, loss=1[A
Training:  12%|▏| 4/32 [00:11<01:04,  2.29s/it, loss=0[A
Training:  16%|▏| 5/32 [00:11<01:01,  2.27s/it, loss=0[A
Training:  16%|▏| 5/32 [00:13<01:01,  2.27s/it, loss=0[A
Training:  19%|▏| 6/32 [00:13<00:58,  2.26s/it, loss=0[A
Training:  19%|▏| 6/32 [00:16<00:58,  2.26s/it, loss=0[A
Training:  22%|▏| 7/32 [00:16<00:56,  2.26s/it, loss=0[A
Training:  22%|▏| 7/32 [00:18<00:56,  2.26s/it, loss=0[A
Training:  25%|▎|


New best model saved with validation F1-score: 0.0068
Validation metrics - Accuracy: 0.6783, Precision: 0.4296, Recall: 0.0034



Training:   0%|                | 0/32 [00:00<?, ?it/s][A
Training:   0%| | 0/32 [00:02<?, ?it/s, loss=0.5006, a[A
Training:   3%| | 1/32 [00:02<01:09,  2.24s/it, loss=0[A
Training:   3%| | 1/32 [00:04<01:09,  2.24s/it, loss=0[A
Training:   6%| | 2/32 [00:04<01:07,  2.25s/it, loss=0[A
Training:   6%| | 2/32 [00:06<01:07,  2.25s/it, loss=0[A
Training:   9%| | 3/32 [00:06<01:05,  2.24s/it, loss=0[A
Training:   9%| | 3/32 [00:08<01:05,  2.24s/it, loss=0[A
Training:  12%|▏| 4/32 [00:08<01:02,  2.25s/it, loss=0[A
Training:  12%|▏| 4/32 [00:11<01:02,  2.25s/it, loss=0[A
Training:  16%|▏| 5/32 [00:11<01:00,  2.25s/it, loss=0[A
Training:  16%|▏| 5/32 [00:13<01:00,  2.25s/it, loss=0[A
Training:  19%|▏| 6/32 [00:13<00:58,  2.25s/it, loss=0[A
Training:  19%|▏| 6/32 [00:15<00:58,  2.25s/it, loss=0[A
Training:  22%|▏| 7/32 [00:15<00:56,  2.25s/it, loss=0[A
Training:  22%|▏| 7/32 [00:17<00:56,  2.25s/it, loss=0[A
Training:  25%|▎| 8/32 [00:17<00:53,  2.25s/it, loss=0[A
Training:  25


New best model saved with validation F1-score: 0.7720
Validation metrics - Accuracy: 0.8465, Precision: 0.7329, Recall: 0.8166



Training:   0%|                | 0/32 [00:00<?, ?it/s][A
Training:   0%| | 0/32 [00:02<?, ?it/s, loss=0.3739, a[A
Training:   3%| | 1/32 [00:02<01:09,  2.26s/it, loss=0[A
Training:   3%| | 1/32 [00:04<01:09,  2.26s/it, loss=0[A
Training:   6%| | 2/32 [00:04<01:07,  2.25s/it, loss=0[A
Training:   6%| | 2/32 [00:06<01:07,  2.25s/it, loss=0[A
Training:   9%| | 3/32 [00:06<01:05,  2.25s/it, loss=0[A
Training:   9%| | 3/32 [00:09<01:05,  2.25s/it, loss=0[A
Training:  12%|▏| 4/32 [00:09<01:03,  2.25s/it, loss=0[A
Training:  12%|▏| 4/32 [00:11<01:03,  2.25s/it, loss=0[A
Training:  16%|▏| 5/32 [00:11<01:00,  2.25s/it, loss=0[A
Training:  16%|▏| 5/32 [00:13<01:00,  2.25s/it, loss=0[A
Training:  19%|▏| 6/32 [00:13<00:58,  2.25s/it, loss=0[A
Training:  19%|▏| 6/32 [00:15<00:58,  2.25s/it, loss=0[A
Training:  22%|▏| 7/32 [00:15<00:56,  2.25s/it, loss=0[A
Training:  22%|▏| 7/32 [00:18<00:56,  2.25s/it, loss=0[A
Training:  25%|▎| 8/32 [00:18<00:54,  2.25s/it, loss=0[A
Training:  25


New best model saved with validation F1-score: 0.8141
Validation metrics - Accuracy: 0.8681, Precision: 0.7382, Recall: 0.9084



Training:   0%|                | 0/32 [00:00<?, ?it/s][A
Training:   0%| | 0/32 [00:02<?, ?it/s, loss=0.2813, a[A
Training:   3%| | 1/32 [00:02<01:10,  2.26s/it, loss=0[A
Training:   3%| | 1/32 [00:04<01:10,  2.26s/it, loss=0[A
Training:   6%| | 2/32 [00:04<01:07,  2.26s/it, loss=0[A
Training:   6%| | 2/32 [00:06<01:07,  2.26s/it, loss=0[A
Training:   9%| | 3/32 [00:06<01:05,  2.25s/it, loss=0[A
Training:   9%| | 3/32 [00:09<01:05,  2.25s/it, loss=0[A
Training:  12%|▏| 4/32 [00:09<01:03,  2.25s/it, loss=0[A
Training:  12%|▏| 4/32 [00:11<01:03,  2.25s/it, loss=0[A
Training:  16%|▏| 5/32 [00:11<01:00,  2.25s/it, loss=0[A
Training:  16%|▏| 5/32 [00:13<01:00,  2.25s/it, loss=0[A
Training:  19%|▏| 6/32 [00:13<00:58,  2.25s/it, loss=0[A
Training:  19%|▏| 6/32 [00:15<00:58,  2.25s/it, loss=0[A
Training:  22%|▏| 7/32 [00:15<00:56,  2.25s/it, loss=0[A
Training:  22%|▏| 7/32 [00:18<00:56,  2.25s/it, loss=0[A
Training:  25%|▎| 8/32 [00:18<00:54,  2.25s/it, loss=0[A
Training:  25


New best model saved with validation F1-score: 0.8498
Validation metrics - Accuracy: 0.9039, Precision: 0.8416, Recall: 0.8584



Training:   0%|                | 0/32 [00:00<?, ?it/s][A
Training:   0%| | 0/32 [00:02<?, ?it/s, loss=0.2484, a[A
Training:   3%| | 1/32 [00:02<01:09,  2.25s/it, loss=0[A
Training:   3%| | 1/32 [00:04<01:09,  2.25s/it, loss=0[A
Training:   6%| | 2/32 [00:04<01:07,  2.25s/it, loss=0[A
Training:   6%| | 2/32 [00:06<01:07,  2.25s/it, loss=0[A
Training:   9%| | 3/32 [00:06<01:05,  2.25s/it, loss=0[A
Training:   9%| | 3/32 [00:09<01:05,  2.25s/it, loss=0[A
Training:  12%|▏| 4/32 [00:09<01:03,  2.25s/it, loss=0[A
Training:  12%|▏| 4/32 [00:11<01:03,  2.25s/it, loss=0[A
Training:  16%|▏| 5/32 [00:11<01:00,  2.25s/it, loss=0[A
Training:  16%|▏| 5/32 [00:13<01:00,  2.25s/it, loss=0[A
Training:  19%|▏| 6/32 [00:13<00:58,  2.25s/it, loss=0[A
Training:  19%|▏| 6/32 [00:15<00:58,  2.25s/it, loss=0[A
Training:  22%|▏| 7/32 [00:15<00:56,  2.25s/it, loss=0[A
Training:  22%|▏| 7/32 [00:18<00:56,  2.25s/it, loss=0[A
Training:  25%|▎| 8/32 [00:18<00:54,  2.25s/it, loss=0[A
Training:  25


New best model saved with validation F1-score: 0.8519
Validation metrics - Accuracy: 0.9022, Precision: 0.8209, Recall: 0.8856



Training:   0%|                | 0/32 [00:00<?, ?it/s][A
Training:   0%| | 0/32 [00:02<?, ?it/s, loss=0.2195, a[A
Training:   3%| | 1/32 [00:02<01:09,  2.25s/it, loss=0[A
Training:   3%| | 1/32 [00:04<01:09,  2.25s/it, loss=0[A
Training:   6%| | 2/32 [00:04<01:07,  2.25s/it, loss=0[A
Training:   6%| | 2/32 [00:06<01:07,  2.25s/it, loss=0[A
Training:   9%| | 3/32 [00:06<01:05,  2.25s/it, loss=0[A
Training:   9%| | 3/32 [00:09<01:05,  2.25s/it, loss=0[A
Training:  12%|▏| 4/32 [00:09<01:03,  2.25s/it, loss=0[A
Training:  12%|▏| 4/32 [00:11<01:03,  2.25s/it, loss=0[A
Training:  16%|▏| 5/32 [00:11<01:00,  2.25s/it, loss=0[A
Training:  16%|▏| 5/32 [00:13<01:00,  2.25s/it, loss=0[A
Training:  19%|▏| 6/32 [00:13<00:58,  2.25s/it, loss=0[A
Training:  19%|▏| 6/32 [00:15<00:58,  2.25s/it, loss=0[A
Training:  22%|▏| 7/32 [00:15<00:56,  2.25s/it, loss=0[A
Training:  22%|▏| 7/32 [00:18<00:56,  2.25s/it, loss=0[A
Training:  25%|▎| 8/32 [00:18<00:54,  2.26s/it, loss=0[A
Training:  25


New best model saved with validation F1-score: 0.8681
Validation metrics - Accuracy: 0.9163, Precision: 0.8680, Recall: 0.8683



Training:   0%|                | 0/32 [00:00<?, ?it/s][A
Training:   0%| | 0/32 [00:02<?, ?it/s, loss=0.2139, a[A
Training:   3%| | 1/32 [00:02<01:09,  2.26s/it, loss=0[A
Training:   3%| | 1/32 [00:04<01:09,  2.26s/it, loss=0[A
Training:   6%| | 2/32 [00:04<01:07,  2.25s/it, loss=0[A
Training:   6%| | 2/32 [00:06<01:07,  2.25s/it, loss=0[A
Training:   9%| | 3/32 [00:06<01:05,  2.25s/it, loss=0[A
Training:   9%| | 3/32 [00:09<01:05,  2.25s/it, loss=0[A
Training:  12%|▏| 4/32 [00:09<01:03,  2.25s/it, loss=0[A
Training:  12%|▏| 4/32 [00:11<01:03,  2.25s/it, loss=0[A
Training:  16%|▏| 5/32 [00:11<01:00,  2.25s/it, loss=0[A
Training:  16%|▏| 5/32 [00:13<01:00,  2.25s/it, loss=0[A
Training:  19%|▏| 6/32 [00:13<00:58,  2.25s/it, loss=0[A
Training:  19%|▏| 6/32 [00:15<00:58,  2.25s/it, loss=0[A
Training:  22%|▏| 7/32 [00:15<00:56,  2.25s/it, loss=0[A
Training:  22%|▏| 7/32 [00:18<00:56,  2.25s/it, loss=0[A
Training:  25%|▎| 8/32 [00:18<00:54,  2.25s/it, loss=0[A
Training:  25


New best model saved with validation F1-score: 0.8748
Validation metrics - Accuracy: 0.9216, Precision: 0.8890, Recall: 0.8612



Training:   0%|                               | 0/32 [00:00<?, ?it/s][A
Training:   0%| | 0/32 [00:02<?, ?it/s, loss=0.1934, acc=0.9252, f1=0[A
Training:   3%| | 1/32 [00:02<01:09,  2.26s/it, loss=0.1934, acc=0.92[A
Training:   3%| | 1/32 [00:04<01:09,  2.26s/it, loss=0.2109, acc=0.91[A
Training:   6%| | 2/32 [00:04<01:07,  2.26s/it, loss=0.2109, acc=0.91[A
Training:   6%| | 2/32 [00:06<01:07,  2.26s/it, loss=0.2135, acc=0.91[A
Training:   9%| | 3/32 [00:06<01:05,  2.26s/it, loss=0.2135, acc=0.91[A
Training:   9%| | 3/32 [00:09<01:05,  2.26s/it, loss=0.1928, acc=0.92[A
Training:  12%|▏| 4/32 [00:09<01:03,  2.25s/it, loss=0.1928, acc=0.92[A
Training:  12%|▏| 4/32 [00:11<01:03,  2.25s/it, loss=0.1893, acc=0.92[A
Training:  16%|▏| 5/32 [00:11<01:00,  2.25s/it, loss=0.1893, acc=0.92[A
Training:  16%|▏| 5/32 [00:13<01:00,  2.25s/it, loss=0.1919, acc=0.92[A
Training:  19%|▏| 6/32 [00:13<00:58,  2.25s/it, loss=0.1919, acc=0.92[A
Training:  19%|▏| 6/32 [00:15<00:58,  2.25s/it, lo


New best model saved with validation F1-score: 0.8769
Validation metrics - Accuracy: 0.9242, Precision: 0.9080, Recall: 0.8479



Training:   0%|                               | 0/32 [00:00<?, ?it/s][A
Training:   0%| | 0/32 [00:02<?, ?it/s, loss=0.1806, acc=0.9299, f1=0[A
Training:   3%| | 1/32 [00:02<01:09,  2.25s/it, loss=0.1806, acc=0.92[A
Training:   3%| | 1/32 [00:04<01:09,  2.25s/it, loss=0.1943, acc=0.92[A
Training:   6%| | 2/32 [00:04<01:07,  2.26s/it, loss=0.1943, acc=0.92[A
Training:   6%| | 2/32 [00:06<01:07,  2.26s/it, loss=0.1957, acc=0.92[A
Training:   9%| | 3/32 [00:06<01:05,  2.25s/it, loss=0.1957, acc=0.92[A
Training:   9%| | 3/32 [00:09<01:05,  2.25s/it, loss=0.1808, acc=0.92[A
Training:  12%|▏| 4/32 [00:09<01:03,  2.25s/it, loss=0.1808, acc=0.92[A
Training:  12%|▏| 4/32 [00:11<01:03,  2.25s/it, loss=0.1770, acc=0.93[A
Training:  16%|▏| 5/32 [00:11<01:00,  2.25s/it, loss=0.1770, acc=0.93[A
Training:  16%|▏| 5/32 [00:13<01:00,  2.25s/it, loss=0.1725, acc=0.93[A
Training:  19%|▏| 6/32 [00:13<00:58,  2.25s/it, loss=0.1725, acc=0.93[A
Training:  19%|▏| 6/32 [00:15<00:58,  2.25s/it, lo


New best model saved with validation F1-score: 0.8819
Validation metrics - Accuracy: 0.9266, Precision: 0.9035, Recall: 0.8613



Training:   0%|         | 0/32 [00:00<?, ?it/s][A
Training:   0%| | 0/32 [00:02<?, ?it/s, loss=0.[A
Training:   3%| | 1/32 [00:02<01:09,  2.25s/it,[A
Training:   3%| | 1/32 [00:04<01:09,  2.25s/it,[A
Training:   6%| | 2/32 [00:04<01:07,  2.25s/it,[A
Training:   6%| | 2/32 [00:06<01:07,  2.25s/it,[A
Training:   9%| | 3/32 [00:06<01:05,  2.26s/it,[A
Training:   9%| | 3/32 [00:09<01:05,  2.26s/it,[A
Training:  12%|▏| 4/32 [00:09<01:03,  2.25s/it,[A
Training:  12%|▏| 4/32 [00:11<01:03,  2.25s/it,[A
Training:  16%|▏| 5/32 [00:11<01:00,  2.25s/it,[A
Training:  16%|▏| 5/32 [00:13<01:00,  2.25s/it,[A
Training:  19%|▏| 6/32 [00:13<00:58,  2.26s/it,[A
Training:  19%|▏| 6/32 [00:15<00:58,  2.26s/it,[A
Training:  22%|▏| 7/32 [00:15<00:56,  2.26s/it,[A
Training:  22%|▏| 7/32 [00:18<00:56,  2.26s/it,[A
Training:  25%|▎| 8/32 [00:18<00:54,  2.25s/it,[A
Training:  25%|▎| 8/32 [00:20<00:54,  2.25s/it,[A
Training:  28%|▎| 9/32 [00:20<00:51,  2.25s/it,[A
Training:  28%|▎| 9/32 [00:22<