In [1]:
import pandas as pd
import os
import numpy as np
import torch
import timm
import matplotlib.pyplot as plt
from dataset_EfficientNetV2 import Dataset_EfficientNetV2


In [2]:

folder_path = '' #path to top folder containing patient directories. 
label_meta_data_path = '' #csv file with patient IDs and corresponding diagnostic labels.
output_folder_path= '' #folder where final trained model will be saved.

subject_meta = pd.read_csv(label_meta_data_path)

is_ct_image=False  #set to False if MCI, set to True if CT
is_shape = False   #set to True if working with binary shape images. Should only be True if is_ct_image is also True

batch_size_val = 32
step_size=2
num_epochs = 6 # or however many you need

#The following values are default for the Adam optimizer / scheduler in pytorch
learning_rate = 1e-3
beta1 = 0.9
beta2 = 0.999
gamma=0.1





In [None]:
# Load training dataset
dataset=Dataset_EfficientNetV2(folder_path, subject_meta, is_ct_image=is_ct_image, is_shape=is_shape)

In [None]:
#Visualize an example montage. In this example, the cmap='jet', vmin and vmax parameters are set 
#for the case of an MCI montage. For CT, you can use cmap='gray', without setting vmin and vmax
plt.imshow(dataset.images[150],cmap='jet',vmin=-1.5,vmax=1.5)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


In [11]:


train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size_val, shuffle=True)


In [12]:
# Initialize the model
model = timm.create_model('tf_efficientnetv2_s',in_chans=1,pretrained=True).to(device)

# Loss function
criterion = torch.nn.CrossEntropyLoss()

# Optimizer with decay

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(beta1, beta2))
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)


In [None]:


for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (images, labels, _,_) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    running_loss /= len(train_loader)
    scheduler.step()



    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {running_loss}')

    #save model only after last epoch if you want to save disk space, but this can be changed
    if(epoch==num_epochs-1):

        model_filename = 'trained_model_EfficientNetV2_epoch{}.pth'.format(epoch+1) # name of trained model file to be saved
        model_save_path = os.path.join(output_folder_path, model_filename)
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': running_loss,
        }, model_save_path)


