In [None]:
import torch
from torch.utils.data import Dataset,DataLoader,TensorDataset
import numpy as np
from torch import nn
from PIL import Image, ImageDraw
from torchvision import transforms
import torchvision.models as models
import time
from matplotlib import pyplot as plt
import os
import glob
import random
import pandas as pd
from histolab.slide import Slide
import csv
import re
from sklearn.model_selection import train_test_split, cross_validate
from sklearn.metrics import roc_curve, roc_auc_score, accuracy_score, confusion_matrix, make_scorer, auc
from sklearn.preprocessing import LabelBinarizer
from itertools import cycle
from collections import Counter

In [None]:
train_data = wsi_dataset(file_path = 'train.txt',transform = data_transforms['train'])
valid_data = wsi_dataset(file_path = 'valid.txt',transform = data_transforms['valid'])

train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True) 
valid_dataloader = DataLoader(valid_data, batch_size=64, shuffle=True) 

In [None]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
model = resnet_model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params_to_update,lr=0.001)

best_acc = 0
epochs = 20
acc_s = []
loss_s = []
for t in range(epochs):
    start_time = time.time()
    print(f"Epoch {t+1}\n--------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    torch.save(model.state_dict(), "stage1\model_stage1_epochs_"+str(t)+"_train.pth")
    test(valid_dataloader, model, loss_fn)
    torch.save(model.state_dict(), "stage1\model_stage1_epochs_"+str(t)+"_test.pth")
    end_time = time.time()
    time_diff = end_time - start_time
    print("time_diff：", time_diff)
print()

In [None]:
plt.subplot(1, 2, 1)
plt.plot(range(0, epochs), acc_s)
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.subplot(1, 2, 2)
plt.plot(range(0, epochs), loss_s)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show() 
print("Done!")

In [None]:
index_df = pd.read_csv(r'E:\THT\HE_DTC_txt\index.csv')
make_preds_by_slide_and_visualize(model_epoch_10, df_filtered, device, r"E:\THT\HE_DTC_output",visualize=True,out_all=False,
                                  transform=data_transforms['valid'], batch_size=64, num_workers=0, n_classes=4, scale_factor=80)

In [None]:
resnet_model,params_to_update = model_initialization(dimension=3,freeze=False)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model = model_load(resnet_model,device,"stage2_best_54\model_stage2_epochs_"+str(13)+".pth")
model.eval()

test_loader = DataLoader(protein_valid_data, batch_size=64, shuffle=False)

y_scores = []
y_labels = []

with torch.no_grad():
    for X_batch, y_batch in test_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        outputs = model(X_batch)  
        probs = torch.softmax(outputs, dim=1)  
        y_scores.append(probs.cpu().numpy()) 
        y_labels.append(y_batch.cpu().numpy())
y_probs = np.concatenate(y_scores)
y_true = np.concatenate(y_labels)
y_onehot_true = LabelBinarizer().fit(range(3)).transform(y_true)

fpr = dict()
tpr = dict()
roc_auc = dict()
n_classes = 3
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_true == i, y_probs[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

In [None]:
print(roc_auc_score(y_true, y_probs, multi_class='ovr',average='macro'))
print(roc_auc_score(y_true, y_probs, multi_class='ovo',average='macro'))
print(roc_auc_score(y_true, y_probs, multi_class='ovr',average='micro'))
print(y_true.shape,y_probs.shape)

In [None]:
y_onehot_true = LabelBinarizer().fit(range(3)).transform(y_true)
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_onehot_true.ravel(), y_probs.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
 
# Compute macro-average ROC curve and ROC area
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
 
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
    mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
 
# Finally average it and compute AUC
mean_tpr /= n_classes
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
 
# Plot all ROC curves
lw=2
plt.figure()
plt.plot(fpr["micro"], tpr["micro"],
         label='micro-average ROC curve (area = {0:0.4f})'
               ''.format(roc_auc["micro"]),
         color='deeppink', linestyle=':', linewidth=4)
 
plt.plot(fpr["macro"], tpr["macro"],
         label='macro-average ROC curve (area = {0:0.4f})'
               ''.format(roc_auc["macro"]),
         color='navy', linestyle=':', linewidth=4)
 
colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
for i, color in zip(range(n_classes), colors):
    plt.plot(fpr[i], tpr[i], color=color, lw=lw,
             label='ROC curve of class {0} (area = {1:0.4f})'
             ''.format(i, roc_auc[i]))
 
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([-0.02, 1.0])
plt.ylim([0.0, 1.02])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.show()

In [None]:
index_df_123 = pd.read_csv(r'E:\THT\HE_DTC_second\second\txt_123\index.csv')
resnet_model_epoch_10,params_no_use = model_initialization(4)
model_epoch_10 = model_load(resnet_model_epoch_10,device,"stage1\model_stage1_epochs_10_train.pth")
model_epoch_10.eval()
make_preds_by_slide_and_visualize(model_epoch_10, index_df_123, device, r"E:\THT\HE_DTC_second\second\output\model1_out\123",transform=data_transforms['valid'],
                                  visualize=True,out_all=True,all_shuffle=False,batch_size=64, num_workers=0, n_classes=4, scale_factor=80,
                                  wsi_type=".mrxs",sum_calc=True,pick_label=0)

In [None]:
pick_index_df = pd.read_csv(r'E:\THT\HE_DTC_TCGA\output\model1_out\txt\pick_index.csv')
output_dir = r'E:\THT\HE_DTC_TCGA\output\model2_2out'
for t in range(21):
    resnet_model,params_to_update = model_initialization(dimension=3,freeze=False)
    model = model_load(resnet_model,device,"stage2_best_54\model_stage2_epochs_"+str(t)+".pth")
    model.eval()
    make_preds_by_slide_and_visualize(model, pick_index_df, device, os.path.join(output_dir,str(t)), transform=data_transforms['valid'], 
                                      visualize=True,out_all=True,all_shuffle=False,batch_size=64, num_workers=0, n_classes=3, scale_factor=80,
                                      wsi_type=".svs",sum_calc=True)
    output_result = os.path.join(output_dir, str(t), 'output_result.csv')
    data = []
    pattern = re.compile(r"(.+)_sum_out_label=(.+)\.txt")
    for filename in os.listdir(os.path.join(output_dir,str(t))):
        match = pattern.match(filename)
        if match:
            wsi_name, out_label = match.groups()
            data.append([wsi_name, out_label])
    df = pd.DataFrame(data, columns=["wsi_name", "out_label"])
    df.to_csv(output_result, index=False)
    print(f"result: {output_result}")
    print(df['out_label'].value_counts())