In [None]:
import torch
import torch.nn as nn
import torchvision
from model_file import model
from vce_dataloader import getBinaryDataLoader, getAllDataLoader, visualize_batch

In [None]:
binaryDL_train = getBinaryDataLoader(image_size=(224,224), target_class_name="Normal", path_to_dataset="/kaggle/input/vce-dataset/training",batch_size=32)
binaryDL_val = getBinaryDataLoader(image_size=(224,224), target_class_name="Normal", path_to_dataset="/kaggle/input/vce-dataset/validation",batch_size=32)
visualize_batch(binaryDL_train, nrow = 8)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
for images, labels in binaryDL_train:
    print(images.shape)  # Should print [batch_size, 1, 48, 48]
    print(labels.shape)  # Should print [batch_size]
    break 

In [None]:
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)
model = model.to(device)
criterion = nn.BCEWithLogitsLoss()  # equivalent to 'sparse_categorical_crossentropy' in TensorFlow
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs = 10

In [None]:
for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0
    
    # Training phase
    for i, (images, labels) in enumerate(binaryDL_train):
        images, labels = images.to(device), labels.to(device).float()  # Convert labels to float for BCE

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)  # outputs should be the logits
        loss = criterion(outputs.squeeze(), labels)  # Use squeeze to ensure the shape matches
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        print("#", end = "")

    # Validation phase using the test set
    model.eval()  # Set model to evaluation mode
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in binaryDL_val:
            images, labels = images.to(device), labels.to(device).float()  # Convert labels to float for BCE
            outputs = model(images)
            loss = criterion(outputs.squeeze(), labels)
            val_loss += loss.item()
            
            # Convert logits to probabilities
            predicted = (torch.sigmoid(outputs) > 0.5).float()  # Convert logits to binary predictions
            total += labels.size(0)
            correct += (predicted.squeeze() == labels).sum().item()

    # Print training and validation loss and accuracy
    print(f"Epoch [{epoch+1}/{num_epochs}], "
          f"Training Loss: {running_loss / len(binaryDL_train):.4f}, "
          f"Validation Loss: {val_loss / len(binaryDL_val):.4f}, "
          f"Validation Accuracy: {100 * correct / total:.2f}%")

In [None]:
PATH = "Binary10.pt"
torch.save({
            'epoch': num_epochs-1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
            }, PATH)