In [1]:
%matplotlib inline
import os, sys, re
import glob

import pandas as pd
import numpy as np
import torch
import torch.utils.data
import torch.nn

from random import randrange
from PIL import Image
import matplotlib.pyplot as plt

!pip install opencv-python -qqq
!pip install wandb -qqq
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mvalenetjong[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
import argparse
""" Training and hyperparameter search configurations """
curr_dir = os.getcwd()

parser = argparse.ArgumentParser(description='Final')
parser.add_argument('--img_dir', type=str, default='/Users/valenetjong/alzheimer-classification/data',
                    help='directory for image storage')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed (default: 1)')
parser.add_argument('--num_classes', type=int, default=3,
                    help='number of classes')
parser.add_argument('--loss', type=str, default="cross entropy",
                    help='cross entropy, focal')
parser.add_argument('--download_oasis', type=bool, default=False,
                    help="download oasis dataset from links if True, use already extracted files, if False")
parser.add_argument('--process_flag', type=bool, default=False,
                    help="extract files from disk if True, use already extracted files, if False")
parser.add_argument('--create_dataset', type=bool, default=False,
                    help="create dataset from scratch if True, load in processed dataset if False")
parser.add_argument('--transforms', type=str, default='all',
                    help='transforms for data augmentation')
parser.add_argument('--threshold', type=float, default=3e-4,
                    help='early stopping criterion')
args = parser.parse_args('')
# Set random seed to reproduce results
torch.manual_seed(args.seed)

<torch._C.Generator at 0x2332a063810>

In [3]:
""" Set-up wandb """
sweep_config = {
    'method': 'bayes'
    }

metric = {
    'name': 'max val acc',
    'goal': 'maximize'   
    }

sweep_config['metric'] = metric

params = {
    'max_epochs': {
        'value': 250
        },
    'hidden_size': {
        'values': [8, 16],
        },
    'fc_size': {
        'values': [32, 64, 128, 256, 512]
        },
    'conv_in_size': {
        'values': [32, 64, 128, 256]
        },
    'conv_hid_size': {
        'values': [8, 16, 32]
        },
    'conv_out_size': {
        'values': [8, 16, 32]
        },
    'dropout': {
          'values': [0.15, 0.2, 0.25, 0.3]
        },
    'batch_size': {
        'distribution': 'q_log_uniform_values',
        'q': 8,
        'min': 8,
        'max': 64,
        },
    'lr': {
        'values': [1e-3, 1e-4, 1e-5]
        },
    }

sweep_config['parameters'] = params
sweep_id = wandb.sweep(sweep_config, project="2D-masked-imgs")

Create sweep with ID: 476ov6p8
Sweep URL: https://wandb.ai/valenetjong/2D-masked-imgs/sweeps/476ov6p8


### Download Files

In [4]:
import requests
import os
import tarfile

def download_file(url, local_filename):
    """
    Downloads a file from a given URL and saves it to a local path.
    """
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(local_filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
    return local_filename

def download_oasis1(base_dir="/Users/valenetjong/Downloads/"):
    base_url = "https://download.nrg.wustl.edu/data/oasis_cross-sectional_disc"
    total_disks = 12

    for i in range(1, total_disks + 1):
        url = f"{base_url}{i}.tar.gz"
        local_filename = f"oasis_cross-sectional_disc{i}.tar.gz"
        full_file_path = os.path.join(base_dir, local_filename)

        # Check if the file already exists
        if os.path.exists(full_file_path):
            print(f"File {local_filename} already exists. Skipping download.")
            continue

        print(f"Downloading: {url}")
        
        try:
            download_file(url, full_file_path)
            print(f"Downloaded {local_filename}")
        except Exception as e:
            print(f"Failed to download {local_filename}: {e}")

def extract_tar_gz(tar_path, extract_to_path):
    """
    Extracts a .tar.gz file to a specified directory.
    """
    with tarfile.open(tar_path, 'r:gz') as tar:
        tar.extractall(path=extract_to_path)
        print(f"Extracted {tar_path} to {extract_to_path}")

def extract_all_discs(base_disc_path="/Users/valenetjong/Downloads/", 
                    extract_to_path="/Users/valenetjong/Downloads/"):
    total_disks = 12

    for i in range(1, total_disks + 1):
        if os.path.exists(extract_to_path + f"/disc{i}") and os.path.isdir(extract_to_path + f"/disc{i}"):
            print(f"Folder for disc{i} already exists. Skipping extraction.")
            continue
        tar_path = os.path.join(base_disc_path, f"oasis_cross-sectional_disc{i}.tar.gz")
        os.makedirs(extract_to_path, exist_ok=True)
        extract_tar_gz(tar_path, extract_to_path)

        # Remove the tar.gz file after extraction
        # os.remove(tar_path)
        # print(f"Removed the archive: {tar_path}")

In [5]:
if args.download_oasis:
    download_oasis1()

In [6]:
if args.process_flag:
    extract_all_discs()

### Pre-processing

In [7]:
import cv2 as cv
import tempfile
import shutil

""" Pre-processing Functions """

DEMENTIA_MAP = {
    '0.0': "nondemented",
    '0.5': "mildly demented",
    '1.0': 'moderately demented',
}

# Pre-determined max dimensions of cropped images
CONV_WIDTH = 137
CONV_HEIGHT = 167

def normalize_intensity(img):
    """
    Normalizes the intensity of an image to the range [0, 255].

    Parameters:
    img: The image to be normalized.

    Returns:
    Normalized image.
    """
    img_min = img.min()
    img_max = img.max()
    normalized_img = (img - img_min) / (img_max - img_min) * 255
    return normalized_img.astype(np.uint8)

def pad_image_to_size(img, width, height):
    """
    Pads an image with zeros to the specified width and height.

    Parameters:
    img: The image to be padded.
    width: The desired width.
    height: The desired height.

    Returns:
    Padded image.
    """
    padded_img = np.zeros((height, width), dtype=img.dtype)
    y_offset = (height - img.shape[0]) // 2
    x_offset = (width - img.shape[1]) // 2
    padded_img[y_offset:y_offset+img.shape[0], x_offset:x_offset+img.shape[1]] = img
    return padded_img

def crop_black_boundary(mri_image):
    """
    Crops the black boundary from an MRI image.

    Parameters:
    mri_image: Input MRI image.

    Returns:
    Cropped MRI image with black boundaries removed.
    """
    _, thresh = cv.threshold(mri_image, 1, 255, cv.THRESH_BINARY)
    contours, _ = cv.findContours(thresh, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
    largest_contour = max(contours, key=cv.contourArea)
    x, y, w, h = cv.boundingRect(largest_contour)
    cropped_image = mri_image[y:y+h, x:x+w]
    return cropped_image

def extract_files(base_dir, target_dir, oasis_csv_path):
    """
    Extracts and processes MRI files from a given directory.

    Parameters:
    base_dir: Directory containing MRI files.
    target_dir: Directory where processed files will be saved.
    oasis_csv_path: Path to the CSV file containing metadata.
    """
    oasis_df = pd.read_csv(oasis_csv_path)

    for subdir in filter(lambda d: d != '.DS_Store', os.listdir(base_dir)):
        source_dir = os.path.join(base_dir, subdir, "FSL_SEG")
        print("source_dir", source_dir)
        num = subdir.split('_')[1]
        id = f'OAS1_{num}_MR1'
        num = int(num)
        row = oasis_df.loc[oasis_df['ID'] == id]
        dementia_type = row['CDR'].item()
        
        if pd.isna(dementia_type):
            continue

        for n_suffix in ['n3', 'n4']:
            fn = os.path.join(source_dir, f"{subdir}_mpr_{n_suffix}_anon_"
                                  f"111_t88_masked_gfc_fseg_tra_90.gif")
            if os.path.exists(fn):
                process_image(fn, target_dir, dementia_type, id)

def process_image(fn, target_dir, dementia_type, id):
    """
    Processes a single MRI image file and saves it to the target directory.

    Parameters:
    fn: Path of the file to be processed.
    target_dir: Directory where the processed file will be saved.
    dementia_type: Type of dementia associated with the image.
    id: Patient identifier associated with the image.
    """
    with Image.open(fn) as img:
        img = np.array(img.convert('RGB'))
        img = cv.cvtColor(img, cv.COLOR_RGB2GRAY)
    img = crop_black_boundary(img)
    img = normalize_intensity(img)
    img = pad_image_to_size(img, CONV_WIDTH, CONV_HEIGHT)

    target_subdir = os.path.join(target_dir, DEMENTIA_MAP[str(dementia_type)])
    os.makedirs(target_subdir, exist_ok=True)
    target_path = os.path.join(target_subdir, f"{id}.png")
    cv.imwrite(target_path, img)

def process_all_discs(base_disc_path, base_extraction_path, oasis_csv_path):
    """
    Processes all discs found in the base directory.

    Parameters:
    base_disc_path: Base path where the discs are located.
    base_extraction_path: Base path where processed data will be saved.
    oasis_csv_path: Path to the OASIS CSV file.
    """
    total_disks = 12

    for i in range(1, total_disks + 1):
        disc_path = f'{base_disc_path}/disc{i}'
        if not os.path.exists(disc_path):
            print(f"Disc {i} does not exist at path {disc_path}. Skipping.")
            continue
        extract_files(disc_path, base_extraction_path, oasis_csv_path)
        print(f"Processed Disc {i}")

def cleanup_directory(path):
    """
    Deletes a directory and all of its contents.

    Parameters:
    path: Path of the directory to be deleted.
    """
    try:
        shutil.rmtree(path)
        print(f"Cleaned up and deleted the directory: {path}")
    except OSError as e:
        print(f"Error: {e.filename} - {e.strerror}")

In [8]:
if args.process_flag:
    base_disc_path = '/Users/valenetjong/Downloads'
    base_extraction_path = '/Users/valenetjong/alzheimer-classification/data'
    oasis_csv_path = '/Users/valenetjong/alzheimer-classification/datacsv/oasis_cross-sectional.csv'
    process_all_discs(base_disc_path, base_extraction_path, oasis_csv_path)

In [9]:
import os
import torch
from torchvision import transforms
from PIL import Image
from collections import Counter

LABEL_MAP = {
    "nondemented": 0,
    "mildly demented": 1,
    'moderately demented': 1 if args.num_classes == 2 else 2
}

def load_dataset(base_dir):
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])    
    all_images = []
    all_labels = []
    class_counts = Counter()

    # Automatically find all subdirectories in base_dir
    for folder_name in os.listdir(base_dir):
        folder_path = os.path.join(base_dir, folder_name)
        if os.path.isdir(folder_path):  # Check if it's a directory
            class_label = LABEL_MAP[folder_name]
            for image_file in os.listdir(folder_path):
                image_path = os.path.join(folder_path, image_file)
                if os.path.isfile(image_path):
                    with Image.open(image_path) as img:
                        img_tensor = transform(img)
                        all_images.append(img_tensor)
                        all_labels.append(class_label)
                        class_counts[folder_name] += 1

    X = torch.stack(all_images)
    y = torch.tensor(all_labels, dtype=torch.long)  # Changed to long for integer labels
    return X, y, class_counts

if args.create_dataset:
    X, y, class_counts = load_dataset(args.img_dir)

    print(f"Combined Tensor Size: {X.size()}")
    print(f"Labels Tensor Size: {y.size()}")
    print(f"Class Counts: {class_counts}")

In [10]:
import torch
from sklearn.model_selection import train_test_split

def train_val_split(X, y, test_size=0.2, random_state=42, stratified=True):
    # Convert X and y to numpy arrays if they are torch tensors
    X_np = X.numpy() if isinstance(X, torch.Tensor) else X
    y_np = y.numpy() if isinstance(y, torch.Tensor) else y

    # Stratified split
    if stratified:
        X_train, X_val, y_train, y_val = train_test_split(
            X_np, y_np, test_size=test_size, random_state=random_state, stratify=y_np
        )
    # Random split
    else:
        X_train, X_val, y_train, y_val = train_test_split(
            X_np, y_np, test_size=test_size, random_state=random_state
        )

    # Convert numpy arrays back to torch tensors
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.long)
    X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
    y_val_tensor = torch.tensor(y_val, dtype=torch.long)

    return X_train_tensor, X_val_tensor, y_train_tensor, y_val_tensor

if args.create_dataset:
    X_train, X_val, y_train, y_val = train_val_split(X, y, test_size=0.2)

    print(f'Training set size: {X_train.shape[0]}')
    print(f'Validation set size: {X_val.shape[0]}')

In [11]:
if args.create_dataset:
    print(f"Number of nondemented in train dataset as percentage: {((y_train == 0).sum() / (X_train.shape[0])) * 100:0.2f}%")
    print(f"Number of mildly demented in train dataset as percentage: {((y_train == 1).sum() / (X_train.shape[0])) * 100:0.2f}%")
    print(f"Number of moderately demented in train dataset as percentage: {((y_train == 2).sum() / (X_train.shape[0])) * 100:0.2f}%")

In [12]:
if args.create_dataset:
    print(f"Number of nondemented in train dataset as percentage: {((y_val == 0).sum() / (X_val.shape[0])) * 100:0.2f}%")
    print(f"Number of mildly demented in train dataset as percentage: {((y_val == 1).sum() / (X_val.shape[0])) * 100:0.2f}%")
    print(f"Number of moderately demented in train dataset as percentage: {((y_val == 2).sum() / (X_val.shape[0])) * 100:0.2f}%")

In [13]:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, TensorDataset
import random

""" Transforms w/ probability """
def custom_random_rotation(image, probability=0.25, min_degree=20, max_degree=40):
    if random.random() < probability:
        degrees = random.randint(min_degree, max_degree)
        return transforms.RandomRotation(degrees=degrees)(image)
    return image

def custom_random_resized_crop(image, probability=0.25, size=(CONV_HEIGHT, CONV_WIDTH), scale=(0.9, 1.0)):
    if random.random() < probability:
        return transforms.RandomResizedCrop(size=size, scale=scale)(image)
    return image

def custom_random_horizontal_flip(image, probability=0.25):
    if random.random() < probability:
        return transforms.RandomHorizontalFlip()(image)
    return image

def custom_random_affine(image, probability=0.25, translate=(0.1, 0.1), scale=None, shear=10):
    if random.random() < probability:
        return transforms.RandomAffine(degrees=0, translate=translate, scale=scale, shear=shear)(image)
    return image

def custom_color_jitter(image, probability=0.25, brightness=0.2, contrast=0.2):
    if random.random() < probability:
        return transforms.ColorJitter(brightness=brightness, contrast=contrast)(image)
    return image

def apply_transforms(X):
    transformed_data = []
    for x in X:
        x = custom_random_rotation(x)
        x = custom_random_resized_crop(x)
        x = custom_random_horizontal_flip(x)
        x = custom_random_affine(x)
        x = custom_color_jitter(x)
        transformed_data.append(x)
    return torch.stack(transformed_data)

def apply_all_transforms(X, transform):
    transformed_data = []
    for x in X:
        x = transform(x)  # Apply the transformation
        transformed_data.append(x)
    return torch.stack(transformed_data)

all_train_transform = transforms.Compose([
    transforms.RandomRotation(degrees=20),
    transforms.RandomResizedCrop(size=(CONV_HEIGHT, CONV_WIDTH), scale=(0.9, 1.0)),
    transforms.RandomHorizontalFlip(),
    # transforms.ColorJitter(brightness=0.2, contrast=0.2), # You can adjust the values for brightness and contrast/
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=None, shear=10),
])

In [14]:
if args.create_dataset:
    print(X_train.shape)

In [15]:
if args.create_dataset:
    num_augment = 1
    # augment_list += [apply_transforms(X_train) for _ in range(num_augment-1)]
    X_augmented = apply_all_transforms(X_train, all_train_transform)
    y_augmented = y_train

    print(X_augmented.shape)
    print(y_augmented.shape)

    torch.save(X_augmented, 'X_augmented.pt')
    torch.save(y_augmented, 'y_augmented.pt')
    torch.save(X_val, 'X_val.pt')
    torch.save(y_val, 'y_val.pt')

### Handle Disproportionate Classes

In [16]:
import torch
import torch.nn as nn
from collections import Counter

def calculate_class_weights(y_train):
    # Count the frequency of each class
    class_counts = Counter(y_train.numpy())
    total_samples = sum(class_counts.values())

    # Calculate weights: Inverse of frequency
    weights = {class_id: total_samples/class_counts[class_id] for class_id in class_counts}

    # Convert to a list in the order of class ids
    weights_list = [weights[i] for i in sorted(weights)]
    return torch.tensor(weights_list, dtype=torch.float32)

#### Define CNN Model

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Add skip connections 
# Number of conv. features should be correlated to number of segments
# Other transformation types ()

class DeepCNNModel(nn.Module):
    def __init__(self, fc_size, conv_in_size, conv_hid_size, conv_out_size, dropout, num_classes=3):
        super(DeepCNNModel, self).__init__()
        
        # Convolutional Block 1
        self.conv1 = nn.Conv2d(1, conv_in_size, kernel_size=3, padding=1)  
        self.bn1 = nn.BatchNorm2d(conv_in_size)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        
        # Convolutional Block 2
        self.conv2 = nn.Conv2d(conv_in_size, conv_hid_size, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(conv_hid_size)
        self.pool2 = nn.MaxPool2d(kernel_size=3)
        
        # Convolutional Block 3
        self.conv3 = nn.Conv2d(conv_hid_size, conv_hid_size, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(conv_hid_size)
        self.pool3 = nn.MaxPool2d(kernel_size=2)
        
        # Convolutional Block 4
        self.conv4 = nn.Conv2d(conv_hid_size, conv_out_size, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(conv_out_size)
        self.pool4 = nn.MaxPool2d(kernel_size=3)

        # Compute the flattened size for the fully connected layer
        self._to_linear = None
        self._forward_conv(torch.randn(1, 1, 137, 167))

        # Fully connected layers
        self.fc1 = nn.Linear(self._to_linear, fc_size)
        self.dropout1 = nn.Dropout(p=dropout)
        self.fc2 = nn.Linear(fc_size, num_classes)
        self.dropout2 = nn.Dropout(p=dropout)

    def _forward_conv(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        x = self.pool4(F.relu(self.bn4(self.conv4(x))))
        if self._to_linear is None:
            self._to_linear = x[0].shape[0] * x[0].shape[1] * x[0].shape[2]
        return x

    def forward(self, x):
        x = self._forward_conv(x)
        x = x.view(-1, self._to_linear)  # Flatten the output for the fully connected layers
        x = self.dropout1(F.relu(self.fc1(x)))
        x = self.dropout2(self.fc2(x))
        return F.log_softmax(x, dim=1)

In [18]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean', num_classes=args.num_classes):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.reduction = reduction
        if alpha is None:
            self.alpha = torch.ones(num_classes)
        else:
            if isinstance(alpha, (float, int)):
                self.alpha = torch.ones(num_classes) * alpha
            else:
                self.alpha = torch.tensor(alpha)
        self.alpha = self.alpha / self.alpha.sum()
        self.num_classes = num_classes

    def forward(self, inputs, targets):
        # Convert targets to one-hot
        targets_one_hot = F.one_hot(targets, num_classes=self.num_classes).to(inputs.device)

        # Compute the log softmax
        log_softmax = F.log_softmax(inputs, dim=1)

        # Compute the loss per class
        loss_per_class = -targets_one_hot * log_softmax

        # Compute the focal loss factors
        softmax_probs = torch.exp(log_softmax)
        focal_factors = (1 - softmax_probs) ** self.gamma

        # Apply alpha weighting and focal factors
        alpha_factors = self.alpha.to(inputs.device).unsqueeze(0)
        loss = alpha_factors * focal_factors * loss_per_class

        # Sum over classes and compute the final loss based on reduction
        loss = loss.sum(dim=1)
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

In [19]:
import logging

class WandbModelCheckpoint:
    def __init__(self, dirpath, decreasing=False, top_n=1):
        """
        dirpath: Directory path where to store all model weights 
        decreasing: If decreasing is `True`, then lower metric is better
        top_n: Total number of models to track based on validation metric value
        """
        if not os.path.exists(dirpath): os.makedirs(dirpath)
        self.dirpath = dirpath
        self.top_n = top_n 
        self.decreasing = decreasing
        self.top_model_paths = []
        self.best_metric_val = np.Inf if decreasing else -np.Inf
        
    def __call__(self, model, epoch, metric_val):
        model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_epoch{epoch}.pt')
        save = metric_val<self.best_metric_val if self.decreasing else metric_val>self.best_metric_val
        if save: 
            logging.info(f"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}, & logging model weights to W&B.")
            self.best_metric_val = metric_val
            torch.save(model.state_dict(), model_path)
            self.log_artifact(f'model-ckpt-epoch-{epoch}.pt', model_path, metric_val)
            self.top_model_paths.append({'path': model_path, 'score': metric_val})
            self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)
        if len(self.top_model_paths)>self.top_n: 
            self.cleanup()
    
    def log_artifact(self, filename, model_path, metric_val):
        artifact = wandb.Artifact(filename, type='model', metadata={'Validation score': metric_val})
        artifact.add_file(model_path)
        wandb.run.log_artifact(artifact)        
    
    def cleanup(self):
        to_remove = self.top_model_paths[self.top_n:]
        logging.info(f"Removing extra models.. {to_remove}")
        for o in to_remove:
            os.remove(o['path'])
        self.top_model_paths = self.top_model_paths[:self.top_n]

In [20]:
checkpoint_dir = "./model_checkpoints"
checkpoint = WandbModelCheckpoint(checkpoint_dir, decreasing=False, top_n=1)

### Training and Validation

In [21]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
import torch.nn.functional as F

# Training Function
def train_model(config=None):
    with wandb.init(config=config):
        config = wandb.config
        # wandb.define_metric("epoch")
        # wandb.define_metric("val acc", step_metric="epoch")
        model = DeepCNNModel(config.fc_size, config.conv_in_size, config.conv_hid_size, config.conv_out_size, config.dropout, num_classes=3)
        optimizer = optim.Adam(model.parameters(), config.lr, weight_decay=0.0001)   
        batch_size = config.batch_size
        
        X_augmented = torch.load('X_augmented.pt')
        y_augmented = torch.load('y_augmented.pt')
        train_data = TensorDataset(X_augmented, y_augmented)
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

        X_val = torch.load('X_val.pt')
        y_val = torch.load('y_val.pt')
        val_data = TensorDataset(X_val, y_val)
        val_loader = DataLoader(val_data, batch_size=batch_size)

        class_weights = calculate_class_weights(y_augmented)
        loss_function = nn.CrossEntropyLoss(weight=class_weights) if args.loss == 'cross entropy' else FocalLoss(alpha=class_weights)
        
        max_acc = 0
        for epoch in range(config.max_epochs):
            model.train()
            total_loss = 0
            for X_batch, y_batch in train_loader:
                optimizer.zero_grad()
                output = model(X_batch)
                loss = loss_function(output, y_batch)
                loss.backward()
                optimizer.step()
                batch_loss = loss.item()
                total_loss += batch_loss
            
            wandb.log({"batch loss": loss.item()})
            # Validation
            model.eval()
            with torch.no_grad():
                correct = 0
                total = 0
                for X_batch, y_batch in val_loader:
                    output = model(X_batch)
                    _, predicted = torch.max(output.data, 1)
                    total += y_batch.size(0)
                    correct += (predicted == y_batch).sum().item()
                    loss = loss_function(output, y_batch)
                    wandb.log({"val loss": loss.item()})
                    
                acc = 100 * correct / total
                
                wandb.log({"val acc": acc})
                if acc >= max_acc:
                    max_acc = acc
                    wandb.log({"max val acc": max_acc})
                    checkpoint(model, epoch, acc)

In [22]:
# Run training
wandb.agent(sweep_id, train_model, count=50)

[34m[1mwandb[0m: Agent Starting Run: 7kag5ezs with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	conv_hid_size: 32
[34m[1mwandb[0m: 	conv_in_size: 32
[34m[1mwandb[0m: 	conv_out_size: 8
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	fc_size: 512
[34m[1mwandb[0m: 	hidden_size: 8
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	max_epochs: 250
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


VBox(children=(Label(value='0.861 MB of 0.875 MB uploaded\r'), FloatProgress(value=0.9843679834676874, max=1.0…

0,1
batch loss,█▅▂▂▂▁▁▁▁▂▁▁▁▁▂▁▂▁▁▂▂▁▁▁▂▁▁▁▁▁▁▁▂▃▂▁▁▂▂▁
max val acc,▁▁▁▂██
val acc,▅█▁▄▅▅▅▅▅▅▅▄▅▅▄▄▅▄▅▃▅▅▅▃▄▄▄▄▄▅▅▅▅▄▄▄▅▅▄▅
val loss,▂▁▁▂▂▃▃▂▅▄▄▄▅▅▄▄▄▅▆▄▆▇▆▅▆▆▆▆▆▆▇▇▆▆▇▆▆█▆▇

0,1
batch loss,0.11038
max val acc,68.08511
val acc,57.44681
val loss,2.63053


[34m[1mwandb[0m: Agent Starting Run: ln89unni with config:
[34m[1mwandb[0m: 	batch_size: 48
[34m[1mwandb[0m: 	conv_hid_size: 32
[34m[1mwandb[0m: 	conv_in_size: 64
[34m[1mwandb[0m: 	conv_out_size: 8
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	fc_size: 512
[34m[1mwandb[0m: 	hidden_size: 8
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	max_epochs: 250
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
batch loss,█▅▃▂▁▂▁▁▁▁▁▁▁▁▁▂▁▂▁▂▁▂▁▁▁▁▂▁▂▂▁▂▁▁▁▁▁▁▁▁
max val acc,▁▁▁▂▂▄▄▅▇███
val acc,▅▂▁▄▇▆▇▇▇▇▇▇▇█▇▇▇▇▇▆▇▇▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇
val loss,▁▁▁▆▃▂▄▄▄▄▅▆▆▇▇▆▇▇▇██▇▇▇▇█▇█▇▇▇▇▇█▇█▇█▇▇

0,1
batch loss,0.09813
max val acc,68.08511
val acc,59.57447
val loss,2.47455


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: 579o0yr3 with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	conv_hid_size: 16
[34m[1mwandb[0m: 	conv_in_size: 256
[34m[1mwandb[0m: 	conv_out_size: 32
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	fc_size: 256
[34m[1mwandb[0m: 	hidden_size: 16
[34m[1mwandb[0m: 	lr: 0.0001
[34m[1mwandb[0m: 	max_epochs: 250
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.08970942773759676, max=1.…

0,1
batch loss,█▇▆▆▅▅▃▃▂▂▂▂▁▁▁▂▂▁▁▁▁▂▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁▁▁▁
max val acc,▁▁▁▁▁▁▁▁▁▃████
val acc,▆▆▁▃▃▁▆▅█▆▆█▇▅█▃▇▃▇▇▇▃█▆▇▅▆▅█▇▇▃▇▆▇▇█▇▇▅
val loss,▂▂▁▁▁▁▁▁▁▂▂▂▂▂▂▄▂▃▃▃▂▄▃▂▃█▃▄▄▄▃▅▅▃▄▃▄▄▃▅

0,1
batch loss,0.11386
max val acc,63.82979
val acc,59.57447
val loss,1.70018


[34m[1mwandb[0m: Agent Starting Run: wshencnk with config:
[34m[1mwandb[0m: 	batch_size: 40
[34m[1mwandb[0m: 	conv_hid_size: 32
[34m[1mwandb[0m: 	conv_in_size: 64
[34m[1mwandb[0m: 	conv_out_size: 8
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	fc_size: 512
[34m[1mwandb[0m: 	hidden_size: 8
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	max_epochs: 250
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.0897549828614955, max=1.0…

0,1
batch loss,█▅▂▁▁▁▂▁▂▂▂▁▂▁▁▁▁▁▂▁▁▂▂▁▁▁▁▁▁▁▁▁▁▂▂▁▁▂▁▁
max val acc,▁▁▁▁▁▇▇█████
val acc,▁▆█▆▅▆▆▅▆▆▆▆▆▅▆▅▆▆▆▆▆▅▅▅▆▆▅▅▆▅▅▆▆▆▆▅▅▆▆▅
val loss,▁▂▂▄▄▅▅▅▆▆▆▇▇▇▇▇▇▇▇█▇▇██████▇███▇█▇█████

0,1
batch loss,0.13032
max val acc,59.57447
val acc,51.06383
val loss,4.38878


[34m[1mwandb[0m: Agent Starting Run: cqbkmkm9 with config:
[34m[1mwandb[0m: 	batch_size: 40
[34m[1mwandb[0m: 	conv_hid_size: 32
[34m[1mwandb[0m: 	conv_in_size: 256
[34m[1mwandb[0m: 	conv_out_size: 16
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	fc_size: 32
[34m[1mwandb[0m: 	hidden_size: 8
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	max_epochs: 250
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


VBox(children=(Label(value='0.780 MB of 0.780 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
batch loss,█▇▅▃▃▂▂▂▂▃▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁
max val acc,▁▁▂▆▆▇▇▇███████
val acc,▂▆▆▇█▇█▃▇▁▇▆▇▇▆█▇▇▇▇█▅▆█▆█▇▆▇█▆█▇▇█▇▇█▅▇
val loss,▁▁▁▁▁▂▂▂▃▃▁▃▃▃▃▅▃▃▃▄▃▅▃▄▄▅▆▅▆▇▃▆▅▅▆▃▅▄█▄

0,1
batch loss,0.01558
max val acc,72.34043
val acc,63.82979
val loss,2.81182


[34m[1mwandb[0m: Agent Starting Run: famxyq07 with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	conv_hid_size: 16
[34m[1mwandb[0m: 	conv_in_size: 32
[34m[1mwandb[0m: 	conv_out_size: 8
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	fc_size: 64
[34m[1mwandb[0m: 	hidden_size: 16
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	max_epochs: 250
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.08976068050530057, max=1.…

0,1
batch loss,█▅▃▂▂▁▁▂▂▂▁▁▁▁▁▂▂▁▂▁▄▂▁▁▁▁▂▂▂▁▁▁▁▁▁▁▁▁▁▂
max val acc,▁▆█████
val acc,▅▆▇▇▇▇▇▇▆▆▇▇▇▇▇▇▇▇▇▆▁▇█▇▅▅▄▆▆▅▆▆▆▆▅▇▆▆▆▅
val loss,▁▁▂▃▃▂▂▅▅▆▂▅▆▆▇▃▆▇█▂▄▃▃▄▃▅▆▆▅▄▇▇▆▅▅▇▆▆▅▄

0,1
batch loss,0.07549
max val acc,63.82979
val acc,42.55319
val loss,2.60866


[34m[1mwandb[0m: Agent Starting Run: 9k9ojb90 with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	conv_hid_size: 16
[34m[1mwandb[0m: 	conv_in_size: 256
[34m[1mwandb[0m: 	conv_out_size: 32
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	fc_size: 512
[34m[1mwandb[0m: 	hidden_size: 16
[34m[1mwandb[0m: 	lr: 0.0001
[34m[1mwandb[0m: 	max_epochs: 250
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.08970373659836325, max=1.…

0,1
batch loss,█▇▆▅▄▅▃▃▃▂▁▃▂▂▂▂▁▁▂▁▂▂▁▂▃▁▂▂▂▁▂▂▂▁▃▂▁▂▂▂
max val acc,▁▁▁▂▂▇██████
val acc,▅▂▁▇▆▃▆█▇▇▆▇▇▆▄▆▅▆▇▇▆▅▅▅▅▆▄▅▇▆▆▅▄▆▅▅▄▅▅▅
val loss,▄▃▃▃▃▂▁▄▆▅▁▅▃▆▆▂▄▅▆▂▂▅▆▇▂▆▆█▇▃▆▇▇▃▃█▇█▃▄

0,1
batch loss,0.03535
max val acc,68.08511
val acc,61.70213
val loss,1.07914


[34m[1mwandb[0m: Agent Starting Run: 7f9mt1tu with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	conv_hid_size: 32
[34m[1mwandb[0m: 	conv_in_size: 32
[34m[1mwandb[0m: 	conv_out_size: 16
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	fc_size: 128
[34m[1mwandb[0m: 	hidden_size: 8
[34m[1mwandb[0m: 	lr: 1e-05
[34m[1mwandb[0m: 	max_epochs: 250
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.08971511959901021, max=1.…

0,1
batch loss,█▇▇▇█▆▆▆▇▆▆▅▅▄▆▆▅▅▃▅▄▆▄▄▃▂▃▃▂▂▃▂▃▂▂▂▂▁▂▂
max val acc,▁▃▄▄▄▄▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇██████████████
val acc,▂▁▁▃▃▅▅▅▇▆▆▆▆▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇█▇▇███████▇█
val loss,▆█▆▆▆▆▆▅▆▆▄▅▅▆▆▄▄▄▆▃▃▄▆▆▂▅▅▆▆▂▅▅▇▂▁▅▇▇▁▁

0,1
batch loss,0.31948
max val acc,61.70213
val acc,55.31915
val loss,0.95443


[34m[1mwandb[0m: Agent Starting Run: 9xox0n0o with config:
[34m[1mwandb[0m: 	batch_size: 24
[34m[1mwandb[0m: 	conv_hid_size: 8
[34m[1mwandb[0m: 	conv_in_size: 128
[34m[1mwandb[0m: 	conv_out_size: 16
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	fc_size: 64
[34m[1mwandb[0m: 	hidden_size: 8
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	max_epochs: 250
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.08978996129196015, max=1.…

0,1
batch loss,██▇▅▃▄▂▁▃▂▂▂▂▁▂▁▃▁▂▂▄▂▂▁▂▂▂▂▂▁▁▂▁▁▂▁▂▂▁▂
max val acc,▁▁▁▃▆█
val acc,▇▆▇▇▆██▅▇▅▇▇▇▇██▇▄▆▇▅▇▆█▆▇▇▇▆▇▆▆▇▃▆▇█▇▆▁
val loss,▁▁▁▁▁▁▂▂▂▃▄▇▅▇▄▅▅▃▇█▃▆▅▄▄▄▅▄▅▆▆▆▆▄▆▆▆▆▇▆

0,1
batch loss,0.11301
max val acc,65.95745
val acc,53.19149
val loss,3.63193


[34m[1mwandb[0m: Agent Starting Run: 2wfxz4p1 with config:
[34m[1mwandb[0m: 	batch_size: 24
[34m[1mwandb[0m: 	conv_hid_size: 8
[34m[1mwandb[0m: 	conv_in_size: 32
[34m[1mwandb[0m: 	conv_out_size: 8
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	fc_size: 32
[34m[1mwandb[0m: 	hidden_size: 16
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	max_epochs: 250
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
batch loss,█▇▆▅▃▃▁▂▁▃▁▂▁▂▂▂▁▂▁▂▁▁▁▂▁▂▁▂▁▁▁▁▁▆▂▂▂▁▁▂
max val acc,▁▁▁▁▃▇▇▇▇▇▇▇▇████████
val acc,▁▅▃▄▅▆▆▅▅▇▇▇▇▆▇▅▅▇▇▇▇▄▆▇▆▇▆▆▇▅▇▇█▆▇█▇▇██
val loss,▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▃▃▃▂▃▂█▁▁▁▁▁▂

0,1
batch loss,0.0232
max val acc,70.21277
val acc,70.21277
val loss,1.73849


[34m[1mwandb[0m: Agent Starting Run: mlojapjx with config:
[34m[1mwandb[0m: 	batch_size: 40
[34m[1mwandb[0m: 	conv_hid_size: 16
[34m[1mwandb[0m: 	conv_in_size: 256
[34m[1mwandb[0m: 	conv_out_size: 16
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	fc_size: 64
[34m[1mwandb[0m: 	hidden_size: 16
[34m[1mwandb[0m: 	lr: 0.0001
[34m[1mwandb[0m: 	max_epochs: 250
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
batch loss,█▇▇▇▇▆▅▄▅▃▄▃▂▂▃▂▂▂▃▂▂▂▂▂▁▂▂▁▁▁▁▁▂▂▂▂▂▂▂▁
max val acc,▁▂▅▅▆▆▆▆▇▇█████
val acc,▁▆▆▆▇▇▇▇▇▇█▅▇▆▇▇▇▇▇▇▆▆▇▇▆▇▇▇▇▇▆▇▇▇▇▇▇▇▇▆
val loss,▁▁▁▁▁▁▁▁▁▁▂▁▂▁▂▃▃▃▃▃▃▃▄▅▄▄▄▄▅▄▆▆▆▆▅▆▆▆▆█

0,1
batch loss,0.06285
max val acc,65.95745
val acc,55.31915
val loss,2.77754


[34m[1mwandb[0m: Agent Starting Run: ifkth8we with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	conv_hid_size: 16
[34m[1mwandb[0m: 	conv_in_size: 64
[34m[1mwandb[0m: 	conv_out_size: 32
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	fc_size: 256
[34m[1mwandb[0m: 	hidden_size: 16
[34m[1mwandb[0m: 	lr: 1e-05
[34m[1mwandb[0m: 	max_epochs: 250
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.09110237742413504, max=1.…

0,1
batch loss,█▆▆▇▅▅█▆▄▅▅▅▃▃▅▆▇▄▅▃▅▃▃▄▄▃▂▂▃▅▂▄▄▁▃▅▃▄▂▃
max val acc,▁▁▁
val acc,█▁▁▁▁▃▂▃▃▃▃▄▅▅▅▄▅▆▅▅▆▆▆▅▆▅▅▆▆▆▆▅▅▅▇▇▇▆▆▇
val loss,▅▅▅▄▇▅▅▄▇▇▄▄▃▇▇▄▃▃▇▃▃▃▇▇▂▃▃▇▇▂▃▃█▂▂▃██▁▁

0,1
batch loss,0.56657
max val acc,57.44681
val acc,53.19149
val loss,0.92081


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: d709q75u with config:
[34m[1mwandb[0m: 	batch_size: 24
[34m[1mwandb[0m: 	conv_hid_size: 8
[34m[1mwandb[0m: 	conv_in_size: 128
[34m[1mwandb[0m: 	conv_out_size: 32
[34m[1mwandb[0m: 	dropout: 0.15
[34m[1mwandb[0m: 	fc_size: 512
[34m[1mwandb[0m: 	hidden_size: 8
[34m[1mwandb[0m: 	lr: 1e-05
[34m[1mwandb[0m: 	max_epochs: 250
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
batch loss,██▇███▇▆▆▆▆▄▅▆▅▅▅▆▆▄▅▅▄▄▃▃▄▄▂▃▄▃▂▂▂▃▃▂▁▂
max val acc,▁▁▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████
val acc,▁▅▅▆▅▆▇▇▅▆▆▆▇▆▇▇▇▆▇▇▇▇▇▇▇▇▇▇█▇▇█▇█▇██▇█▇
val loss,▇███▇█▇▇▇▇▇▇▆▆▆▆▅▅▅▅▅▄▄▄▄▃▃▃▃▂▃▂▂▂▂▂▁▁▁▁

0,1
batch loss,0.60246
max val acc,68.08511
val acc,59.57447
val loss,0.99933


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: lj5td91v with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	conv_hid_size: 8
[34m[1mwandb[0m: 	conv_in_size: 256
[34m[1mwandb[0m: 	conv_out_size: 8
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	fc_size: 32
[34m[1mwandb[0m: 	hidden_size: 16
[34m[1mwandb[0m: 	lr: 0.0001
[34m[1mwandb[0m: 	max_epochs: 250
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.08979565934763295, max=1.…

0,1
batch loss,███▇▆▆▇▆▅▅▄▄▃▄▅▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▃▃▂▂▁▁▁▂▂
max val acc,▁▁▁▁
val acc,█▃▁▄▄▃▄▇▂▄▆▆▆▅▅▆▅▅▅▆▇▆▇▇▇▆▇▇▆▇▇▇▆▆█▆▇▇▇▅
val loss,▁▁▁▁▁▁▂▁▁▂▂▁▂▃▃▁▂▂▃▂▃▄▄▄▂▃▃▅▅▄▃▃▆▅▃▅██▅▆

0,1
batch loss,0.083
max val acc,57.44681
val acc,46.80851
val loss,1.86527


[34m[1mwandb[0m: Agent Starting Run: 9ko0ous3 with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	conv_hid_size: 16
[34m[1mwandb[0m: 	conv_in_size: 128
[34m[1mwandb[0m: 	conv_out_size: 16
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	fc_size: 256
[34m[1mwandb[0m: 	hidden_size: 8
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	max_epochs: 250
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.08977856734978745, max=1.…

0,1
batch loss,█▇▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▂▂▁▂▁▁▁▁▃▁
max val acc,▁▂▄▅▅▅▅▅▅▅▅▇▇▇▇▇███
val acc,▆▄▁▇█▇▇▇▇█▇▆▇████▇▇▆█▇▇██▇▇▇▇█▇▇▇▇▇█▇█▇▇
val loss,▂▁▂▂▅▁▁▃▅▄▂▅▅▆▆▂▄▅▆▂▂▄▆▇▂▅▅█▇▁▆▅█▂▂▅▇█▃▂

0,1
batch loss,0.08021
max val acc,70.21277
val acc,63.82979
val loss,2.13149


[34m[1mwandb[0m: Agent Starting Run: 3zmu3s19 with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	conv_hid_size: 16
[34m[1mwandb[0m: 	conv_in_size: 64
[34m[1mwandb[0m: 	conv_out_size: 8
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	fc_size: 128
[34m[1mwandb[0m: 	hidden_size: 16
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	max_epochs: 250
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Traceback (most recent call last):
  File "C:\Users\valen\AppData\Local\Temp\ipykernel_19500\35702409.py", line 42, in train_model
    wandb.log({"batch loss": loss.item()})
  File "c:\Users\valen\anaconda3\envs\nlp-m\lib\site-packages\wandb\sdk\wandb_run.py", line 420, in wrapper
    return func(self, *args, **kwargs)
  File "c:\Users\valen\anaconda3\envs\nlp-m\lib\site-packages\wandb\sdk\wandb_run.py", line 371, in wrapper_fn
    return func(self, *args, **kwargs)
  File "c:\Users\valen\anaconda3\envs\nlp-m\lib\site-packages\wandb\sdk\wandb_run.py", line 361, in wrapper
    return func(self, *args, **kwargs)
  File "c:\Users\valen\anaconda3\envs\nlp-m\lib\site-packages\wandb\sdk\wandb_run.py", line 1820, in log
    self._log(data=data, step=step, commit=commit)
  File "c:\Users\valen\anaconda3\envs\nlp-m\lib\site-packages\wandb\sdk\wandb_run.py", line 1595, in _log
    self._partial_history_callback(data, step, commit)
  File "c:\Users\valen\anaconda3\envs\nlp-m\lib\site-packages\wan

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x00000233385F9A30>> (for post_run_cell):


ConnectionResetError: [WinError 10054] An existing connection was forcibly closed by the remote host

Naive CNN performance achieves ~70% validation accuracy. We stop early when the validation accuracy is achieved.

### Next Steps
- Explore different ConvNet architectures
- Figure out why number of samples is so much less than actual number
- Figure out how to deal with the inconsistent classes
- Try ResNet (PyTorch has models)

In [None]:
# import torch
# import torchvision.models as models
# import torchvision.transforms as transforms
# from torch.utils.data import DataLoader
# from torchvision.datasets import ImageFolder

# resnet18 = models.resnet18(pretrained=True)  # For ResNet18
# resnet50 = models.resnet50(pretrained=True)  # For ResNet50

In [None]:
# import torchvision.transforms as transforms

# class GrayscaleToRGBTransform:
#     def __call__(self, tensor):
#         # Check if the tensor has one channel (grayscale)
#         if tensor.shape[0] == 1:
#             # Repeat the tensor across 3 channels
#             tensor = tensor.repeat(3, 1, 1)
#         return tensor

# res_transform = transforms.Compose([
#     GrayscaleToRGBTransform(),
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])

# def apply_all_transforms(X, transform):
#     transformed_data = []
#     for x in X:
#         x = transform(x)  # Apply the transformation
#         transformed_data.append(x)
#     return torch.stack(transformed_data)
    
# X_train_resnet = apply_all_transforms(X_train, transform=res_transform)
# train_resnet = TensorDataset(X_train_resnet, y_train)
# trainloader_resnet = DataLoader(train_resnet, batch_size=32, shuffle=True)

# X_val_resnet = apply_all_transforms(X_val, transform=res_transform)
# val_resnet = TensorDataset(X_val_resnet, y_val)
# valloader_resnet = DataLoader(val_resnet, batch_size=32, shuffle=True)

In [None]:
# num_epochs = 100  # Set the number of epochs
# num_ftrs = resnet18.fc.in_features
# resnet18.fc = torch.nn.Linear(num_ftrs, args.num_classes) 

# # Define a loss function and optimizer
# criterion = torch.nn.CrossEntropyLoss(class_weights)
# optimizer = torch.optim.SGD(resnet18.parameters(), lr=0.001, momentum=0.9)
# train_model(num_epochs, resnet18, criterion, optimizer, trainloader_resnet, valloader_resnet, stop_acc=70)