# Training and Testing

In [None]:
import pandas as pd
import shutil
import random
from DicomRTTool.ReaderWriter import DicomReaderWriter, ROIAssociationClass
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import os
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, ConcatDataset, TensorDataset
import pickle

## Load dataset

In [None]:
class CTMaskDataset(Dataset):
    def __init__(self, ct_path, label):
        self.ct_path = ct_path
        self.label = label

        self.Dicom_reader = DicomReaderWriter(description='Examples', arg_max=True)
        self.Dicom_reader.walk_through_folders(ct_path)

        Contour_names = ['lung'] # Define what rois you want
        associations = [ROIAssociationClass('lung', ['lungs', 'whole lung'])] 
        self.Dicom_reader.set_contour_names_and_associations(contour_names=Contour_names, associations=associations)

        self.indexes = self.Dicom_reader.which_indexes_have_all_rois()

    def __len__(self):
        return len(self.indexes)

    def __getitem__(self, idx):

        pt_indx = self.indexes[idx]
        self.Dicom_reader.set_index(pt_indx)  
        self.Dicom_reader.get_images_and_mask()  # Load up the images and mask for the requested index

        image = self.Dicom_reader.ArrayDicom # image array
        mask = self.Dicom_reader.mask # mask array

        return image, mask, self.label

In [None]:
with open('train_dataset.pkl', 'rb') as f:
    data_tuple = pickle.load(f)

# Extract the data and labels from the data tuple
data, labels = data_tuple

# Create a new dataset from the data and labels
dataset = CTMaskDataset(data, labels)

# Create the DataLoader
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)

## Define model structure

In [None]:
class CNNClassifier(nn.Module):
    def __init__(self):
        super(CNNClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 128 * 128, 512)
        self.fc2 = nn.Linear(512, 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 128 * 128)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

## Train model

In [None]:
# Check if a GPU is available and if not, default to CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Training on {device}")

# Initialize the model, loss function, and optimizer
model = CNNClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Now you can use the model in your training loop
n_epochs = 10
for epoch in range(n_epochs):
    model.train()
    running_loss = 0.0
    for i, (inputs, masks, labels) in enumerate(data_loader):
        # Move data and labels to device
        inputs, labels = inputs.to(device), labels.to(device)
        mask = masks.to(device)
        inputs = masks.unsqueeze(1).float()
        optimizer.zero_grad()
        outputs = model(inputs).float()
        labels = labels.long()
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch + 1}/{n_epochs}, Loss: {running_loss / len(data_loader)}")
# Save the model
model_save_path = "./model.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")