In [1]:
#import libraries
import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset,DataLoader,random_split
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import os
from PIL import Image,ImageFile
import torch.optim as optim
import pandas as pd
import copy
from torch.optim import lr_scheduler
from torch.utils.tensorboard import SummaryWriter
import cv2
from torchvision import transforms
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from captum.attr import IntegratedGradients
import timm
from torchvision.utils import make_grid
from torchcam.methods import SmoothGradCAMpp
from torchcam.utils import overlay_mask

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

2024-11-20 16:48:44.628737: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-20 16:48:44.645449: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


cuda


In [2]:
seed_val = 420
torch.manual_seed(seed_val)

<torch._C.Generator at 0x76f6d833ee30>

In [3]:
base_data_dir = '../data/artifact'
fake_dir = os.path.join(base_data_dir, 'generated')  # Fake artworks directory
real_dir = os.path.join(base_data_dir, 'real')  # Real artworks directory

ARTWORK DATASET

In [4]:
#Class to manage artworks with respect to their authenticity

class ArtworkDataset(Dataset):
  def __init__(self,links,transform):
      self.data = links
      self.transform = transform

  def __len__(self):
    return self.data.index.shape[0]
    
  def __getitem__(self,idx):
        img = Image.open(self.data.iloc[idx,0])
        label_index = self.data.iloc[idx, 1]
        if (img.mode != 'RGB'):
            img = img.convert('RGB')

        if self.transform:
          img = self.transform(img)
        return img, label_index


In [5]:
# Create a CSV file with paths to artwork images and their labels (real or fake)
data = [] 

# Iterate over the fake artworks and add their paths and labels to the list
for dirpath, dirnames, filenames in os.walk(fake_dir):
    for filename in filenames:
        if filename.endswith(".jpg"): # only consider jpg files
            filepath = os.path.join(dirpath, filename)
            data.append((filepath, "0"))


# Iterate over the real artworks and add their paths and labels to the list
for dirpath, dirnames, filenames in os.walk(real_dir):
    for filename in filenames:
        if filename.endswith(".jpg"):
            filepath = os.path.join(dirpath, filename)
            data.append((filepath, "1"))  # Label 1 for real artworks

# Convert the list "data" to a pandas dataframe
df = pd.DataFrame(data, columns=["path", "label"])

# Save the dataframe to a CSV file
csv_output_path = os.path.join(base_data_dir, "image_labels.csv")
df.to_csv(csv_output_path, index=False)
print(f"CSV file saved at {csv_output_path}")


CSV file saved at ../data/artifact/image_labels.csv


LOAD PRETRAINED MODEL

In [6]:
model = timm.create_model('convnext_base',pretrained=True, num_classes=2)

model = model.to(device)
print(model)

ConvNeXt(
  (stem): Sequential(
    (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
    (1): LayerNorm2d((128,), eps=1e-06, elementwise_affine=True)
  )
  (stages): Sequential(
    (0): ConvNeXtStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): ConvNeXtBlock(
          (conv_dw): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
          (norm): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=128, out_features=512, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (fc2): Linear(in_features=512, out_features=128, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
        )
        (1): ConvNeXtBlock(
          (conv_dw): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
          (norm): LayerNorm((128,), eps=1e-06, elemen

In [7]:
#Setting the model weights to non-trainable
for param in model.parameters(): 
    param.requires_grad = False

In [8]:
#Make the last layer of the model trainable
for p in model.head.parameters(): #instead of fc, we use head
    p.requires_grad=True


SPLIT IN TRAINING AND VALIDATION SET

In [9]:
dataset = df
dataset['label'] = dataset['label'].astype(int)
dataset

Unnamed: 0,path,label
0,../data/artifact/generated/stylegan3-t-metface...,0
1,../data/artifact/generated/stylegan3-t-metface...,0
2,../data/artifact/generated/stylegan3-t-metface...,0
3,../data/artifact/generated/stylegan3-t-metface...,0
4,../data/artifact/generated/stylegan3-t-metface...,0
...,...,...
88146,../data/artifact/real/paul-albert-besnard_robe...,1
88147,../data/artifact/real/nikolai-ge_christ-and-th...,1
88148,../data/artifact/real/paul-bril_view-of-a-port...,1
88149,../data/artifact/real/felix-vallotton_the-port...,1


In [10]:
#train, validation = train_test_split(dataset.values, stratify=dataset.values[:, 1], test_size=.3, random_state = 1) 

# Split the dataset into two parts (train + validation, and test)
train_val_data, test = train_test_split(dataset.values, test_size=0.1, random_state=seed_val)

# Split the train + validation part into training and validation sets
train, validation = train_test_split(train_val_data, test_size=0.1, random_state=seed_val)

In [11]:
print(f"Test set size: {len(validation)}")

Test set size: 7934


In [12]:
train_links = pd.DataFrame(train, columns = dataset.columns)
validation_links = pd.DataFrame(validation, columns = dataset.columns)
test_links = pd.DataFrame(test, columns = dataset.columns)

In [13]:
train_links[:5]

Unnamed: 0,path,label
0,../data/artifact/real/felipe-de-vicente_art-67...,1
1,../data/artifact/generated/stylegan3-r-metface...,0
2,../data/artifact/real/johannes-itten_der-bachs...,1
3,../data/artifact/generated/stylegan3-t-metface...,0
4,../data/artifact/real/stanley-spencer_kit-insp...,1


#### DATA LOADERS

In [14]:
# old
data_transforms = transforms.Compose([
                                transforms.Resize(224),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# new
train_transforms = transforms.Compose([
    transforms.Resize(256),                    # Resize to allow cropping
    transforms.RandomCrop(224),                # Crop to 224x224
    transforms.RandomHorizontalFlip(p=0.5),    # Random horizontal flip -- half of the images are mirrored; prevent the model from learning exact position-based cues.
    transforms.RandomRotation(10),             # Small rotation for variability -- helps the model generalize across minor rotations, as real-world artworks and digital images might not always be perfectly aligned.
    transforms.ColorJitter(brightness=0.1, contrast=0.1), # Minimal brightness/contrast -- A brightness and contrast jitter of ±10% introduces slight lighting variability without drastically changing color tones. This adjustment accounts for small lighting and contrast differences that can naturally occur in photographed or scanned artworks, which could otherwise become confounding factors.
    transforms.ToTensor(),                     # Convert to tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Standard normalization
])

# Validation/test transformations (no augmentation)
val_test_transforms = transforms.Compose([
    transforms.Resize(224),                    # Resize directly to 224x224
    transforms.CenterCrop(224),                # Center crop to ensure consistency
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the image
])

batch_size = 32

train_set = ArtworkDataset(train_links[:1000], train_transforms)

validation_set = ArtworkDataset(validation_links[:1000], val_test_transforms)

test_set = ArtworkDataset(test_links[:1000], val_test_transforms)


train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, 
                               drop_last=False,num_workers=6, pin_memory = True)

validation_loader = DataLoader(validation_set, batch_size=batch_size, shuffle=False, 
                               drop_last=False,num_workers=6, pin_memory = True)

test_loader = DataLoader(test_set,batch_size=batch_size, shuffle = False,
                              drop_last=False,num_workers=6, pin_memory = True)

**Early Stopping**

In [15]:
class EarlyStopping():
    """
    Early stopping to stop the training when the loss does not improve after
    certain epochs.
    """
    def __init__(self, patience=5, min_delta=0.001):
        """
        :param patience: how many epochs to wait before stopping when loss is
               not improving
        :param min_delta: minimum difference between new loss and old loss for
               new loss to be considered as an improvement
        """
        self.patience = patience
        self.min_delta = min_delta
        self.wait = 0
        self.best_loss = None
        self.early_stop = False
                
    def __call__(self, current_loss):
        if self.best_loss is None or (current_loss - self.best_loss) < -self.min_delta:
            self.best_loss = current_loss
            self.wait = 0
        else:
            self.wait += 1
            print(f"INFO: Early stopping counter {self.wait} of {self.patience}")
            if self.wait >= self.patience:
                self.early_stop = True

**Functions for visualisation and logging**

In [16]:
# functions for logging the misclassified images in tensorboard
def denormalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
    """Denormalizes a tensor image by applying the inverse of the normalization transform."""
    # Reshape mean and std to match the (C, H, W) shape of the tensor
    if tensor.dim() == 4:
        tensor = tensor[0]
    mean = torch.tensor(mean).view(3, 1, 1).to(tensor.device)  # Match device as well
    std = torch.tensor(std).view(3, 1, 1).to(tensor.device)
    denormalized = tensor * std + mean
    denormalized = denormalized.clamp(0, 1) # convert to [H, W, C] for matplotlib visualisation

    return denormalized.permute(1, 2, 0).cpu().numpy()

def add_labels_to_image(image, true_label, pred_label):
    """Add true and predicted labels to the image."""
    # Convert the image to uint8 if it's in float format
    if image.dtype != np.uint8:
        image = (image * 255).astype(np.uint8)

    # Convert from RGB to BGR for OpenCV
    bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    # Prepare text without brackets
    text = f'True: {int(true_label)}, Pred: {int(pred_label)}'
    cv2.putText(bgr_image, text, (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)

    # Convert back to RGB before returning
    return cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)


In [17]:
def log_misclassified_images(misclassified_images, misclassified_preds, misclassified_true, epoch, writer):
    print("Logging misclassified images...")
    print(len(misclassified_images))
    labeled_images = []

    for i, img in enumerate(misclassified_images):
        img_np = denormalize_image(img)  # Denormalize the image      

        # Ensure the image is in the right format for TensorBoard
        if img_np.shape[-1] == 1:  # Check if the image is grayscale
            img_np = img_np.squeeze(axis=2)  # Remove the channel dimension if it is 1
        elif img_np.shape[0] == 1:  # If the shape is (1, H, W)
            img_np = img_np.squeeze(axis=0)  # Remove the batch dimension
        img_np = np.clip(img_np, 0, 1)  # Ensure the values are between 0 and 1
        img_np = (img_np * 255).astype(np.uint8)     # Scale to 0-255

        true_label = misclassified_true[i]
        pred_label = misclassified_preds[i]

        labeled_image = add_labels_to_image(img_np, true_label, pred_label)
        # labeled_image = labeled_image.transpose(2, 0, 1)
        labeled_images.append(labeled_image)

    # Log image to TensorBoard
    writer.add_images(
        f'Misclassified_Epoch_{epoch}',
        np.array(labeled_images),
        global_step=epoch,
        dataformats='NHWC'
    )


def log_confusion_matrix(all_labels, all_preds, epoch, log_name, writer):
    """Log confusion matrix to TensorBoard."""
    cm = confusion_matrix(all_labels, all_preds)

    # Plot confusion matrix using Seaborn
    plt.figure(figsize=(6, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=['Fake', 'True'], yticklabels=['Fake', 'True'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(f'Confusion Matrix - Epoch {epoch}')
    plt.tight_layout()

    # Save to buffer and log
    buf = BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    image = np.array(Image.open(buf))
    writer.add_image(f'Confusion_Matrix', image, global_step=epoch, dataformats='HWC')
    buf.close()
    plt.close()

##### GRADCAM and Integrated Gradients

In [18]:
print(model.stages[-1].blocks[-1])

ConvNeXtBlock(
  (conv_dw): Conv2d(1024, 1024, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=1024)
  (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
  (mlp): Mlp(
    (fc1): Linear(in_features=1024, out_features=4096, bias=True)
    (act): GELU()
    (drop1): Dropout(p=0.0, inplace=False)
    (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    (drop2): Dropout(p=0.0, inplace=False)
  )
  (drop_path): Identity()
)


In [19]:
def log_integrated_gradients_old(model, inputs, labels, device, epoch, writer):
    """
    Applies Integrated Gradients to a single batch of inputs and logs a side-by-side plot (input + IG) to TensorBoard.
    """
    model.eval()
    ig = IntegratedGradients(model)

    inputs = inputs.to(device)
    labels = labels.to(device)

    # Perform IG for a single input (e.g., the first image in the batch)
    input_example = inputs[0].unsqueeze(0)  # Add batch dimension
    target_label = labels[0].item()  # Ground truth class

    # Compute attributions
    attributions = ig.attribute(input_example, target=target_label, n_steps=50)

    # Convert tensors to numpy for visualization
    input_example = input_example.squeeze().cpu().detach().numpy()
    attributions = attributions.squeeze().cpu().detach().numpy()

    # Normalize attributions for visualization
    attributions = (attributions - attributions.min()) / (attributions.max() - attributions.min())

    # Reshape input and attribution for RGB or Grayscale
    input_example = input_example.transpose(1, 2, 0)  # CHW to HWC for RGB
    if input_example.shape[2] != 3:  # Handle grayscale
        input_example = np.repeat(input_example[:, :, np.newaxis], 3, axis=2)
    attribution_overlay = attributions.transpose(1, 2, 0)

    # Create a side-by-side plot
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(input_example)
    axes[0].set_title("Input Image")
    axes[0].axis("off")
    axes[1].imshow(attribution_overlay, cmap="viridis")
    axes[1].set_title("Integrated Gradients")
    axes[1].axis("off")

    # Save the plot to a buffer
    fig.canvas.draw()
    buf = fig.canvas.tostring_rgb()
    width, height = fig.canvas.get_width_height()
    image = np.frombuffer(buf, dtype=np.uint8).reshape(height, width, 3).transpose(2, 0, 1)  # HWC to CHW

    plt.close(fig)  # Close the plot to free memory

    # Log the combined image to TensorBoard
    writer.add_image(f"IG/Input_and_IG", image, epoch)

    print(f"Integrated Gradients side-by-side plot logged for epoch {epoch}")

In [20]:
def apply_grad_cam(model, images, preds, true_labels, writer, log_name="Validation"):
    # Assume 'model' is in eval mode and the images are pre-processed
    target_layer = model.stages[-1].blocks[-1].conv_dw
    gradcam = GradCAMPlusPlus(model, target_layers=[target_layer])
    for i, image in enumerate(images):
        image_tensor = transforms.ToTensor()(image).unsqueeze(0).cuda()
        true_label = true_labels[i]
        
        # Get Grad-CAM++ heatmap
        cam = gradcam(image_tensor)
        cam = cam.squeeze().cpu().detach().numpy()  # Remove batch dimension
        
        # Normalize the CAM output
        cam = np.maximum(cam, 0)  # Set negative values to 0
        cam = cam / cam.max()  # Normalize to [0, 1]

        denormalized_image = denormalize_image(image_tensor.unsqueeze(0))
        
        # Overlay heatmap on the original image
        cam = np.uint8(255 * cam)  # Convert to 0-255 range
        cam = np.expand_dims(cam, axis=-1)  # Make it a 3-channel image
        heatmap = np.repeat(cam, 3, axis=-1)
        overlayed_image = np.uint8(denormalized_image * 0.5 + heatmap * 0.5)  # Blend the original image with heatmap

        # Convert image to TensorBoard format
        figure = plt.figure(figsize=(10, 10))
        plt.subplot(1, 2, 1)
        plt.imshow(denormalized_image)
        plt.title(f"Original - True Label: {true_label}, Pred: {preds[i]}")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(overlayed_image)
        plt.title(f"Grad-CAM++ Overlay")
        plt.axis('off')

        writer.add_figure(f'{log_name}/Grad-CAM++_{i}', figure)

In [21]:
def log_gradcam(model, image_tensor, true_label, class_idx, writer, log_name="Validation"):
    """
    Apply Grad-CAM++ to visualize class-specific saliency maps and overlay them on the input image.
    
    Args:
        model: The PyTorch model.
        target_layer: The name of the layer to use for Grad-CAM++.
        image_tensor: The input image tensor (normalized, shape CxHxW).
        true_label: The true label of the image (int).
        class_idx: The predicted class index to visualize (int).
        writer: TensorBoard writer for logging.
        step: The current step (e.g., batch index or global step).
        denormalize: A function to denormalize the image tensor for visualization.
    """
    # Move model to eval mode
    model.eval()
    image_tensor.requires_grad = True
    target_layer = model.stages[-1].blocks[-1].conv_dw

    # Create Grad-CAM++ object
    cam_extractor = SmoothGradCAMpp(model, target_layer)

    # Detach the image for visualization
    img = image_tensor.detach()

    model(image_tensor)
    
    # Denormalize for visual output
    img_np = denormalize_image(img)
    img_np = (img_np * 255).astype('uint8')
    pil_image = Image.fromarray(img_np)

    if not isinstance(class_idx, (list, np.ndarray)):
        class_idx = [class_idx]

    # Generate Grad-CAM++ activation maps
    cams = cam_extractor(class_idx, img.unsqueeze(0))
    cam_map = cams[0].squeeze(0).cpu()

    # Overlay Grad-CAM++ heatmap on the original image
    heatmap = overlay_mask(pil_image, Image.fromarray((cam_map.numpy() * 255).astype('uint8')), alpha=0.5)

    # Plot original image and overlayed heatmap
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(pil_image)
    ax[0].set_title(f"Original Image (Label: {true_label})")
    ax[0].axis("off")
    ax[1].imshow(heatmap)
    ax[1].set_title(f"Grad-CAM++ (Class: {class_idx})")
    ax[1].axis("off")
    plt.tight_layout()

    # Log images to TensorBoard
    writer.add_figure(f"GradCAM++/{log_name}", fig)

    plt.close(fig)  # Close the plot to free memory

In [22]:
def log_integrated_gradients(
    model, input_tensor, target_label, baseline=None, writer=None, log_name = "Validation"):
    """
    Visualizes Integrated Gradients (IG) for a single input image.
    
    Args:
        model: Trained PyTorch model.
        input_tensor: Input image tensor (C, H, W) normalized for the model.
        target_label: Target label for IG attribution.
        baseline: Baseline for IG. Defaults to zero baseline.
        writer: TensorBoard SummaryWriter object.
        log_name: Name for logging the figure (default: "Validation").
    """

    # Ensure the target_label is a PyTorch tensor of type long
    target_label = torch.tensor(target_label, dtype=torch.long).to(input_tensor.device)

    # Set the model to evaluation mode
    model.eval()
    for param in model.parameters():
        param.requires_grad = True
    
    # Ensure that the input tensor requires gradients
    if input_tensor.dim() == 3:  # Shape: [C, H, W]
        input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension
    input_tensor.requires_grad_()  # Enable gradients for backpropagation


    input_tensor.requires_grad_()  # Ensure the input tensor requires gradients

    # Prepare baseline (default is zero image)
    if baseline is None:
        baseline = torch.zeros_like(input_tensor).to(input_tensor.device)
    baseline.requires_grad = False

    
    # Compute attributions using Integrated Gradients
    ig = IntegratedGradients(model)
    attributions, _ = ig.attribute(
        inputs=input_tensor, 
        baselines=baseline, 
        target=target_label, 
        return_convergence_delta=True
    )

    print(f"Attributions shape before processing: {attributions.shape}")

    
    # Process the attributions
    attributions = attributions.squeeze(0).cpu().detach()  # Shape: (C, H, W)
    print(f"Attributions shape after squeeze: {attributions.shape}")

    
    # Normalize IG attributions to [0, 1] for visualization
    attributions_normalized = (attributions - attributions.min()) / (attributions.max() - attributions.min())
    attributions_normalized = attributions_normalized.permute(1, 2, 0).numpy()  # Shape: (H, W, C)
    
    # Denormalize the input image for visualization
    input_image_denormalized = denormalize_image(input_tensor.squeeze(0)).permute(1, 2, 0).cpu().numpy()
    
    # Plotting the results
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    ax[0].imshow(input_image_denormalized)
    ax[0].axis('off')
    ax[0].set_title('Denormalized Input Image')

    ax[1].imshow(input_image_denormalized, alpha=0.6)  # Overlay original image
    ax[1].imshow(attributions_normalized, cmap='hot', alpha=0.4)  # Overlay IG heatmap
    ax[1].axis('off')
    ax[1].set_title('Integrated Gradients Attributions')
    
    plt.tight_layout()
    
    # Save figure to TensorBoard if writer is provided
    plt.show()
    writer.add_figure(f"Integrated Gradients/{log_name}", fig)
    plt.close(fig)


#### TRAINING AND VALIDATION

In [23]:
def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, epoch, log_name, writer):
    model.train()  # Set model to training mode
    running_loss = 0.0
    running_corrects = 0

    # Iterate over training data
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch} Training"):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()  # Zero the parameter gradients

        # Forward pass
        outputs = model(inputs)
        outputs = nn.Softmax(dim=1)(outputs)  # Apply softmax for classification
        preds = torch.argmax(outputs, dim = 1)

        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    # Compute loss and accuracy for the epoch
    epoch_loss = running_loss / (len(train_loader) * train_loader.batch_size)
    epoch_acc = running_corrects.double() / (len(train_loader) * train_loader.batch_size)

    # Log statistics
    writer.add_scalar(f'Loss/Train/{log_name}', epoch_loss, epoch)
    writer.add_scalar(f'Accuracy/Train/{log_name}', epoch_acc, epoch)

    # Step the scheduler
    scheduler.step()

    print(f"Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
    return epoch_acc, epoch_loss


In [24]:
def validate_one_epoch(model, validation_loader, criterion, device, epoch, num_epochs, log_name, writer):
    model.eval()  # Set model to evaluation mode
    
    running_loss = 0.0
    running_corrects = 0
    all_preds = []
    all_labels = []
    misclassified_images = []
    misclassified_preds = []
    misclassified_true = []

    # Disable gradient computation for validation
    with torch.no_grad():
        for inputs, labels in tqdm(validation_loader, desc=f"Epoch {epoch} Validation"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(inputs)
            outputs = nn.Softmax(dim=1)(outputs)  # Apply softmax for classification
            #_, preds = torch.max(outputs, 1)
            preds = torch.argmax(outputs, dim = 1)
        
            loss = criterion(outputs, labels)

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

            # Collect all predictions and labels for confusion matrix
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
         
            # Collect misclassified images
            misclassified = preds != labels.data
            if misclassified.any():
                misclassified_images.extend(inputs[misclassified].cpu())
                misclassified_preds.extend(preds[misclassified].cpu().numpy())
                misclassified_true.extend(labels[misclassified].cpu().numpy())

    # Compute loss and accuracy for the epoch
    epoch_loss = running_loss / (len(validation_loader) * validation_loader.batch_size)
    epoch_acc = running_corrects.double() / (len(validation_loader) * validation_loader.batch_size)

    # Log statistics
    writer.add_scalar(f'Loss/Validation/{log_name}', epoch_loss, epoch)
    writer.add_scalar(f'Accuracy/Validation/{log_name}', epoch_acc, epoch)

    print(f"Validation Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

    
    # Confusion Matrix
    log_confusion_matrix(all_labels, all_preds, epoch, log_name, writer)
    
    # Log misclassified images (with confusion matrix info)
    # log_misclassified_images(misclassified_images, misclassified_preds, misclassified_true, epoch, writer)

    # Apply Grad-CAM and IG to a subset of misclassified images
    subset_size = min(50, len(misclassified_images))
    if subset_size > 0:
        subset_indices = np.random.choice(len(misclassified_images), size=subset_size, replace=False)
        subset_images = torch.stack([misclassified_images[i] for i in subset_indices])
        subset_preds = [misclassified_preds[i] for i in subset_indices]
        subset_true = [misclassified_true[i] for i in subset_indices]

        for i, (image, pred, true) in enumerate(zip(subset_images, subset_preds, subset_true)):
            log_gradcam(model, image.unsqueeze(0).to(device), pred, true, writer, "Validation")
            # apply_grad_cam(model, image.unsqueeze(0).to(device), pred, true, writer, "Validation")

            log_integrated_gradients(model, image.to(device), true, None, writer, "Validation")



    return epoch_acc, epoch_loss


In [25]:
def fine_tune(model, train_loader, validation_loader, criterion, optimizer, scheduler, early_stop, num_epochs=100, 
              log_name='run1', device='cuda', writer=None):
    best_model = copy.deepcopy(model)
    best_acc = 0.0
    best_epoch = 0
    stop = False

    for epoch in range(1, num_epochs + 1):
        if stop:
            break
        print(f'Epoch {epoch}/{num_epochs}')
        print('-'*120)

        # Training phase
        train_acc, train_loss = train_one_epoch(model, train_loader, criterion, optimizer, scheduler, device, epoch, log_name, writer)

        # Validation phase
        val_acc, val_loss = validate_one_epoch(model, validation_loader, criterion, device, epoch, num_epochs, log_name, writer)

        # Early stopping
        if val_acc > best_acc:
            best_acc = val_acc
            best_epoch = epoch
            best_model = copy.deepcopy(model)

        early_stop(val_loss)
        stop = early_stop.early_stop

        print('-'*120)

    print(f'Best val Acc: {best_acc:4f}')
    print(f'Best epoch: {best_epoch:03d}')
    
    # Save the best model state when training stops
    model_path = os.path.join('saved_models/convnext_model', 'RealArt_vs_FakeArt_convnext_base.pt')
    # torch.save(best_model.state_dict(), model_path)
    print(f"Model saved to: {model_path}")

    return best_model

In [26]:
model_path = os.path.join('saved_models/convnext_model', 'RealArt_vs_FakeArt_convnext_base1.pt')
writer = SummaryWriter("runs/gradcam")

if not os.path.exists(model_path):
   print("Training a new model...")
   criterion = nn.CrossEntropyLoss()
   optimizer = optim.Adam(model.parameters(), lr=1e-3)
   scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
   early_stop = EarlyStopping(patience = 3, min_delta = 0.001)
   ImageFile.LOAD_TRUNCATED_IMAGES = True

   
   # Train and save the best model
   best_model_head = fine_tune(model, train_loader, validation_loader, criterion, optimizer, scheduler, 
                               early_stop, num_epochs = 30, writer = writer)
else: 
     # Model already exists; load it
     print("Loading pre-trained model...")
     model = timm.create_model('convnext_base',pretrained=True, num_classes=2) 
     model.load_state_dict(torch.load(model_path))
     model.to(device)
     print("Model loaded successfully from:", model_path)

Training a new model...
Epoch 1/30
------------------------------------------------------------------------------------------------------------------------


Epoch 1 Training:   0%|          | 0/32 [00:00<?, ?it/s]

Epoch 1 Training: 100%|██████████| 32/32 [00:03<00:00, 10.62it/s]


Train Loss: 0.5469 Acc: 0.8584


Epoch 1 Validation: 100%|██████████| 32/32 [00:02<00:00, 11.64it/s]


Validation Loss: 0.4760 Acc: 0.8750


OutOfMemoryError: CUDA out of memory. Tried to allocate 78.00 MiB. GPU 0 has a total capacity of 11.75 GiB of which 74.50 MiB is free. Including non-PyTorch memory, this process has 11.53 GiB memory in use. Of the allocated memory 11.06 GiB is allocated by PyTorch, and 154.30 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

TESTING

In [None]:
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, roc_auc_score
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# testing function
def test_model(model, test_loader):
    model.eval() # evaluation mode
    test_loss = 0
    correct = 0
    pred_list = []
    true_list = []
    misclassified_images = []
    misclassified_preds = []
    misclassified_true = []


    # progress bar
    pbar = tqdm(total=len(test_loader))
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item() # sommo il loss di ogni batch
            
            pred = output.argmax(dim=1, keepdim=True) # ottengo la predizione del modello
            pred_list.extend(pred.cpu().numpy()) # aggiungo la predizione alla lista
            true_list.extend(target.cpu().numpy()) # aggiungo il target alla lista
            
            correct += pred.eq(target.view_as(pred)).sum().item() # aggiorno il contatore di classificazioni corrette

            misclassified = ~pred.eq(target.view_as(pred)).squeeze()
            if misclassified.any():
                misclassified_images.extend(data[misclassified].cpu())
                misclassified_preds.extend(pred[misclassified].cpu().numpy())
                misclassified_true.extend(target[misclassified].cpu().numpy())

            # update progress bar
            pbar.update(1)
            
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    recall = recall_score(true_list, pred_list, average='macro')
    precision = precision_score(true_list, pred_list, average='macro') 
    f1 = f1_score(true_list, pred_list, average='macro') 
    auc = roc_auc_score(true_list, pred_list) 
    
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%), Recall: {:.2f}%, Precision: {:.2f}%, F1: {:.2f}%, AUC: {:.2f}%\n'.format(
        test_loss, correct, len(test_loader.dataset), accuracy, recall*100, precision*100, f1*100, auc*100))
    
    log_misclassified_images(misclassified_images, misclassified_preds, misclassified_true, None, writer)

    if len(misclassified_images) > 0:
            for i, (image, pred, true) in enumerate(zip(misclassified_images, misclassified_preds, misclassified_true)):
                # Apply Grad-CAM
                apply_grad_cam(model, image.unsqueeze(0).to(device), pred, true, writer, "Testing")
                # Log Integrated Gradients
                log_integrated_gradients(model, image.to(device), target_label=true, baseline = None,writer=writer, log_name="Testing")

    writer.flush()
    return accuracy, recall, precision, f1, auc


100%|██████████| 276/276 [00:40<00:00, 12.20it/s]

use this for logs:
tensorboard --logdir=./runs

In [None]:
accuracy,recall,precision,f1,auc = test_model(model,test_loader)
writer.close()

  text = f'True: {int(true_label)}, Pred: {int(pred_label)}'



Test set: Average loss: 0.0752, Accuracy: 8643/8816 (98.04%), Recall: 97.90%, Precision: 98.10%, F1: 98.00%, AUC: 97.90%

Logging misclassified images...
173


100%|██████████| 276/276 [00:25<00:00, 10.63it/s]
