In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import v2 as transforms
from torch.utils.data import DataLoader
import time
from sklearn import metrics
import wandb

In [4]:
## Import dataset modules and model
from dataset_classification_vindr import MakeDataset_VinDr_classification
from models.mvswintransformer import MVSwinTransformer

In [3]:
## Devices
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Available GPUs: ", torch.cuda.device_count())
print("Current device ID: ", torch.cuda.current_device())

Available GPUs:  2
Current device ID:  0
NVIDIA A30
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB
NVIDIA A30
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [5]:
## Configuration
extension = ".png"
target_size = 384
window_size = 12

In [6]:
# Define hyperparameters
batch_size = 16
learning_rate = 1e-4
epochs = 2
threshold = 0.5
weight_decay = 1e-3

In [8]:
## Data Paths
image_dir = "./dataset/VinDr_Mammo/Images_Processed_CLAHE"
label_dir_csv ="./dataset/VinDr_Mammo/breast-level_annotations.csv"

In [None]:
## Data Loaders

transform = transforms.Compose([transforms.Resize((target_size, target_size)), transforms.ToTensor()])

train_dataloader = MakeDataset_VinDr_classification(image_dir = image_dir,
                                                        label_dir_csv = label_dir_csv,
                                                        transform=transform,
                                                        mode='train',
                                                        split_size= 0.2,
                                                        target_size= target_size)

val_dataloader = MakeDataset_VinDr_classification(image_dir = image_dir,
                                                        label_dir_csv = label_dir_csv,
                                                        transform=transform,
                                                        mode='val',
                                                        split_size= 0.2,
                                                        target_size= target_size)

test_dataloader = MakeDataset_VinDr_classification(image_dir = image_dir,
                                                        label_dir_csv = label_dir_csv,
                                                        transform=transform,
                                                        mode='test',
                                                        split_size= None,
                                                        target_size= target_size)

train_loader = DataLoader(train_dataloader, batch_size=batch_size, shuffle=True, num_workers=64)
val_loader = DataLoader(val_dataloader, batch_size=batch_size, num_workers=64)
test_loader = DataLoader(test_dataloader, batch_size=batch_size, num_workers=64)

In [None]:
# Create an instance of your model
model = MVSwinTransformer(img_size= target_size, window_size= window_size).to(device)

In [12]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
pytorch_total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total number of parameters: ", pytorch_total_params // 10 ** 6, " mil")
print("Total number of trainable parameters: ", pytorch_total_trainable_params // 10 ** 6, " mil")

Total number of parameters:  29  mil
Total number of trainable parameters:  29  mil


In [13]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [14]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience= 20)
criterion = nn.BCELoss()
early_stopper = EarlyStopper(patience=75, min_delta=0)
curr_best_val_acc = 0.0

In [16]:
# Training loop
for epoch in range(1, epochs + 1):
    since = time.time()
    print('-' * 10)
    model.train()  # Set the model to training mode
    running_loss = 0.0
    print("#########Epoch: ", epoch)
    total = 0
    correct = 0
    for i, data in enumerate(train_loader):
        inputs_cc, inputs_mlo, labels = data
        inputs_cc, inputs_mlo, labels = inputs_cc.float().to(device), inputs_mlo.float().to(device), labels.float().to(device)

        labels = labels.unsqueeze(1)
        predicted = model(inputs_cc, inputs_mlo)
        
        total_loss = criterion(predicted, labels)
        predicted = (predicted > threshold).float()
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
      
        running_loss += total_loss.item()
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    # Print the average loss for this epoch
    running_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch}, Loss: {running_loss}")
    
    time_elapsed = time.time() - since
    curr_lr = optimizer.param_groups[0]["lr"]
    print('Training complete in {:.0f}m {:.0f}s and current learning rate is {}.'.format(time_elapsed // 60, time_elapsed % 60, curr_lr))
    
    # Calculate accuracy for training
    train_acc = 100 * correct / total
    print(f'Accuracy of the network on the train images: {train_acc:.3f} %')

    # Validation loop
    model.eval()  # Set the model to evaluation mode
    total = 0
    correct = 0
    running_val_loss = 0.0
    predicted_prob_val = []
    true_labels_val = []

    with torch.no_grad():
        for data in val_loader:  # Use val_loader for validation
            inputs_cc, inputs_mlo, labels = data
            inputs_cc, inputs_mlo, labels = inputs_cc.float().to(device), inputs_mlo.float().to(device), labels.float().to(device)
            
            labels = labels.unsqueeze(1)
            predicted = model(inputs_cc, inputs_mlo)
            
            total_loss = criterion(predicted, labels)
            predicted = (predicted > threshold).float()
            
            running_val_loss += total_loss.item()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    # Calculate accuracy for validation
    val_acc = 100 * correct / total
    running_val_loss = running_val_loss / len(val_loader)
    print(f'Accuracy of the network on the val images: {val_acc:.3f} % and val loss: {running_val_loss:0.5f}')
    
    scheduler.step(running_val_loss)
    

----------
#########Epoch:  1
Epoch 1, Loss: 0.20529883645428346
Training complete in 5m 8s and current learning rate is 0.0001.
Accuracy of the network on the train images: 94.924 %
Accuracy of the network on the val images: 95.050 % and val loss: 4.95000
----------
#########Epoch:  2
Epoch 2, Loss: 0.198341371409595
Training complete in 5m 9s and current learning rate is 0.0001.
Accuracy of the network on the train images: 95.062 %
Accuracy of the network on the val images: 95.050 % and val loss: 4.95000
