In [1]:
import os
os.chdir("..")
print("Current Directory:", os.getcwd())

Current Directory: /workspace/iscat


In [2]:
import h5py
import numpy as np
particle_data_path ='dataset/brightfield_particles.hdf5'
with h5py.File(particle_data_path , 'r') as f:
    print(f['data'].shape)
    print(np.unique(f['labels'],return_counts=True))

(27282, 16, 201)
(array([0, 1, 2]), array([22076,  5146,    60]))


In [3]:
weights = 1/np.array([22076,  5146])

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np
# from torchvision.models import vit_b_16
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
from torchvision.models.vision_transformer import VisionTransformer
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, f1_score, balanced_accuracy_score
import matplotlib.pyplot as plt

def compute_normalization_stats(h5_path, classes=None):
    """
    Compute mean and standard deviation for z-score normalization.
    
    Args:
        h5_path (str): Path to HDF5 file
        classes (list, optional): List of classes to include in computation
        
    Returns:
        tuple: (mean, std) computed across all data points
    """
    with h5py.File(h5_path, 'r') as h5_file:
        data = h5_file['data'][:]
        labels = h5_file['labels'][:]
        
        if classes is not None:
            # Filter data for selected classes
            mask = np.isin(labels, classes)
            data = data[mask]
        
        # Compute statistics across all dimensions
        mean = np.mean(data)
        std = np.std(data)
        
        print(f"Computed statistics: mean = {mean:.4f}, std = {std:.4f}")
        
        return mean, std
        
class ParticleDataset(Dataset):
    """Custom Dataset for particle data with flexible class selection and normalization."""
    def __init__(self, h5_path, classes=[0, 1], transform=None, mean=None, std=None):
        self.h5_file = h5py.File(h5_path, 'r')
        data = self.h5_file['data'][:]
        labels = self.h5_file['labels'][:]
        
        # Filter data for selected classes
        mask = np.isin(labels, classes)
        self.data = data[mask]
        self.labels = labels[mask]
        
        # Create class mapping to handle non-consecutive class indices
        self.class_to_idx = {c: i for i, c in enumerate(classes)}
        self.num_classes = len(classes)
        
        # Map original labels to new consecutive indices
        self.labels = np.array([self.class_to_idx[label] for label in self.labels])
        
        self.transform = transform
        self.mean = mean
        self.std = std
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # Get particle data
        particle = self.data[idx]  # Shape: (16, 201)
        
        # Apply normalization if mean and std are provided
        if self.mean is not None and self.std is not None:
            particle = (particle - self.mean) / self.std
        
        # First resize to (16, 192) using cubic interpolation
        # Convert to torch tensor for better interpolation
        particle_tensor = torch.FloatTensor(particle).unsqueeze(0)  # Add channel dim
        resized = torch.nn.functional.interpolate(
            particle_tensor.unsqueeze(0),  # Add batch dim
            size=(16, 192),
            mode='bicubic',
            align_corners=True
        ).squeeze(0).squeeze(0)  # Remove batch and channel dims
        
        # Create square tensor with symmetric padding
        target_size = 192
        current_height = resized.shape[0]  # 16
        padding_height = target_size - current_height
        padding_top = padding_height // 2
        padding_bottom = padding_height - padding_top
        
        # Use torch's pad function for symmetric padding
        padded = torch.nn.functional.pad(
            resized,
            (0, 0,                # No padding in width dimension
             padding_top, padding_bottom),  # Padding in height dimension
            mode='constant',
            value=0
        )
        
        # Add single channel dimension
        final_tensor = padded.unsqueeze(0)
        
        if self.transform:
            final_tensor = self.transform(final_tensor)
        
        # Create one-hot encoded label
        label_idx = self.labels[idx]
        label_onehot = torch.zeros(self.num_classes)
        label_onehot[label_idx] = 1
        
        return final_tensor, label_onehot

    def close(self):
        self.h5_file.close()
        
class ModifiedViT(nn.Module):
    def __init__(self, num_classes=2,patch_size=16,num_layers=12,num_heads=12,hidden_dim =768,mlp_dim=3072):
        super().__init__()
        # Load pretrained ViT
        self.vit = VisionTransformer(image_size=192,
        patch_size=16,
        num_classes = num_classes,
        num_layers=12,
        num_heads=12,
        hidden_dim=768,
        mlp_dim=3072,
    )
        
        self.vit.conv_proj = nn.Conv2d(
                    in_channels=1, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
                ) 
    def forward(self, x):
        return self.vit(x)
        
def train_model(model, train_loader, val_loader, device, num_epochs=50, weights=None, patience=20):
    criterion = nn.CrossEntropyLoss(weight=torch.Tensor(weights).to(device)) if weights is not None else nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    best_val_acc = 0.0
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    val_accuracies = []
    val_f1_scores = []
    val_bal_accs = []
    
    early_stopping_counter = 0
    
    for epoch in range(num_epochs):
        # Training Phase
        model.train()
        total_train_loss = 0
        
        for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} Training'):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        
        avg_train_loss = total_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        # Validation Phase
        model.eval()
        val_preds = []
        val_true = []
        total_val_loss = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                val_loss = criterion(outputs, labels)
                total_val_loss += val_loss.item()
                
                predicted = torch.argmax(outputs, 1)
                labels = torch.argmax(labels, 1)
                val_preds.extend(predicted.cpu().numpy())
                val_true.extend(labels.cpu().numpy())
        
        avg_val_loss = total_val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        
        # Calculate metrics
        val_acc = 100 * (sum([p == t for p, t in zip(val_preds, val_true)]) / len(val_true))
        f1 = f1_score(val_true, val_preds, average='weighted')
        bal_acc = balanced_accuracy_score(val_true, val_preds)
        
        # Confusion matrix metrics for class 0
        tp = sum((np.array(val_preds) == 0) & (np.array(val_true) == 0))
        tn = sum((np.array(val_preds) == 1) & (np.array(val_true) == 1))
        fp = sum((np.array(val_preds) == 0) & (np.array(val_true) == 1))
        fn = sum((np.array(val_preds) == 1) & (np.array(val_true) == 0))
        
        val_accuracies.append(val_acc)
        val_f1_scores.append(f1)
        val_bal_accs.append(bal_acc)
        
        # Print epoch details
        print(f'Epoch {epoch+1}:')
        print(f'  Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}')
        print(f'  Val Accuracy = {val_acc:.2f}%, F1-Score = {f1:.4f}, Balanced Acc = {bal_acc:.4f}')
        print(f'  Class 0 Metrics:')
        print(f'    True Positives: {tp}')
        print(f'    True Negatives: {tn}')
        print(f'    False Positives: {fp}')
        print(f'    False Negatives: {fn}')
        print(f'    Precision: {tp/(tp+fp) if (tp+fp) > 0 else 0:.4f}')
        print(f'    Recall: {tp/(tp+fn) if (tp+fn) > 0 else 0:.4f}')
        
        # Early Stopping and Model Saving
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            early_stopping_counter = 0
            torch.save(model.state_dict(), 'best_particle_vit.pth')
        
        # Check for improvement in validation loss
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1
        
        # Early stopping
        if early_stopping_counter >= patience:
            print(f'Early stopping triggered after {epoch+1} epochs')
            break
        
        scheduler.step()
    
    return (train_losses, val_losses, val_accuracies, 
            val_f1_scores, val_bal_accs, val_preds, val_true)

def plot_training_results(train_losses, val_accuracies, val_f1_scores, val_bal_accs, val_preds, val_true):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
    
    ax1.plot(train_losses, label='Loss')
    ax1.set_title('Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    
    ax2.plot(val_accuracies, label='Accuracy')
    ax2.set_title('Validation Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    
    ax3.plot(val_f1_scores, label='F1-Score')
    ax3.plot(val_bal_accs, label='Balanced Accuracy')
    ax3.set_title('Validation Metrics')
    ax3.set_xlabel('Epoch')
    ax3.legend()
    
    plt.tight_layout()
    plt.savefig('training_plots.png')
    
    cm = confusion_matrix(val_true, val_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig('confusion_matrix.png')


In [5]:
# Set random seed for reproducibility
torch.manual_seed(42)
DEVICE = "cuda:7"
# Device configuration
device = torch.device(DEVICE if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
classes = [0, 1]
mean, std = compute_normalization_stats('dataset/brightfield_particles.hdf5', classes=classes)
# Load and split dataset
dataset = ParticleDataset(h5_path='dataset/brightfield_particles.hdf5',
                          classes=classes,
                          mean=mean,
                          std=std
                         )
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)
_,c=np.unique(dataset.labels[train_dataset.indices],return_counts=True)
print(c[1]/(c[0]+c[1]))
_,c=np.unique(dataset.labels[val_dataset.indices],return_counts=True)
print(c[1]/(c[0]+c[1]))
# Create data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=32,
    shuffle=True,
    num_workers=6,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=6,
    pin_memory=True
)

# Initialize model
model = ModifiedViT(num_classes=len(classes))
model = model.to(device)

# Train model
train_losses, val_accuracies, val_preds, val_true = train_model(
    model, train_loader, val_loader, device
)


Using device: cuda:7
Computed statistics: mean = 7077.7081, std = 1176.7393
0.1875832300133168
0.1948576675849403


Epoch 1/50 Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.23it/s]


Epoch 1:
  Train Loss = 0.4275, Val Loss = 0.3457
  Val Accuracy = 83.49%, F1-Score = 0.7893, Balanced Acc = 0.5892
  Class 0 Metrics:
    True Positives: 4348
    True Negatives: 198
    False Positives: 863
    False Negatives: 36
    Precision: 0.8344
    Recall: 0.9918


Epoch 2/50 Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:39<00:00,  4.26it/s]


Epoch 2:
  Train Loss = 0.3168, Val Loss = 0.2742
  Val Accuracy = 85.75%, F1-Score = 0.8667, Balanced Acc = 0.8618
  Class 0 Metrics:
    True Positives: 3747
    True Negatives: 922
    False Positives: 139
    False Negatives: 637
    Precision: 0.9642
    Recall: 0.8547


Epoch 3/50 Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.25it/s]


Epoch 3:
  Train Loss = 0.2209, Val Loss = 0.1614
  Val Accuracy = 94.14%, F1-Score = 0.9419, Balanced Acc = 0.9154
  Class 0 Metrics:
    True Positives: 4200
    True Negatives: 926
    False Positives: 135
    False Negatives: 184
    Precision: 0.9689
    Recall: 0.9580


Epoch 4/50 Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.25it/s]


Epoch 4:
  Train Loss = 0.1377, Val Loss = 0.1370
  Val Accuracy = 94.88%, F1-Score = 0.9472, Balanced Acc = 0.8910
  Class 0 Metrics:
    True Positives: 4321
    True Negatives: 845
    False Positives: 216
    False Negatives: 63
    Precision: 0.9524
    Recall: 0.9856


Epoch 5/50 Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.25it/s]


Epoch 5:
  Train Loss = 0.1261, Val Loss = 0.1314
  Val Accuracy = 94.93%, F1-Score = 0.9479, Balanced Acc = 0.8946
  Class 0 Metrics:
    True Positives: 4315
    True Negatives: 854
    False Positives: 207
    False Negatives: 69
    Precision: 0.9542
    Recall: 0.9843


Epoch 6/50 Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.25it/s]


Epoch 6:
  Train Loss = 0.1323, Val Loss = 0.1218
  Val Accuracy = 95.96%, F1-Score = 0.9596, Balanced Acc = 0.9356
  Class 0 Metrics:
    True Positives: 4274
    True Negatives: 951
    False Positives: 110
    False Negatives: 110
    Precision: 0.9749
    Recall: 0.9749


Epoch 7/50 Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 7:
  Train Loss = 0.1386, Val Loss = 0.1318
  Val Accuracy = 95.28%, F1-Score = 0.9529, Balanced Acc = 0.9271
  Class 0 Metrics:
    True Positives: 4249
    True Negatives: 939
    False Positives: 122
    False Negatives: 135
    Precision: 0.9721
    Recall: 0.9692


Epoch 8/50 Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 8:
  Train Loss = 0.1172, Val Loss = 0.1219
  Val Accuracy = 95.67%, F1-Score = 0.9572, Balanced Acc = 0.9452
  Class 0 Metrics:
    True Positives: 4226
    True Negatives: 983
    False Positives: 78
    False Negatives: 158
    Precision: 0.9819
    Recall: 0.9640


Epoch 9/50 Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 9:
  Train Loss = 0.1154, Val Loss = 0.1184
  Val Accuracy = 95.85%, F1-Score = 0.9580, Balanced Acc = 0.9228
  Class 0 Metrics:
    True Positives: 4302
    True Negatives: 917
    False Positives: 144
    False Negatives: 82
    Precision: 0.9676
    Recall: 0.9813


Epoch 10/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 10:
  Train Loss = 0.1126, Val Loss = 0.1251
  Val Accuracy = 95.50%, F1-Score = 0.9544, Balanced Acc = 0.9153
  Class 0 Metrics:
    True Positives: 4298
    True Negatives: 902
    False Positives: 159
    False Negatives: 86
    Precision: 0.9643
    Recall: 0.9804


Epoch 11/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 11:
  Train Loss = 0.1206, Val Loss = 0.1179
  Val Accuracy = 95.48%, F1-Score = 0.9556, Balanced Acc = 0.9477
  Class 0 Metrics:
    True Positives: 4206
    True Negatives: 993
    False Positives: 68
    False Negatives: 178
    Precision: 0.9841
    Recall: 0.9594


Epoch 12/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 12:
  Train Loss = 0.1033, Val Loss = 0.1161
  Val Accuracy = 95.52%, F1-Score = 0.9539, Balanced Acc = 0.9029
  Class 0 Metrics:
    True Positives: 4334
    True Negatives: 867
    False Positives: 194
    False Negatives: 50
    Precision: 0.9572
    Recall: 0.9886


Epoch 13/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 13:
  Train Loss = 0.1004, Val Loss = 0.1784
  Val Accuracy = 93.92%, F1-Score = 0.9357, Balanced Acc = 0.8547
  Class 0 Metrics:
    True Positives: 4354
    True Negatives: 760
    False Positives: 301
    False Negatives: 30
    Precision: 0.9353
    Recall: 0.9932


Epoch 14/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 14:
  Train Loss = 0.0950, Val Loss = 0.1192
  Val Accuracy = 95.48%, F1-Score = 0.9536, Balanced Acc = 0.9030
  Class 0 Metrics:
    True Positives: 4331
    True Negatives: 868
    False Positives: 193
    False Negatives: 53
    Precision: 0.9573
    Recall: 0.9879


Epoch 15/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:41<00:00,  4.23it/s]


Epoch 15:
  Train Loss = 0.0954, Val Loss = 0.0934
  Val Accuracy = 96.82%, F1-Score = 0.9680, Balanced Acc = 0.9413
  Class 0 Metrics:
    True Positives: 4320
    True Negatives: 952
    False Positives: 109
    False Negatives: 64
    Precision: 0.9754
    Recall: 0.9854


Epoch 16/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:41<00:00,  4.23it/s]


Epoch 16:
  Train Loss = 0.1086, Val Loss = 0.0989
  Val Accuracy = 96.73%, F1-Score = 0.9671, Balanced Acc = 0.9418
  Class 0 Metrics:
    True Positives: 4312
    True Negatives: 955
    False Positives: 106
    False Negatives: 72
    Precision: 0.9760
    Recall: 0.9836


Epoch 17/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 17:
  Train Loss = 0.1063, Val Loss = 0.1153
  Val Accuracy = 95.85%, F1-Score = 0.9577, Balanced Acc = 0.9156
  Class 0 Metrics:
    True Positives: 4322
    True Negatives: 897
    False Positives: 164
    False Negatives: 62
    Precision: 0.9634
    Recall: 0.9859


Epoch 18/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 18:
  Train Loss = 0.1055, Val Loss = 0.1090
  Val Accuracy = 96.07%, F1-Score = 0.9609, Balanced Acc = 0.9427
  Class 0 Metrics:
    True Positives: 4262
    True Negatives: 969
    False Positives: 92
    False Negatives: 122
    Precision: 0.9789
    Recall: 0.9722


Epoch 19/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 19:
  Train Loss = 0.0907, Val Loss = 0.1024
  Val Accuracy = 96.40%, F1-Score = 0.9643, Balanced Acc = 0.9523
  Class 0 Metrics:
    True Positives: 4259
    True Negatives: 990
    False Positives: 71
    False Negatives: 125
    Precision: 0.9836
    Recall: 0.9715


Epoch 20/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 20:
  Train Loss = 0.0869, Val Loss = 0.0921
  Val Accuracy = 96.71%, F1-Score = 0.9669, Balanced Acc = 0.9399
  Class 0 Metrics:
    True Positives: 4316
    True Negatives: 950
    False Positives: 111
    False Negatives: 68
    Precision: 0.9749
    Recall: 0.9845


Epoch 21/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 21:
  Train Loss = 0.0792, Val Loss = 0.0912
  Val Accuracy = 96.88%, F1-Score = 0.9684, Balanced Acc = 0.9395
  Class 0 Metrics:
    True Positives: 4329
    True Negatives: 946
    False Positives: 115
    False Negatives: 55
    Precision: 0.9741
    Recall: 0.9875


Epoch 22/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 22:
  Train Loss = 0.0764, Val Loss = 0.0927
  Val Accuracy = 96.55%, F1-Score = 0.9659, Balanced Acc = 0.9578
  Class 0 Metrics:
    True Positives: 4254
    True Negatives: 1003
    False Positives: 58
    False Negatives: 130
    Precision: 0.9865
    Recall: 0.9703


Epoch 23/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:41<00:00,  4.22it/s]


Epoch 23:
  Train Loss = 0.0732, Val Loss = 0.0891
  Val Accuracy = 96.95%, F1-Score = 0.9697, Balanced Acc = 0.9578
  Class 0 Metrics:
    True Positives: 4283
    True Negatives: 996
    False Positives: 65
    False Negatives: 101
    Precision: 0.9851
    Recall: 0.9770


Epoch 24/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:41<00:00,  4.22it/s]


Epoch 24:
  Train Loss = 0.0729, Val Loss = 0.0900
  Val Accuracy = 97.08%, F1-Score = 0.9706, Balanced Acc = 0.9458
  Class 0 Metrics:
    True Positives: 4326
    True Negatives: 960
    False Positives: 101
    False Negatives: 58
    Precision: 0.9772
    Recall: 0.9868


Epoch 25/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:41<00:00,  4.23it/s]


Epoch 25:
  Train Loss = 0.0690, Val Loss = 0.0827
  Val Accuracy = 97.21%, F1-Score = 0.9719, Balanced Acc = 0.9491
  Class 0 Metrics:
    True Positives: 4326
    True Negatives: 967
    False Positives: 94
    False Negatives: 58
    Precision: 0.9787
    Recall: 0.9868


Epoch 26/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 26:
  Train Loss = 0.0663, Val Loss = 0.0867
  Val Accuracy = 97.19%, F1-Score = 0.9715, Balanced Acc = 0.9411
  Class 0 Metrics:
    True Positives: 4347
    True Negatives: 945
    False Positives: 116
    False Negatives: 37
    Precision: 0.9740
    Recall: 0.9916


Epoch 27/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 27:
  Train Loss = 0.0664, Val Loss = 0.0821
  Val Accuracy = 97.21%, F1-Score = 0.9720, Balanced Acc = 0.9509
  Class 0 Metrics:
    True Positives: 4321
    True Negatives: 972
    False Positives: 89
    False Negatives: 63
    Precision: 0.9798
    Recall: 0.9856


Epoch 28/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:41<00:00,  4.23it/s]


Epoch 28:
  Train Loss = 0.0628, Val Loss = 0.0804
  Val Accuracy = 97.39%, F1-Score = 0.9738, Balanced Acc = 0.9549
  Class 0 Metrics:
    True Positives: 4323
    True Negatives: 980
    False Positives: 81
    False Negatives: 61
    Precision: 0.9816
    Recall: 0.9861


Epoch 29/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.23it/s]


Epoch 29:
  Train Loss = 0.0602, Val Loss = 0.0778
  Val Accuracy = 97.34%, F1-Score = 0.9736, Balanced Acc = 0.9652
  Class 0 Metrics:
    True Positives: 4290
    True Negatives: 1010
    False Positives: 51
    False Negatives: 94
    Precision: 0.9883
    Recall: 0.9786


Epoch 30/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 30:
  Train Loss = 0.0570, Val Loss = 0.0787
  Val Accuracy = 97.17%, F1-Score = 0.9718, Balanced Acc = 0.9589
  Class 0 Metrics:
    True Positives: 4296
    True Negatives: 995
    False Positives: 66
    False Negatives: 88
    Precision: 0.9849
    Recall: 0.9799


Epoch 31/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 31:
  Train Loss = 0.0554, Val Loss = 0.0796
  Val Accuracy = 97.30%, F1-Score = 0.9727, Balanced Acc = 0.9472
  Class 0 Metrics:
    True Positives: 4338
    True Negatives: 960
    False Positives: 101
    False Negatives: 46
    Precision: 0.9772
    Recall: 0.9895


Epoch 32/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.25it/s]


Epoch 32:
  Train Loss = 0.0522, Val Loss = 0.0728
  Val Accuracy = 97.69%, F1-Score = 0.9768, Balanced Acc = 0.9610
  Class 0 Metrics:
    True Positives: 4327
    True Negatives: 992
    False Positives: 69
    False Negatives: 57
    Precision: 0.9843
    Recall: 0.9870


Epoch 33/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 33:
  Train Loss = 0.0515, Val Loss = 0.0786
  Val Accuracy = 97.37%, F1-Score = 0.9739, Balanced Acc = 0.9662
  Class 0 Metrics:
    True Positives: 4290
    True Negatives: 1012
    False Positives: 49
    False Negatives: 94
    Precision: 0.9887
    Recall: 0.9786


Epoch 34/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 34:
  Train Loss = 0.0471, Val Loss = 0.0739
  Val Accuracy = 97.63%, F1-Score = 0.9763, Balanced Acc = 0.9621
  Class 0 Metrics:
    True Positives: 4320
    True Negatives: 996
    False Positives: 65
    False Negatives: 64
    Precision: 0.9852
    Recall: 0.9854


Epoch 35/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.25it/s]


Epoch 35:
  Train Loss = 0.0459, Val Loss = 0.0752
  Val Accuracy = 97.67%, F1-Score = 0.9764, Balanced Acc = 0.9534
  Class 0 Metrics:
    True Positives: 4347
    True Negatives: 971
    False Positives: 90
    False Negatives: 37
    Precision: 0.9797
    Recall: 0.9916


Epoch 36/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.25it/s]


Epoch 36:
  Train Loss = 0.0427, Val Loss = 0.0707
  Val Accuracy = 97.74%, F1-Score = 0.9774, Balanced Acc = 0.9631
  Class 0 Metrics:
    True Positives: 4325
    True Negatives: 997
    False Positives: 64
    False Negatives: 59
    Precision: 0.9854
    Recall: 0.9865


Epoch 37/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.25it/s]


Epoch 37:
  Train Loss = 0.0403, Val Loss = 0.0705
  Val Accuracy = 97.70%, F1-Score = 0.9770, Balanced Acc = 0.9607
  Class 0 Metrics:
    True Positives: 4329
    True Negatives: 991
    False Positives: 70
    False Negatives: 55
    Precision: 0.9841
    Recall: 0.9875


Epoch 38/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.23it/s]


Epoch 38:
  Train Loss = 0.0388, Val Loss = 0.0799
  Val Accuracy = 97.61%, F1-Score = 0.9759, Balanced Acc = 0.9541
  Class 0 Metrics:
    True Positives: 4341
    True Negatives: 974
    False Positives: 87
    False Negatives: 43
    Precision: 0.9804
    Recall: 0.9902


Epoch 39/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 39:
  Train Loss = 0.0367, Val Loss = 0.0791
  Val Accuracy = 97.74%, F1-Score = 0.9773, Balanced Acc = 0.9599
  Class 0 Metrics:
    True Positives: 4334
    True Negatives: 988
    False Positives: 73
    False Negatives: 50
    Precision: 0.9834
    Recall: 0.9886


Epoch 40/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.25it/s]


Epoch 40:
  Train Loss = 0.0348, Val Loss = 0.0769
  Val Accuracy = 97.78%, F1-Score = 0.9777, Balanced Acc = 0.9619
  Class 0 Metrics:
    True Positives: 4331
    True Negatives: 993
    False Positives: 68
    False Negatives: 53
    Precision: 0.9845
    Recall: 0.9879


Epoch 41/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.25it/s]


Epoch 41:
  Train Loss = 0.0329, Val Loss = 0.0797
  Val Accuracy = 97.74%, F1-Score = 0.9772, Balanced Acc = 0.9556
  Class 0 Metrics:
    True Positives: 4346
    True Negatives: 976
    False Positives: 85
    False Negatives: 38
    Precision: 0.9808
    Recall: 0.9913


Epoch 42/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.25it/s]


Epoch 42:
  Train Loss = 0.0308, Val Loss = 0.0809
  Val Accuracy = 97.67%, F1-Score = 0.9765, Balanced Acc = 0.9544
  Class 0 Metrics:
    True Positives: 4344
    True Negatives: 974
    False Positives: 87
    False Negatives: 40
    Precision: 0.9804
    Recall: 0.9909


Epoch 43/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.25it/s]


Epoch 43:
  Train Loss = 0.0296, Val Loss = 0.0771
  Val Accuracy = 97.87%, F1-Score = 0.9786, Balanced Acc = 0.9625
  Class 0 Metrics:
    True Positives: 4336
    True Negatives: 993
    False Positives: 68
    False Negatives: 48
    Precision: 0.9846
    Recall: 0.9891


Epoch 44/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.25it/s]


Epoch 44:
  Train Loss = 0.0280, Val Loss = 0.0832
  Val Accuracy = 97.76%, F1-Score = 0.9775, Balanced Acc = 0.9611
  Class 0 Metrics:
    True Positives: 4332
    True Negatives: 991
    False Positives: 70
    False Negatives: 52
    Precision: 0.9841
    Recall: 0.9881


Epoch 45/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.25it/s]


Epoch 45:
  Train Loss = 0.0277, Val Loss = 0.0806
  Val Accuracy = 97.81%, F1-Score = 0.9781, Balanced Acc = 0.9654
  Class 0 Metrics:
    True Positives: 4324
    True Negatives: 1002
    False Positives: 59
    False Negatives: 60
    Precision: 0.9865
    Recall: 0.9863


Epoch 46/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 46:
  Train Loss = 0.0264, Val Loss = 0.0796
  Val Accuracy = 97.87%, F1-Score = 0.9787, Balanced Acc = 0.9650
  Class 0 Metrics:
    True Positives: 4329
    True Negatives: 1000
    False Positives: 61
    False Negatives: 55
    Precision: 0.9861
    Recall: 0.9875


Epoch 47/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 47:
  Train Loss = 0.0253, Val Loss = 0.0809
  Val Accuracy = 97.85%, F1-Score = 0.9785, Balanced Acc = 0.9634
  Class 0 Metrics:
    True Positives: 4332
    True Negatives: 996
    False Positives: 65
    False Negatives: 52
    Precision: 0.9852
    Recall: 0.9881


Epoch 48/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.25it/s]


Epoch 48:
  Train Loss = 0.0247, Val Loss = 0.0807
  Val Accuracy = 97.85%, F1-Score = 0.9785, Balanced Acc = 0.9649
  Class 0 Metrics:
    True Positives: 4328
    True Negatives: 1000
    False Positives: 61
    False Negatives: 56
    Precision: 0.9861
    Recall: 0.9872


Epoch 49/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 49:
  Train Loss = 0.0244, Val Loss = 0.0808
  Val Accuracy = 97.89%, F1-Score = 0.9789, Balanced Acc = 0.9651
  Class 0 Metrics:
    True Positives: 4330
    True Negatives: 1000
    False Positives: 61
    False Negatives: 54
    Precision: 0.9861
    Recall: 0.9877


Epoch 50/50 Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 681/681 [02:40<00:00,  4.24it/s]


Epoch 50:
  Train Loss = 0.0241, Val Loss = 0.0809
  Val Accuracy = 97.89%, F1-Score = 0.9789, Balanced Acc = 0.9651
  Class 0 Metrics:
    True Positives: 4330
    True Negatives: 1000
    False Positives: 61
    False Negatives: 54
    Precision: 0.9861
    Recall: 0.9877


ValueError: too many values to unpack (expected 4)

In [6]:
train_losses, val_accuracies, val_preds, val_true

NameError: name 'train_losses' is not defined

In [None]:
# Plot results
plot_training_results(train_losses, val_accuracies, val_preds, val_true)

# Close HDF5 file
dataset.close()

In [None]:
dataset = ParticleDataset('dataset/brightfield_particles.hdf5')

In [None]:
from torchvision.transforms import Resize
import torch
import matplotlib.pyplot as plt
with h5py.File('dataset/brightfield_particles.hdf5', 'r') as f:
    particle_1 = f['data'][f['labels'][:]==0,:][4]
    particle_2 = f['data'][f['labels'][:]==1,:][20]
img = torch.Tensor(particle_1)
print(img.shape)
img = Resize((16,192))(img.unsqueeze(0))
img = img.squeeze(0)
fig, ax = plt.subplots(1, 1, figsize=(30, 5), dpi=300) 
plt.axis('off')
plt.imshow(img,cmap='gray')
plt.show()

In [None]:
img = torch.Tensor(particle_2)
img = Resize((16,192))(img.unsqueeze(0))
img = img.squeeze(0)
fig, ax = plt.subplots(1, 1, figsize=(30, 5), dpi=300) 
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.show()