In [1]:
from pathpretrain.predict import predict
from pathpretrain.train_model import train_model, generate_transformers
from pathpretrain.datasets import NPYDataset
import pandas as pd
from sklearn.metrics import roc_auc_score
from statistics import mean
from torch.utils.data import Dataset
from pathpretrain.utils import load_image
from PIL import Image
import os 
import torch

class CustomDataset(Dataset):
    def __init__(self, patch_info, npy_file, transform, image_stack=False, predict_only=False, target_col=None):
        self.X=load_image(npy_file)
        self.patch_info=pd.read_pickle(patch_info)
        self.xy=self.patch_info[['x','y']].values
        self.patch_size=self.patch_info['patch_size'].iloc[0]
        self.length=self.patch_info.shape[0]
        self.transform=transform
        self.to_pil=lambda x: Image.fromarray(x)
        self.ID=os.path.basename(npy_file).replace(".npy","").replace(".tiff","").replace(".tif","").replace(".svs","")
        self.image_stack=image_stack
        self.predict_only=predict_only
        self.target_col=target_col

    def __getitem__(self,i):
        x,y=self.xy[i]
        X=self.X[i] if self.image_stack else self.X[x:(x+self.patch_size),y:(y+self.patch_size)]
        X=self.transform(self.to_pil(X))
        if not self.predict_only: return X, torch.LongTensor([x,y])
        else: return X, torch.LongTensor([self.patch_info.iloc[i][self.target_col]])

    def __len__(self):
        return self.length

In [None]:
df = pd.read_pickle('.pkl')   #pkl file with train, val, test designations and associated npy stacks and pkl files
transform=generate_transformers(224,256)['test']
i=#index
custom_dataset=CustomDataset(df.loc[i,"pkl"], df.loc[i,"npy"], transform, image_stack=True, predict_only=True, target_col="tumor")
Y_pred=train_model(inputs_dir='',
                        architecture='resnet50',
                        batch_size=128,
                        predict=True,
                        num_classes=2,
                        model_save_loc='/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Zavras/checkpoints_cnn_tumor_liver/liver_cyclegan_baseline_DH_20220916/3.epoch.checkpoint.pth',
                        predictions_save_path='/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Zavras/experiments/predictions_liver_cyclegan_baseline_DH.pkl',
                        predict_set='custom',
                        verbose=False,
                        class_balance=False,
                        gpu_id=-1,
                        tensor_dataset=False,
                        pickle_dataset=True,
                        semantic_segmentation=False,
                        custom_dataset=custom_dataset,
                        save_predictions=False)
    roc_auc=roc_auc_score(Y_pred['true'], Y_pred['pred'][:,1])