In [617]:
#import libraries
import torch
import torch.nn as nn
import shutil
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset,DataLoader,random_split
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import os
from PIL import Image,ImageFile
import torch.optim as optim
import pandas as pd
import copy
from torch.optim import lr_scheduler
from torch.utils.tensorboard import SummaryWriter
import cv2
from torchvision import transforms
import albumentations as augment
from albumentations.pytorch import ToTensorV2
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [618]:
seed_val = 420
torch.manual_seed(seed_val)

<torch._C.Generator at 0x7ef72e103750>

In [619]:
base_data_dir = '../data/artifact'
fake_dir = os.path.join(base_data_dir, 'generated')  # Fake artworks directory
real_dir = os.path.join(base_data_dir, 'real')  # Real artworks directory

ARTWORK DATASET

In [620]:
#Class to manage artworks with respect to their authenticity

class ArtworkDataset(Dataset):
  def __init__(self,links,transform):
      self.data = links
      self.transform = transform

  def __len__(self):
    return self.data.index.shape[0]
    
  def __getitem__(self,idx):
        img = Image.open(self.data.iloc[idx,0])
        label_index = self.data.iloc[idx, 1]
        if (img.mode != 'RGB'):
            img = img.convert('RGB')

        if self.transform:
          img = self.transform(img)
        return img, label_index


In [621]:
# Create a CSV file with paths to artwork images and their labels (real or fake)
data = [] 

# Iterate over the fake artworks and add their paths and labels to the list
for dirpath, dirnames, filenames in os.walk(fake_dir):
    for filename in filenames:
        if filename.endswith(".jpg"): # only consider jpg files
            filepath = os.path.join(dirpath, filename)
            data.append((filepath, "0"))


# Iterate over the real artworks and add their paths and labels to the list
for dirpath, dirnames, filenames in os.walk(real_dir):
    for filename in filenames:
        if filename.endswith(".jpg"):
            filepath = os.path.join(dirpath, filename)
            data.append((filepath, "1"))  # Label 1 for real artworks

# Convert the list "data" to a pandas dataframe
df = pd.DataFrame(data, columns=["path", "label"])

# Save the dataframe to a CSV file
csv_output_path = os.path.join(base_data_dir, "image_labels.csv")
df.to_csv(csv_output_path, index=False)
print(f"CSV file saved at {csv_output_path}")


CSV file saved at ../data/artifact/image_labels.csv


LOAD PRETRAINED MODEL

In [622]:
%pip install timm
import timm

model = timm.create_model('convnext_base',pretrained=True, num_classes=2)

model = model.to(device)
print(model)

Note: you may need to restart the kernel to use updated packages.
ConvNeXt(
  (stem): Sequential(
    (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
    (1): LayerNorm2d((128,), eps=1e-06, elementwise_affine=True)
  )
  (stages): Sequential(
    (0): ConvNeXtStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): ConvNeXtBlock(
          (conv_dw): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
          (norm): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=128, out_features=512, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (norm): Identity()
            (fc2): Linear(in_features=512, out_features=128, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
          (shortcut): Identity()
          (drop_path): Identity()
        )
        (1): ConvNeXtBlock(
          (conv_dw): Con

In [623]:
#Setting the model weights to non-trainable
for param in model.parameters(): 
    param.requires_grad = False

In [624]:
#Make the last layer of the model trainable
for p in model.head.parameters(): #instead of fc, we use head
    p.requires_grad=True


SPLIT IN TRAINING AND VALIDATION SET

In [625]:
dataset = df
dataset['label'] = dataset['label'].astype(int)
dataset

Unnamed: 0,path,label
0,../data/artifact/generated/stylegan3-t-metface...,0
1,../data/artifact/generated/stylegan3-t-metface...,0
2,../data/artifact/generated/stylegan3-t-metface...,0
3,../data/artifact/generated/stylegan3-t-metface...,0
4,../data/artifact/generated/stylegan3-t-metface...,0
...,...,...
88146,../data/artifact/real/paul-albert-besnard_robe...,1
88147,../data/artifact/real/nikolai-ge_christ-and-th...,1
88148,../data/artifact/real/paul-bril_view-of-a-port...,1
88149,../data/artifact/real/felix-vallotton_the-port...,1


In [626]:
#train, validation = train_test_split(dataset.values, stratify=dataset.values[:, 1], test_size=.3, random_state = 1) 

# Split the dataset into two parts (train + validation, and test)
train_val_data, test = train_test_split(dataset.values, test_size=0.1, random_state=seed_val)

# Split the train + validation part into training and validation sets
train, validation = train_test_split(train_val_data, test_size=0.1, random_state=seed_val)

In [627]:
print(f"Test set size: {len(validation)}")

Test set size: 7934


In [628]:
train_links = pd.DataFrame(train, columns = dataset.columns)
validation_links = pd.DataFrame(validation, columns = dataset.columns)
test_links = pd.DataFrame(test, columns = dataset.columns)

In [629]:
train_links[:5]

Unnamed: 0,path,label
0,../data/artifact/real/felipe-de-vicente_art-67...,1
1,../data/artifact/generated/stylegan3-r-metface...,0
2,../data/artifact/real/johannes-itten_der-bachs...,1
3,../data/artifact/generated/stylegan3-t-metface...,0
4,../data/artifact/real/stanley-spencer_kit-insp...,1


BUILDING DATA LOADERS

In [630]:
# old
data_transforms = transforms.Compose([
                                transforms.Resize(224),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# new
train_transforms = transforms.Compose([
    transforms.Resize(256),                    # Resize to allow cropping
    transforms.RandomCrop(224),                # Crop to 224x224
    transforms.RandomHorizontalFlip(p=0.5),    # Random horizontal flip -- half of the images are mirrored; prevent the model from learning exact position-based cues.
    transforms.RandomRotation(10),             # Small rotation for variability -- helps the model generalize across minor rotations, as real-world artworks and digital images might not always be perfectly aligned.
    transforms.ColorJitter(brightness=0.1, contrast=0.1), # Minimal brightness/contrast -- A brightness and contrast jitter of ±10% introduces slight lighting variability without drastically changing color tones. This adjustment accounts for small lighting and contrast differences that can naturally occur in photographed or scanned artworks, which could otherwise become confounding factors.
    transforms.ToTensor(),                     # Convert to tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Standard normalization
])

# Validation/test transformations (no augmentation)
val_test_transforms = transforms.Compose([
    transforms.Resize(224),                    # Resize directly to 224x224
    transforms.CenterCrop(224),                # Center crop to ensure consistency
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the image
])

batch_size = 32

train_set = ArtworkDataset(train_links, train_transforms)

validation_set = ArtworkDataset(validation_links, val_test_transforms)

test_set = ArtworkDataset(test_links, val_test_transforms)


train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, 
                               drop_last=False,num_workers=2)

validation_loader = DataLoader(validation_set, batch_size=batch_size, shuffle=False, 
                               drop_last=False,num_workers=2)

test_loader = DataLoader(test_set,batch_size=batch_size, shuffle = False,
                              drop_last=False,num_workers=2)

TRAINING

In [631]:
class EarlyStopping():
    """
    Early stopping to stop the training when the loss does not improve after
    certain epochs.
    """
    def __init__(self, patience=5, min_delta=0.001):
        """
        :param patience: how many epochs to wait before stopping when loss is
               not improving
        :param min_delta: minimum difference between new loss and old loss for
               new loss to be considered as an improvement
        """
        self.patience = patience
        self.min_delta = min_delta
        self.wait = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, current_loss):
        if self.best_loss == None:
            self.best_loss = current_loss
        elif (current_loss - self.best_loss) < -self.min_delta:
            self.best_loss = current_loss
            self.wait = 0
            
             # Save the model state
            model_path = os.path.join(self.model_dir, 'RealArt_vs_FakeArt_convnext_base.pt')

            # Remove the old model if it exists in the same directory
            if os.path.exists(model_path):
                os.remove(model_path)
            
            # save the new model 
            torch.save(model.state_dict(), model_path)
        else:
            self.wait = self.wait + 1
            print(f"INFO: Early stopping counter {self.wait} of {self.patience}")
            if self.wait >= self.patience:
                self.early_stop = True
                
    def __call__(self, current_loss):
        if self.best_loss is None or (current_loss - self.best_loss) < -self.min_delta:
            self.best_loss = current_loss
            self.wait = 0
        else:
            self.wait += 1
            print(f"INFO: Early stopping counter {self.wait} of {self.patience}")
            if self.wait >= self.patience:
                self.early_stop = True

In [632]:
# functions for logging the misclassified images in tensorboard
def denormalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
    """Denormalizes a tensor image by applying the inverse of the normalization transform."""
    # Reshape mean and std to match the (C, H, W) shape of the tensor
    mean = torch.tensor(mean).view(3, 1, 1).to(tensor.device)  # Match device as well
    std = torch.tensor(std).view(3, 1, 1).to(tensor.device)
    denormalized = tensor * std + mean
    return denormalized.clamp(0, 1)

def add_labels_to_image(image, true_label, pred_label):
    """Add true and predicted labels to the image."""
    # Convert the image to uint8 if it's in float format
    if image.dtype != np.uint8:
        image = (image * 255).astype(np.uint8)

    # Convert from RGB to BGR for OpenCV
    bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    # Prepare text without brackets
    text = f'True: {int(true_label)}, Pred: {int(pred_label)}'
    cv2.putText(bgr_image, text, (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)

    # Convert back to RGB before returning
    return cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)


In [633]:
def log_misclassified_images(misclassified_images, misclassified_preds, misclassified_true, epoch, writer):
    print("Logging misclassified images...")
    print(len(misclassified_images))
    labeled_images = []

    for i, img in enumerate(misclassified_images):
        img = denormalize_image(img)  # Denormalize the image
        img_np = img.permute(1, 2, 0).numpy()  # Convert (C, H, W) -> (H, W, C)
      

        # Ensure the image is in the right format for TensorBoard
        if img_np.shape[-1] == 1:  # Check if the image is grayscale
            img_np = img_np.squeeze(axis=2)  # Remove the channel dimension if it is 1
        elif img_np.shape[0] == 1:  # If the shape is (1, H, W)
            img_np = img_np.squeeze(axis=0)  # Remove the batch dimension
        img_np = np.clip(img_np, 0, 1)  # Ensure the values are between 0 and 1
        img_np = (img_np * 255).astype(np.uint8)     # Scale to 0-255

        true_label = misclassified_true[i]
        pred_label = misclassified_preds[i]

        labeled_image = add_labels_to_image(img_np, true_label, pred_label)
        # labeled_image = labeled_image.transpose(2, 0, 1)
        labeled_images.append(labeled_image)

    # Log image to TensorBoard
    writer.add_images(
        f'Misclassified_Epoch_{epoch}',
        np.array(labeled_images),
        global_step=epoch,
        dataformats='NHWC'
    )

In [634]:
def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, epoch, log_name, writer):
    model.train()  # Set model to training mode
    running_loss = 0.0
    running_corrects = 0

    # Iterate over training data
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch} Training"):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()  # Zero the parameter gradients

        # Forward pass
        outputs = model(inputs)
        outputs = nn.Softmax(dim=1)(outputs)  # Apply softmax for classification
        #_, preds = torch.max(outputs, 1)
        #print(torch.argmax(outputs, dim = 1))
        preds = torch.argmax(outputs, dim = 1)

        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    # Compute loss and accuracy for the epoch
    epoch_loss = running_loss / (len(train_loader) * train_loader.batch_size)
    epoch_acc = running_corrects.double() / (len(train_loader) * train_loader.batch_size)

    # Log statistics
    writer.add_scalar(f'Loss/Train/{log_name}', epoch_loss, epoch)
    writer.add_scalar(f'Accuracy/Train/{log_name}', epoch_acc, epoch)

    # Step the scheduler
    scheduler.step()

    print(f"Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
    return epoch_acc, epoch_loss


In [635]:
def validate_one_epoch(model, validation_loader, criterion, device, epoch, num_epochs, log_name, writer):
    model.eval()  # Set model to evaluation mode
    running_loss = 0.0
    running_corrects = 0
    all_preds = []
    all_labels = []
    misclassified_images = []
    misclassified_preds = []
    misclassified_true = []

    # Disable gradient computation for validation
    with torch.no_grad():
        for inputs, labels in tqdm(validation_loader, desc=f"Epoch {epoch} Validation"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(inputs)
            outputs = nn.Softmax(dim=1)(outputs)  # Apply softmax for classification
            #_, preds = torch.max(outputs, 1)
            preds = torch.argmax(outputs, dim = 1)
        
            loss = criterion(outputs, labels)

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

            # Collect all predictions and labels for confusion matrix
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
         
            # Collect misclassified images
            misclassified = preds != labels.data
            if misclassified.any():
                misclassified_images.extend(inputs[misclassified].cpu())
                misclassified_preds.extend(preds[misclassified].cpu().numpy())
                misclassified_true.extend(labels[misclassified].cpu().numpy())

    # Compute loss and accuracy for the epoch
    epoch_loss = running_loss / (len(validation_loader) * validation_loader.batch_size)
    epoch_acc = running_corrects.double() / (len(validation_loader) * validation_loader.batch_size)

    # Log statistics
    writer.add_scalar(f'Loss/Validation/{log_name}', epoch_loss, epoch)
    writer.add_scalar(f'Accuracy/Validation/{log_name}', epoch_acc, epoch)

    print(f"Validation Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

    # Calculate confusion matrix
    cm = confusion_matrix(all_labels, all_preds)

    # Plot confusion matrix using Seaborn
    plt.figure(figsize=(6, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=['Fake', 'True'], yticklabels=['Fake', 'True'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(f'Confusion Matrix - Epoch {epoch}')
    plt.tight_layout()

    # Convert the plot to a NumPy array and log it to TensorBoard
    buf = BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    image = np.array(Image.open(buf))
    writer.add_image(f'Confusion_Matrix', image, global_step=epoch, dataformats='HWC')
    buf.close()
    plt.close()

    # Log misclassified images (with confusion matrix info)
    log_misclassified_images(misclassified_images, misclassified_preds, misclassified_true, epoch, writer)

    return epoch_acc, epoch_loss

In [636]:
def fine_tune(model, train_loader, validation_loader, criterion, optimizer, scheduler, early_stop, num_epochs=100, 
              log_name='run1', device='cuda', writer=None):
    best_model = copy.deepcopy(model)
    best_acc = 0.0
    best_epoch = 0
    stop = False

    for epoch in range(1, num_epochs + 1):
        if stop:
            break
        print(f'Epoch {epoch}/{num_epochs}')
        print('-'*120)

        # Training phase
        train_acc, train_loss = train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, epoch, log_name, writer)

        # Validation phase
        val_acc, val_loss = validate_one_epoch(model, validation_loader, criterion, device, epoch, num_epochs, log_name, writer)

        # Early stopping
        if val_acc > best_acc:
            best_acc = val_acc
            best_epoch = epoch
            best_model = copy.deepcopy(model)

        early_stop(val_loss)
        stop = early_stop.early_stop

        print('-'*120)

    print(f'Best val Acc: {best_acc:4f}')
    print(f'Best epoch: {best_epoch:03d}')
    
    # Save the best model state when training stops
    model_path = os.path.join('saved_models/convnext_model', 'RealArt_vs_FakeArt_convnext_base.pt')
    # torch.save(best_model.state_dict(), model_path)
    print(f"Model saved to: {model_path}")

    return best_model

In [637]:
model_path = os.path.join('saved_models/convnext_model', 'RealArt_vs_FakeArt_convnext_base.pt')
writer = SummaryWriter("runs/confusion_matrix_all")

if not os.path.exists(model_path):
   print("Training a new model...")
   criterion = nn.CrossEntropyLoss()
   optimizer = optim.Adam(model.parameters(), lr=1e-3)
   scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
   early_stop = EarlyStopping(patience = 3, min_delta = 0.001)
   ImageFile.LOAD_TRUNCATED_IMAGES = True

   
   # Train and save the best model
   best_model_head = fine_tune(model, train_loader, validation_loader, criterion, optimizer, scheduler, 
                               early_stop, num_epochs = 30, log_name = 'augmented', writer = writer)
else: 
     # Model already exists; load it
     print("Loading pre-trained model...")
     model = timm.create_model('convnext_base',pretrained=True, num_classes=2) 
   #   model = torch.load(model_path)  # Load the full model
     model.load_state_dict(torch.load(model_path))
     print("Model loaded successfully from:", model_path)

Training a new model...
Epoch 1/30
------------------------------------------------------------------------------------------------------------------------


Epoch 1 Training:   0%|          | 0/2232 [00:00<?, ?it/s]

Epoch 1 Training: 100%|██████████| 2232/2232 [03:11<00:00, 11.63it/s]


Train Loss: 0.3562 Acc: 0.9581


Epoch 1 Validation: 100%|██████████| 248/248 [00:21<00:00, 11.41it/s]


Validation Loss: 0.3421 Acc: 0.9703
Logging misclassified images...
234
------------------------------------------------------------------------------------------------------------------------
Epoch 2/30
------------------------------------------------------------------------------------------------------------------------


Epoch 2 Training: 100%|██████████| 2232/2232 [03:13<00:00, 11.52it/s]


Train Loss: 0.3445 Acc: 0.9678


Epoch 2 Validation: 100%|██████████| 248/248 [00:21<00:00, 11.37it/s]


Validation Loss: 0.3435 Acc: 0.9688
Logging misclassified images...
246
INFO: Early stopping counter 1 of 3
------------------------------------------------------------------------------------------------------------------------
Epoch 3/30
------------------------------------------------------------------------------------------------------------------------


Epoch 3 Training: 100%|██████████| 2232/2232 [03:15<00:00, 11.40it/s]


Train Loss: 0.3419 Acc: 0.9705


Epoch 3 Validation: 100%|██████████| 248/248 [00:21<00:00, 11.30it/s]


Validation Loss: 0.3391 Acc: 0.9738
Logging misclassified images...
206
------------------------------------------------------------------------------------------------------------------------
Epoch 4/30
------------------------------------------------------------------------------------------------------------------------


Epoch 4 Training: 100%|██████████| 2232/2232 [03:15<00:00, 11.41it/s]


Train Loss: 0.3407 Acc: 0.9712


Epoch 4 Validation: 100%|██████████| 248/248 [00:21<00:00, 11.30it/s]


Validation Loss: 0.3385 Acc: 0.9734
Logging misclassified images...
209
INFO: Early stopping counter 1 of 3
------------------------------------------------------------------------------------------------------------------------
Epoch 5/30
------------------------------------------------------------------------------------------------------------------------


Epoch 5 Training: 100%|██████████| 2232/2232 [03:15<00:00, 11.39it/s]


Train Loss: 0.3398 Acc: 0.9725


Epoch 5 Validation: 100%|██████████| 248/248 [00:21<00:00, 11.27it/s]


Validation Loss: 0.3398 Acc: 0.9730
Logging misclassified images...
212
INFO: Early stopping counter 2 of 3
------------------------------------------------------------------------------------------------------------------------
Epoch 6/30
------------------------------------------------------------------------------------------------------------------------


Epoch 6 Training: 100%|██████████| 2232/2232 [03:16<00:00, 11.37it/s]


Train Loss: 0.3391 Acc: 0.9731


Epoch 6 Validation: 100%|██████████| 248/248 [00:22<00:00, 11.24it/s]


Validation Loss: 0.3386 Acc: 0.9739
Logging misclassified images...
205
INFO: Early stopping counter 3 of 3
------------------------------------------------------------------------------------------------------------------------
Best val Acc: 0.973916
Best epoch: 006
Model saved to: saved_models/convnext_model/RealArt_vs_FakeArt_convnext_base.pt


TESTING

In [638]:
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, roc_auc_score
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# testing function
def test_model(model, test_loader):
    model.eval() # evaluation mode
    test_loss = 0
    correct = 0
    pred_list = []
    true_list = []
    misclassified_images = []
    misclassified_preds = []
    misclassified_true = []


    # progress bar
    pbar = tqdm(total=len(test_loader))
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item() # sommo il loss di ogni batch
            
            pred = output.argmax(dim=1, keepdim=True) # ottengo la predizione del modello
            pred_list.extend(pred.cpu().numpy()) # aggiungo la predizione alla lista
            true_list.extend(target.cpu().numpy()) # aggiungo il target alla lista
            
            correct += pred.eq(target.view_as(pred)).sum().item() # aggiorno il contatore di classificazioni corrette

            misclassified = ~pred.eq(target.view_as(pred)).squeeze()
            if misclassified.any():
                misclassified_images.extend(data[misclassified].cpu())
                misclassified_preds.extend(pred[misclassified].cpu().numpy())
                misclassified_true.extend(target[misclassified].cpu().numpy())

            # update progress bar
            pbar.update(1)
            
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    recall = recall_score(true_list, pred_list, average='macro')
    precision = precision_score(true_list, pred_list, average='macro') 
    f1 = f1_score(true_list, pred_list, average='macro') 
    auc = roc_auc_score(true_list, pred_list) 
    
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%), Recall: {:.2f}%, Precision: {:.2f}%, F1: {:.2f}%, AUC: {:.2f}%\n'.format(
        test_loss, correct, len(test_loader.dataset), accuracy, recall*100, precision*100, f1*100, auc*100))
    
    #Logging the misclassified images
    for i, img in enumerate(misclassified_images):
        img_denorm = denormalize_image(img.clone())
        img_np = img_denorm.permute(1, 2, 0).numpy()  # Convert tensor to numpy (C, H, W) to (H, W, C)
        
        # Ensure the image is in the right format for TensorBoard
        if img_np.shape[-1] == 1:  # Check if the image is grayscale
            img_np = img_np.squeeze(axis=2)  # Remove the channel dimension if it is 1
        elif img_np.shape[0] == 1:  # If the shape is (1, H, W)
            img_np = img_np.squeeze(axis=0)  # Remove the batch dimension

        img_np = np.clip(img_np, 0, 1)  # Ensure the values are between 0 and 1
        img_np = (img_np * 255).astype(np.uint8)  # Scale to 0-255 and convert to uint8

        labeled_image = add_labels_to_image(img_np, misclassified_true[i], misclassified_preds[i])

        # Log the image to TensorBoard
        writer.add_image(f'Misclassified_True_{misclassified_true[i]}_Pred_{misclassified_preds[i]}_Index_{i}', labeled_image, dataformats='HWC')

    writer.flush()
    return accuracy, recall, precision, f1, auc

use this for logs:
tensorboard --logdir=./runs

In [639]:
# accuracy,recall,precision,f1,auc = test_model(model,test_loader)
writer.close()