In [4]:
import cv2 as cv
import cv2
import csv
import pandas as pd
import numpy as np
from tifffile import imread
import matplotlib.pyplot as plt
import os
import random
import numpy as np
from PIL import Image
from shutil import copyfile
from sklearn.model_selection import train_test_split
import time
from collections import OrderedDict
from sklearn.metrics import confusion_matrix
import nbformat
import ast

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
from torchvision.transforms import ToTensor

In [None]:
def print_functions_and_variables(filename):
    with open(filename) as f:
        nb = nbformat.read(f, as_version=4)
    
    function_counter = 0
    
    for cell in nb.cells:
        if cell.cell_type == "code":
            try:
                code_ast = ast.parse(cell.source)
            except SyntaxError:
                print("SyntaxError when parsing cell. Skipping...")
                continue

            functions = [node for node in code_ast.body if isinstance(node, ast.FunctionDef)]
            function_counter += len(functions)
            
            for function in functions:
                params = [arg.arg for arg in function.args.args]
                print(f"- Function: {function.name}, Parameters: {params}")
                
    print("\nTotal number of functions:", function_counter)

In [None]:
def obtain_tifs(in_dir):
    img_list = []
    for file_name in os.listdir(in_dir):
        if file_name.endswith('.tif'):
            img_list.append(os.path.join(in_dir, file_name))
    return img_list

In [31]:
def plot_median_intensity_histogram(directory, threshold, binsize, title):
    median_intensities = []

    for filename in os.listdir(directory):
        if filename.endswith(".png"):
            img_path = os.path.join(directory, filename)
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            
            # Remove intensity values less than the threshold
            img = img[img >= threshold]
            
            median_intensity = np.median(img)
            median_intensities.append(median_intensity)

    plt.hist(median_intensities, bins=binsize)
    plt.title(title)
    plt.xlabel("Median Intensity")
    plt.ylabel("Frequency")
    plt.show()

In [32]:
# Segments cells based on threshold - 25 is good, images are normalized and scaled
def segment_binary_cells(image_input, channel, threshold):
    
    #select channel of image
    image = image_input[channel, :, :]
    
    # Normalize and scale to 0-255 range
    img_array = (image / image.max()) * 255

    # Convert to uint8
    img_array = img_array.astype(np.uint8)
    
    #segment image
    _, segmented_img = cv2.threshold(img_array, threshold, 255, cv2.THRESH_BINARY)
    return img_array, segmented_img

In [None]:
def segment_dilate_binary_cells(image_input, channel, threshold, kernel_size, num_dilation):
    
    #select channel of image
    image = image_input[channel, :, :]
    
    # Normalize and scale to 0-255 range
    img_array = (image / image.max()) * 255

    # Convert to uint8
    img_array = img_array.astype(np.uint8)
    
    #segment image
    _, segmented_img = cv2.threshold(img_array, threshold, 255, cv2.THRESH_BINARY)

    # Define a kernel for morphological operations
    kernel = np.ones((kernel_size,kernel_size),np.uint8)

    # Closing operation to fill small holes
    closing = cv2.morphologyEx(segmented_img, cv2.MORPH_CLOSE, kernel)

    # Dilation operation to merge nearby contours
    dilated = cv2.dilate(closing, kernel, iterations = num_dilation)

    return img_array, dilated

In [2]:
def visualize_segmentation_binary(image_input, channel, threshold):
    img_array, segmented_img = segment_binary_cells(image_input, channel, threshold)
    
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(img_array, cmap='gray')
    ax[0].set_title('Original Image')
    ax[0].axis('off')
    
    ax[1].imshow(segmented_img, cmap='gray')
    ax[1].set_title('Binary Segmented Image')
    ax[1].axis('off')
    
    plt.show()
    
    return img_array, segmented_img

In [34]:
# Find contours in the binary image and filter for size and touching image border
def find_contours(thresh_img, contour_size_limit):
    filtered_contours = []
    contours, _ = cv2.findContours(thresh_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    for contour in contours:
       
        # Exclude contours touching the edge of the image
        if not any(point[0][0] == 0 or point[0][1] == 0 or point[0][0] == thresh_img.shape[1]-1 or point[0][1] == thresh_img.shape[0]-1 for point in contour):
            size = cv2.contourArea(contour)
           
            # Exclude contours less than contour_size_limit
            if size > contour_size_limit:
                filtered_contours.append(contour)
                
    return filtered_contours

In [35]:
# Function to calculate the maximum convexity defect depth for a contour
def max_convexity_defect_depth(contour): 
    
    #some contours trigger errors in calculating defects
    try:
        hull = cv2.convexHull(contour, returnPoints=False)
        defects = cv2.convexityDefects(contour, hull)

    except Exception as e:
        # Handle the error here
        print("Convexity error occured, contour is skipped")
        return float('inf')
    
    if defects is not None:
        depths = [defect[0][3] for defect in defects]
        max_depth = max(depths)
    else:
        max_depth = 0

    return max_depth

In [36]:
# Function to calculate the aspect ratio of the fitted ellipse for a contour
def aspect_ratio(contour):
    # Fit an ellipse to the contour
    ellipse = cv2.fitEllipse(contour)
  
    # Extract the major and minor axes of the ellipse
    major_axis, minor_axis = sorted(ellipse[1])

    # Calculate the aspect ratio (major_axis / minor_axis)
    return major_axis / minor_axis

In [37]:
# Function to calculate the area and perimeter of the contour
def circularity(contour):
    area = cv2.contourArea(contour)
    perimeter = cv2.arcLength(contour, True)
    
    # Calculate circularity (4π × area / perimeter²)
    if area !=0 and perimeter !=0:
        circularity = 4 * np.pi * area / (perimeter * perimeter)
    
    return circularity

In [38]:
# Function to extract individual cells as images
def extract_cell_image(image, contour):
    
    # Find the bounding box around the contour
    x, y, w, h = cv2.boundingRect(contour)

    # Extract the region of interest (ROI) from the input image
    roi = image[y-2:y+h+2, x-2:x+w+2]

    return roi


In [39]:
# Iterates over all images to produce single images
# To prevent non specific cell images, only DAPI is used for drawing contours
def extract_cells(image_paths, segment_threshold, max_depth_threshold, min_size, out_dir_dapi, out_dir_ab):

    #for counting errors and such
    total_sc_count = 0
    total_mc_count = 0
    
    #make dirs for sc images
    os.makedirs(out_dir_dapi, exist_ok=True)
    os.makedirs(out_dir_ab, exist_ok=True)    
    
    #iterate over each image
    img_count = 0
    for img_path in image_paths: 
        img = imread(img_path)
        
        img_array_DAPI, segmented_DAPI = segment_binary_cells(img, 0, segment_threshold)
        img_array_ab, segmented_ab = segment_binary_cells(img, 1, segment_threshold) 
        
        contours_DAPI = find_contours(segmented_DAPI, min_size)
        
        #filtering based on convexity defect magnitude and writing to dir 
        sc_count = 0
        mc_count = 0
        for contour in contours_DAPI:
            max_depth = max_convexity_defect_depth(contour)
           
            if max_depth < max_depth_threshold:
                filename = f'{img_count}.{sc_count+1}.png'
                
                #write
                single_cell_dapi = extract_cell_image(img_array_DAPI, contour)
                output_path_dapi = os.path.join(out_dir_dapi, filename)
                single_cell_ab = extract_cell_image(img_array_ab, contour)
                output_path_ab = os.path.join(out_dir_ab, filename)
                
                
                #some images are null, not clear why
                try:
                    cv2.imwrite(output_path_dapi, single_cell_dapi)
                    cv2.imwrite(output_path_ab, single_cell_ab)
                    sc_count+=1
                
                except Exception as e:
                    # Handle the error here
                    print("Unable to write cell image from", img_path)
            else:
                mc_count+=1
              
        total_sc_count+=sc_count
        total_mc_count+=mc_count
        img_count+=1
    print('Finished writing ', total_sc_count, ' single cell images from ', img_count, ' images')
    print('Removed ',  total_mc_count, ' multi cell images')

In [40]:
def extract_maskedbg_cell(image_paths, segment_threshold, max_depth_threshold, min_contour_size, out_dir_dapi, out_dir_ab):
    
    total_sc_count = 0
    total_mc_count = 0

    os.makedirs(out_dir_dapi, exist_ok=True)
    os.makedirs(out_dir_ab, exist_ok=True)

    img_count = 0
    for img_path in image_paths:
        img = imread(img_path)

        img_array_DAPI, segmented_DAPI = segment_binary_cells(img, 0, segment_threshold)
        img_array_ab, segmented_ab = segment_binary_cells(img, 1, segment_threshold)

        contours_DAPI = find_contours(segmented_DAPI, min_contour_size)

        sc_count = 0
        mc_count = 0
        for contour in contours_DAPI:
            max_depth = max_convexity_defect_depth(contour)

            if max_depth < max_depth_threshold:
                filename = f'{img_count}.{sc_count+1}.png'
                
                # Create a mask with the same shape as the contour
                mask = np.zeros_like(img_array_DAPI)
                cv2.drawContours(mask, [contour], 0, (255), thickness=-1)

                # Apply the mask to the cell images
                whole_img_mask_dapi = cv2.bitwise_and(img_array_DAPI, mask)
                whole_img_mask_ab = cv2.bitwise_and(img_array_ab, mask)

                masked_single_cell_dapi = extract_cell_image(whole_img_mask_dapi, contour)
                masked_single_cell_ab = extract_cell_image(whole_img_mask_ab, contour)
                
                # Write masked cell images
                output_path_dapi = os.path.join(out_dir_dapi, filename)
                output_path_ab = os.path.join(out_dir_ab, filename)

                try:
                    cv2.imwrite(output_path_dapi, masked_single_cell_dapi)
                    cv2.imwrite(output_path_ab, masked_single_cell_ab)
                    sc_count+=1
                except Exception as e:
                    print("Unable to write cell image from", img_path)
            else:
                mc_count+=1

        total_sc_count+=sc_count
        total_mc_count+=mc_count
        img_count+=1
    print('Finished writing ', total_sc_count, ' single cell images with BG mask from ', img_count, ' images')
    print('Removed ',  total_mc_count, ' multi cell images')

In [41]:
def extract_maskedbg_cell_merge(image_paths, segment_threshold, max_depth_threshold, out_dir_combined):
    
    total_sc_count = 0
    total_mc_count = 0

    os.makedirs(out_dir_combined, exist_ok=True)

    img_count = 0
    for img_path in image_paths:
        img = imread(img_path)

        img_array_DAPI, segmented_DAPI = segment_binary_cells(img, 0, segment_threshold)
        img_array_ab, segmented_ab = segment_binary_cells(img, 1, segment_threshold)

        contours_DAPI = find_contours(segmented_DAPI, 500)

        sc_count = 0
        mc_count = 0
        for contour in contours_DAPI:
            max_depth = max_convexity_defect_depth(contour)
            if max_depth < max_depth_threshold:
                filename = f'{img_count}.{sc_count+1}'
                
                # Create a mask with the same shape as the contour
                mask = np.zeros_like(img_array_DAPI)
                cv2.drawContours(mask, [contour], 0, (255), thickness=-1)

                # Apply the mask to the cell images
                whole_img_mask_dapi = cv2.bitwise_and(img_array_DAPI, mask)
                whole_img_mask_ab = cv2.bitwise_and(img_array_ab, mask)

                masked_single_cell_dapi = extract_cell_image(whole_img_mask_dapi, contour)
                masked_single_cell_ab = extract_cell_image(whole_img_mask_ab, contour)
                
                 # Stack the channels
                combined_image = np.stack((masked_single_cell_dapi, masked_single_cell_ab), axis=-1)

                # Write masked cell images
                output_path_combined = os.path.join(out_dir_combined, filename)

                try:
                    np.save(output_path_combined, combined_image)
                    sc_count+=1
                except Exception as e:
                    print("Unable to write cell image from", img_path)
            else:
                mc_count+=1

        total_sc_count+=sc_count
        total_mc_count+=mc_count
        img_count+=1
        
    print('Finished writing ', total_sc_count, ' single cell images with BG mask from ', img_count, ' images')
    print('Removed ',  total_mc_count, ' multi cell images')

In [42]:
#shows n number of images in the specified directory and prints total number of images 

def show_n_random_images(img_dir, n, row_length):
    
    # get list of png files in directory
    png_files = [f for f in os.listdir(img_dir) if f.endswith('.png')]

    # choose n random png files
    random_png_files = random.sample(png_files, n)

    # display images
    print('Displaying ', n, ' images out of ', len(png_files), ' in ', img_dir)
    fig, axs = plt.subplots((n+row_length-1)//row_length, row_length, figsize=(row_length*2.5,15))
    axs = axs.ravel()
    
    for i, file in enumerate(random_png_files):
        img_path = os.path.join(img_dir, file)
        img = Image.open(img_path)
        axs[i].imshow(np.array(img), cmap='gray')
        axs[i].set_title(file)
        axs[i].axis('off')

    plt.show()
    return random_png_files

In [43]:
def show_images(img_dir, filepaths):
    
    n = len(filepaths)

    # display images
    print('Displaying ', n, ' images in ', img_dir)
    fig, axs = plt.subplots((n+4)//5, 5, figsize=(15,15))
    axs = axs.ravel()
    
    for i, file in enumerate(filepaths):
        img_path = os.path.join(img_dir, file)
        img = Image.open(img_path)
        axs[i].imshow(np.array(img), cmap='gray')
        axs[i].set_title(file)
        axs[i].axis('off')

    plt.show()
    

In [44]:
def display_batch(train_loader):
    data_iter = iter(train_loader)
    images, labels = next(data_iter)
    num_images = images.size(0)
    
    nrows = (num_images + 7) // 8
    ncols = min(num_images, 8)
    
    fig, axes = plt.subplots(nrows, ncols, figsize=(2 * ncols, 2 * nrows))
    
    for i in range(nrows):
        for j in range(ncols):
            idx = i * 8 + j
            if idx < num_images:
                img = images[idx].numpy().transpose((1, 2, 0))
                img = (img * 255).astype(np.uint8)
                if nrows == 1:
                    axes[j].imshow(img)
                    axes[j].axis('off')
                else:
                    axes[i, j].imshow(img)
                    axes[i, j].axis('off')
    plt.show()

In [45]:
def make_train_test_data(train_dir, test_dir, data_dir, file_num_limit):
    
    base_data_dir = os.path.basename(data_dir)
    image_files = [f for f in os.listdir(data_dir) if f.endswith('.png')]

    # Split the images into training and test sets
    train_files, test_files = train_test_split(image_files, test_size=0.1, random_state=42)
    train_files_subset = random.sample(train_files, file_num_limit)
    
    # Create the training and test directories
    train_dir = os.path.join(train_dir, base_data_dir)
    test_dir = os.path.join(test_dir, base_data_dir)
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)

    # Copy the training images to the training directory
    for f in train_files_subset:
        src = os.path.join(data_dir, f)
        dst = os.path.join(train_dir, os.path.basename(f))
        copyfile(src, dst)

    # Copy the test images to the test directory
    for f in test_files:
        src = os.path.join(data_dir, f)
        dst = os.path.join(test_dir, os.path.basename(f))
        copyfile(src, dst)
    
    print('done')

In [46]:
def transform_load_data(train_dir, test_dir, batchsize, transform_x, transform_y, mean, std):
    # Define transforms for training and testing data
    train_transforms = transforms.Compose([
        transforms.Resize((transform_x, transform_y)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    test_transforms = transforms.Compose([
        transforms.Resize((transform_x, transform_y)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    # Load the datasets
    train_data = datasets.ImageFolder(train_dir, transform=train_transforms)
    test_data = datasets.ImageFolder(test_dir, transform=test_transforms)

    # Create dataloaders
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batchsize, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batchsize, shuffle=True)
    total_step = len(train_loader)
    
    return train_loader, test_loader

In [47]:
def transform_load_data_nonorm(train_dir, test_dir, batchsize, transform_x, transform_y):
    # Define transforms for training and testing data
    train_transforms = transforms.Compose([
        transforms.Resize((transform_x, transform_y)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
    ])

    test_transforms = transforms.Compose([
        transforms.Resize((transform_x, transform_y)),
        transforms.ToTensor(),
    ])

    # Load the datasets
    train_data = datasets.ImageFolder(train_dir, transform=train_transforms)
    test_data = datasets.ImageFolder(test_dir, transform=test_transforms)

    # Create dataloaders
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batchsize, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batchsize, shuffle=True)
    total_step = len(train_loader)
    
    return train_loader, test_loader

In [48]:
def calculate_mean_and_std(train_dir):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    
    train_data = datasets.ImageFolder(train_dir, transform=transform)
    dataloader = DataLoader(dataset=train_data, batch_size=64)
    mean, std = 0, 0
    for images, _ in dataloader:
        batch_samples = images.size(0)
        images = images.view(batch_samples, images.size(1), -1)
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)
    mean /= len(dataloader.dataset)
    std /= len(dataloader.dataset)
    return mean, std

In [5]:
# Trains the specified model for a set number of epochs
def train_model(model, train_loader, test_loader, optimizer, criterion, num_epochs, csv_name, save_model_name=None):
    
    #vars for saving model only
    valid_test = float('inf')
    best_valid_epoch = 0
    
    #for counting time for training
    t1 = time.perf_counter()
    print('Begin model training for {} epochs'.format(num_epochs))

    train_losses = []
    test_losses = []
    train_acc = []
    test_acc = []
    time_list = []

    for epoch in range(num_epochs):
        train_loss = 0
        test_loss = 0

        #Training 
        correct_t = 0
        total_t = 0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs) #tensor produced, indecipherable when printed 
            loss = criterion(outputs, labels) #criterion is loss function
            loss.backward() #used to calculate the gradients of the parameters of a model with respect to a loss function
            optimizer.step() #updates the model parameters based on the gradients computed during the backward pass of training

            _, preds = torch.max(outputs, 1) #produces tensor containing indices of the maximum values (i.e. the predicted classes)
            correct_t += (preds == labels).sum().item()
            train_loss += loss.item()
            total_t += labels.size(0) #total equals 50000 by the end of this for loop for CIFAR10
            #correct_t += (preds == labels).sum().item()
            #print(preds==labels)
            #print()

        #Validation
        correct_v = 0
        total_v = 0
        with torch.no_grad():
            for i, data in enumerate(test_loader, 0):
                inputs, labels = data
                #optimizer.zero_grad() #can probably be removed - no parameters being updated
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
                correct_v += (preds == labels).sum().item()
                test_loss += loss.item()
                total_v += labels.size(0) #total equals 10000 by the end of this for loop for CIFAR10

        #defining accuracy and loss 
        t_acc = correct_t/total_t
        v_acc = correct_v/total_v
        new_train_loss = train_loss / len(train_loader)
        new_valid_loss = test_loss / len(test_loader)

        #prints stuff
        t2 = time.perf_counter()
        print('Training loss for Epoch {} is {:.4f} and Training accuracy is {:.3f}'.format(epoch + 1, new_train_loss, t_acc))
        print('Validation loss for Epoch {} is {:.4f} and Validation accuracy is {:.3f}'.format(epoch + 1, new_valid_loss, v_acc))
        print('Completed Epoch {} in {:.1f} seconds'.format(epoch + 1, t2-t1))

        #makes list of loss, accuract, and time for epoch
        train_acc.append(t_acc)
        test_acc.append(v_acc)
        train_losses.append(new_train_loss)
        test_losses.append(new_valid_loss)
        time_list.append(t2-t1)

        #saves the model if validation loss has decreased
        if new_valid_loss < valid_test and save_model_name is not None:
            #saving model...
            if save_model_name.endswith(".pt") and isinstance(save_model_name, str):
                torch.save(model.state_dict(), save_model_name)
                print ('Test loss improvement ({:.4f} -----> {:.4f}), model saved as {}'.format(valid_test, new_valid_loss, save_model_name))
                valid_test = new_valid_loss
                best_valid_epoch = epoch + 1
            
            else:
                print("Model name is in wrong form (ends with .pt)")
            
        else:
            print('No improvement in test loss, best model saved at Epoch {} with validation loss of {:.4f}').format(best_valid_epoch, valid_test)        

        # Write train and test lists to a CSV file
        with open(csv_name, 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(['Training Loss', 'Test Loss', 'Train Accuracy', 'Test Accuracy', 'Time'])
            rows = zip(train_losses, test_losses, train_acc, test_acc, time_list)
            writer.writerows(rows)    
        
    return train_losses, test_losses, train_acc, test_acc, time_list

In [50]:
#thx chatgpt
def show_images_from_loader(loader, labels=None, num_batches=1, batch_size=20):
    # Get specified number of batches from the loader
    images = []
    actual_labels = []
    for i, (image_batch, label_batch) in enumerate(loader):
        if i == num_batches:
            break
        images.append(image_batch)
        actual_labels.append(label_batch)

    images = torch.cat(images, dim=0)
    actual_labels = torch.cat(actual_labels, dim=0)
    
    # If labels are specified, filter the images and labels based on the specified labels
    if labels is not None:
        mask = torch.zeros_like(actual_labels, dtype=torch.bool)
        for label in labels:
            mask = mask | (actual_labels == label)
        images = images[mask]
        actual_labels = actual_labels[mask]
    
    # Plot the images
    num_images = min(batch_size*num_batches, len(images))
    num_rows = (num_images + batch_size - 1) // batch_size
    fig, axes = plt.subplots(nrows=num_rows, ncols=batch_size, figsize=(12, 6*num_rows),
                             subplot_kw={'xticks': [], 'yticks': []})

    for i, ax in enumerate(axes.flat):
        if i < num_images:
            ax.imshow(np.transpose(images[i], (1, 2, 0)))
            ax.set_title(f'Label: {actual_labels[i].item()}')

    plt.show()


In [51]:
def plot_confusion_matrix(model, device, test_loader, classes):
    correct = 0
    total = 0
    
    # Evaluate the model on the test set
    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            y_true += labels.tolist()
            y_pred += preds.tolist()
            total += labels.size(0)
            correct += (preds == labels).sum().item()

    print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))
    print('Test loss: %d %%' % (None))

    
    # Generate a confusion matrix
    labels = list(range(len(classes)))
    cm = confusion_matrix(y_true, y_pred, labels=labels)

    # Print the confusion matrix
    print(cm)

    # Output the sum of each row as a list and print it
    sum_by_row = cm.sum(axis=1).tolist()
    print("Sum by row:", sum_by_row)
    
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)

    # Normalize color intensity to the row
    plt.imshow(np.log(cm_norm), interpolation='nearest', cmap='Blues')

    plt.colorbar()

    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' #2 decimal places
    thresh = cm_norm.max() / 3.
    for i in range(cm_norm.shape[0]):
        for j in range(cm_norm.shape[1]):
            plt.text(j, i, format(cm_norm[i, j], fmt),
                     horizontalalignment="center",
                     color="white" if cm_norm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')


In [52]:
def eval_model(model, device, test_loader, num_classes):
    correct = 0
    total = 0

    # Evaluate the model on the test set
    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            y_true += labels.tolist()
            y_pred += preds.tolist()
            total += labels.size(0)
            correct += (preds == labels).sum().item()

    print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))

    # Generate a confusion matrix
    labels = list(range(num_classes))
    cm = confusion_matrix(y_true, y_pred, labels=labels)

    # Print the confusion matrix
    print(cm)

    # Output the sum of each row as a list and print it
    row_labels = ['True Label ' + str(i) for i in range(num_classes)]
    sum_by_row = cm.sum(axis=1).tolist()
    row_dict = OrderedDict(zip(row_labels, sum_by_row))
    print("Sum by row:", row_dict)
    
    return cm

In [53]:
#merged two functions
#Validation
def eval_and_plot(model, criterion, device, test_loader, num_classes, class_labels):
    correct = 0
    total = 0
    test_loss = 0

    # Evaluate the model on the test set
    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
            y_true += labels.tolist()
            y_pred += preds.tolist()
            test_loss += loss.item()
            total += labels.size(0)
            correct += (preds == labels).sum().item()
    test_loss = test_loss / len(test_loader)
    print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))
    print('Test loss of the network on the test images is:', test_loss)
          
    # Generate a confusion matrix
    cm = confusion_matrix(y_true, y_pred, labels=class_labels)

    # Print the confusion matrix
    print(cm)

    # Output the sum of each row as a list and print it
    row_labels = ['True Label ' + str(i) for i in class_labels]
    sum_by_row = cm.sum(axis=1).tolist()
    row_dict = OrderedDict(zip(row_labels, sum_by_row))
    print("Sum by row:", row_dict)
    
    # Plot the confusion matrix
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)

    # Normalize color intensity to the row
    plt.imshow(np.log(cm_norm), interpolation='nearest', cmap='Blues')

    plt.colorbar()

    tick_marks = np.arange(len(class_labels))
    plt.xticks(tick_marks, class_labels, rotation=45)
    plt.yticks(tick_marks, class_labels)

    fmt = '.2f'
    thresh = cm_norm.max() / 3.
    for i in range(cm_norm.shape[0]):
        for j in range(cm_norm.shape[1]):
            plt.text(j, i, format(cm_norm[i, j], fmt),
                     horizontalalignment="center",
                     color="white" if cm_norm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()
