In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import xml.etree.ElementTree as ET
import numpy as np
import skimage.draw
import os
import glob
import cv2
from PIL import Image
from scipy.io import savemat


def binary_mask_from_xml_file(xml_file_path, image_shape=(1000, 1000)):
    tree = ET.parse(xml_file_path)
    root = tree.getroot()

    def vertex_element_to_tuple(vertex_element):
        col = float(vertex_element.get('X'))
        row = float(vertex_element.get('Y'))
        return round(row), round(col)

    mask = np.zeros(image_shape, dtype=np.uint8)
    iii = 1
    for region in root.iter('Region'):
        vertices = map(vertex_element_to_tuple, region.iter('Vertex'))
        rows, cols = np.array(list(zip(*vertices)))

        rows[rows >= 1000] = 999
        cols[cols >= 1000] = 999
        rr, cc = skimage.draw.polygon(rows, cols, mask.shape)

        '''
        # To add the nuclear boundary as a separate class
        mask[rr, cc] = 2
        mask[rows, cols] = 1
        '''
        mask[rr, cc] = 1

        iii += 1

    return mask

# TRAIN SET
folder1 = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/MoNuSeg_train/Annotations/*'  # Annotations Folder
folder2 = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/MoNuSeg_train/Tissue Images'  # Images Folder

# Train Output Folders
labelled_images = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/train/images'
os.makedirs(labelled_images, exist_ok=True)
labels = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/train/masks'
os.makedirs(labels, exist_ok=True)

labels_list = []
IMGS = glob.glob(folder1)
indexing = 0
for mask_path in IMGS:
    print(mask_path)
    img_path = os.path.join(folder2, mask_path.split('/')[-1].split('.xml')[0] + '.tif')

    image = Image.open(img_path)
    image = np.array(image)
    save_path1 = os.path.join(labelled_images, 'image_' + str(indexing) + '.jpg')
    cv2.imwrite(save_path1, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))

    image_label = binary_mask_from_xml_file(mask_path)
    save_path2 = os.path.join(labels, 'mask_' + str(indexing) + '.jpg')
    cv2.imwrite(save_path2, image_label)
    labels_list.append(image_label)

    indexing += 1


masks = np.array(labels_list)
path = os.path.join('/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/data', 'train_masks.mat')
mdic = {"data": masks, "label": "labels"}
savemat(path, mdic)


# TEST SET
folder1 = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/MoNuSeg_test/Annotations/*'  # Annotations Folder
folder2 = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/MoNuSeg_test/Tissue Images'  # Images Folder

# Test Output Folders
labelled_images = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/test/images'
os.makedirs(labelled_images, exist_ok=True)
labels = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/test/masks'
os.makedirs(labels, exist_ok=True)

labels_list = []
IMGS = glob.glob(folder1)
indexing = 0
for mask_path in IMGS:
    print(mask_path)
    img_path = os.path.join(folder2, mask_path.split('/')[-1].split('.xml')[0] + '.tif')

    image = Image.open(img_path)
    image = np.array(image)
    save_path1 = os.path.join(labelled_images, 'image_' + str(indexing) + '.jpg')
    cv2.imwrite(save_path1, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))

    image_label = binary_mask_from_xml_file(mask_path)
    save_path2 = os.path.join(labels, 'mask_' + str(indexing) + '.jpg')
    cv2.imwrite(save_path2, image_label)
    labels_list.append(image_label)

    indexing += 1

masks = np.array(labels_list)
path = os.path.join('/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/data', 'test_masks.mat')
mdic = {"data": masks, "label": "labels"}
savemat(path, mdic)



/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/MoNuSeg_train/Annotations/TCGA-A7-A13E-01Z-00-DX1.xml
/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/MoNuSeg_train/Annotations/TCGA-AY-A8YK-01A-01-TS1.xml
/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/MoNuSeg_train/Annotations/TCGA-HE-7128-01Z-00-DX1.xml
/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/MoNuSeg_train/Annotations/TCGA-G9-6356-01Z-00-DX1.xml
/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/MoNuSeg_train/Annotations/TCGA-NH-A8F7-01A-01-TS1.xml
/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/MoNuSeg_train/Annotations/TCGA-HE-7129-01Z-00-DX1.xml
/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/MoNuSeg_train/Annotations/TCGA-HE-7130-01Z-00-DX1.xml
/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/MoNuSeg_train/Annotations/TCGA-38-6178-01Z-00-DX1.xml
/content/drive/MyDrive/converted_notebooks/GenSelfDiff-H

In [None]:
!pip install torchvision

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.6.0->torchvision)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2.6.0->torchvision)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch==2.6.0->torchvision)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86

In [None]:
# This script is used to generate the unlabeled image patches for self-supervision from the official train set of MoNuSeg

from PIL import Image
from torchvision.transforms import transforms
import torch
import glob
import json
import numpy as np
import os
import cv2
from skimage import io
import sys

folder = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/data/train_images/*'

# directories for images
OUT_FOLDER = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/unlabelled_img_patches'
os.makedirs(OUT_FOLDER, exist_ok=True)

PATCHES = []
indexing = 0
IMGS = glob.glob(folder)
for img_path in sorted(IMGS, key=lambda x: int(x.split("_")[-1].split('.jpg')[0])):
    image_number = int(img_path.split("_")[-1].split(".jpg")[0])

    img = cv2.imread(img_path)
    ximg = transforms.ToTensor()(img)

    size = 256  # patch size
    stride = 64  # patch stride
    patches = ximg.unfold(1, size, stride).unfold(2, size, stride)
    patches = patches.reshape(img.shape[2], -1, size, size)
    patches = torch.permute(patches, (1, 2, 3, 0))
    patches = patches.numpy()

    for i in range(patches.shape[0]):
        save_path = os.path.join(OUT_FOLDER, 'image_' + str(indexing) + '.jpg')
        cv2.imwrite(save_path, np.uint8(255*patches[i, :, :, :]))
        indexing += 1
    print('{} patches are created'.format(indexing))


45 patches are created
90 patches are created
135 patches are created
180 patches are created
225 patches are created
270 patches are created
294 patches are created
339 patches are created
384 patches are created
429 patches are created
474 patches are created
489 patches are created
534 patches are created
579 patches are created
624 patches are created
669 patches are created
714 patches are created
759 patches are created
804 patches are created
849 patches are created
894 patches are created
939 patches are created
984 patches are created
1029 patches are created
1074 patches are created
1119 patches are created
1164 patches are created
1209 patches are created
1254 patches are created
1299 patches are created
1344 patches are created
1389 patches are created
1434 patches are created
1479 patches are created
1524 patches are created
1548 patches are created
1593 patches are created
1638 patches are created
1683 patches are created
1728 patches are created
1743 patches are created


In [None]:
# This script is used to split the test data of the official MoNuSeg into train-test split for our downstream
# histopathological image segmentation

from PIL import Image
from torchvision.transforms import transforms
import torch
import glob
import json
import numpy as np
import os
import cv2
from skimage import io
import sys
from scipy.io import loadmat, savemat


def split_data(images, labels, ratio):
    # Ensure both arrays have same length before permutation
    min_len = min(images.shape[0], labels.shape[0])
    idxs = np.random.RandomState(2023).permutation(min_len)

    split = int(min_len * ratio)
    split_1 = idxs[:split]
    split_2 = idxs[split:]

    # Apply indices to both images and labels
    return images[split_1], images[split_2], labels[split_1], labels[split_2]


folder1 = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/data/test_images/*'

# directories for images and labels
TRAIN_OUT_FOLDER = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/train/full_size_images'
os.makedirs(TRAIN_OUT_FOLDER, exist_ok=True)
TEST_OUT_FOLDER = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/test/full_size_masks'
os.makedirs(TEST_OUT_FOLDER, exist_ok=True)

print('===========================================================================')
print('                          IMAGES and LABELS                                    ')
print('===========================================================================')

IMAGES = []
IMGS = glob.glob(folder1)
for img_path in sorted(IMGS, key=lambda x: int(x.split("_")[-1].split(".jpg")[0])):
    print(img_path.split("_")[-1].split(".jpg")[0])
    img = cv2.imread(img_path)
    img = cv2.resize(img, (256, 256))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    IMAGES.append(img)
IMAGES = np.array(IMAGES)
print(f'number of images: {IMAGES.shape[0]}')

label_file = loadmat('/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/data/test_masks.mat')
all_labels = label_file['data']
LABELS = []
for i in range(all_labels.shape[0]):
    mask = all_labels[i]
    print(np.unique(mask))
    LABELS.append(mask)
LABELS = np.array(LABELS)
print(f'number of images: {LABELS.shape[0]}')

train_imgs, test_imgs, train_lbls, test_lbls = split_data(IMAGES, LABELS, 0.8)


for i in range(train_imgs.shape[0]):
    TRAIN_OUT_IMAGE_PATH = os.path.join(TRAIN_OUT_FOLDER, 'image_' + str(i) + '.jpg')
    cv2.imwrite(TRAIN_OUT_IMAGE_PATH, cv2.cvtColor(train_imgs[i], cv2.COLOR_RGB2BGR))

print(f'Number of training images: {train_imgs.shape[0]}')

path = os.path.join('/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/train', 'full_size_labels.mat')
mdic = {"data": train_lbls, "label": "train_labels"}
savemat(path, mdic)
print(f'Number of training labels: {train_lbls.shape[0]}')

for i in range(test_imgs.shape[0]):
    TEST_OUT_IMAGE_PATH = os.path.join(TEST_OUT_FOLDER, 'image_' + str(i) + '.jpg')
    cv2.imwrite(TEST_OUT_IMAGE_PATH, cv2.cvtColor(test_imgs[i], cv2.COLOR_RGB2BGR))

print(f'Number of testing images: {test_imgs.shape[0]}')

path = os.path.join('/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/test', 'full_size_labels.mat')
mdic = {"data": test_lbls, "label": "test_labels"}
savemat(path, mdic)
print(f'Number of testing labels: {test_lbls.shape[0]}')



                          IMAGES and LABELS                                    
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
number of images: 80
[0 1]
[0 1]
[0 1]
[0 1]
[0 1]
[0 1]
[0 1]
[0 1]
[0 1]
[0 1]
[0 1]
[0 1]
[0 1]
[0 1]
number of images: 14
Number of training images: 11
Number of training labels: 11
Number of testing images: 3
Number of testing labels: 3


In [None]:
# This script is used to obtain the test image-label patches form the full-scale image-labels

from PIL import Image
from torchvision.transforms import transforms
import torch
import glob
import json
import numpy as np
import os
import cv2
from skimage import io
import sys
from scipy.io import loadmat, savemat


folder1 = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/test/full_size_images/*'

# directories for images and labels
OUT_FOLDER = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/test/images'
os.makedirs(OUT_FOLDER, exist_ok=True)

print('===========================================================================')
print('                          IMAGE PATCHES                                    ')
print('===========================================================================')
indexing = 0
IMGS = glob.glob(folder1)
for img_path in sorted(IMGS, key=lambda x: int(x.split("_")[-1].split(".jpg")[0])):
    print(img_path.split("_")[-1].split(".jpg")[0])
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    ximg = transforms.ToTensor()(img)

    size = 256  # patch size
    stride = 64  # patch stride
    patches = ximg.unfold(1, size, stride).unfold(2, size, stride)
    patches = patches.reshape(img.shape[2], -1, size, size)
    patches = torch.permute(patches, (1, 2, 3, 0))
    patches = patches.numpy()

    for i in range(patches.shape[0]):
        save_path = os.path.join(OUT_FOLDER, 'image_' + str(indexing) + '.jpg')
        cv2.imwrite(save_path, cv2.cvtColor(np.uint8(255 * patches[i, :, :, :]), cv2.COLOR_RGB2BGR))
        indexing += 1
print('Number of image patches in total: {}'.format(indexing))

print('===========================================================================')
print('                          LABEL PATCHES                                    ')
print('===========================================================================')
label_file = loadmat('/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/test/full_size_labels.mat')
all_labels = label_file['data']
LABEL_PATCHES = []
indexing = 0
for i in range(all_labels.shape[0]):
    xlabel = transforms.ToTensor()(all_labels[i, :, :])

    size = 256  # patch size
    stride = 64  # patch stride
    patches = xlabel.unfold(1, size, stride).unfold(2, size, stride)
    patches = patches.reshape(-1, size, size)
    patches = np.uint8(255 * patches.numpy())

    for i in range(patches.shape[0]):
        LABEL_PATCHES.append(patches[i, :, :])
        indexing += 1
    print('{} label patches are created'.format(indexing))
LABEL_PATCHES = np.array(LABEL_PATCHES)
print('Number of label patches in total: {}'.format(LABEL_PATCHES.shape[0]))

print('===========================================================================')
print('                            Now Saving the patches                         ')
print('===========================================================================')

path = os.path.join('/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/test', 'label_patches.mat')
mdic = {"data": LABEL_PATCHES, "label": "test_labels"}
savemat(path, mdic)
print(f'Number of training labels: {LABEL_PATCHES.shape[0]}')




                          IMAGE PATCHES                                    
0
1
2
Number of image patches in total: 135
                          LABEL PATCHES                                    
144 label patches are created
288 label patches are created
432 label patches are created
Number of label patches in total: 432
                            Now Saving the patches                         
Number of training labels: 432


In [None]:
# This script is used to obtain the train image patches from the full scale train images

from PIL import Image
from torchvision.transforms import transforms
import torch
import glob
import json
import numpy as np
import os
import cv2
from skimage import io
import sys
from scipy.io import loadmat, savemat


folder1 = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/train//full_size_images/*'

# directories for images and labels
OUT_FOLDER = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/train/images'
os.makedirs(OUT_FOLDER, exist_ok=True)

print('===========================================================================')
print('                          IMAGE PATCHES                                    ')
print('===========================================================================')
indexing = 0
IMGS = glob.glob(folder1)
for img_path in sorted(IMGS, key=lambda x: int(x.split("_")[-1].split(".jpg")[0])):
    print(img_path.split("_")[-1].split(".jpg")[0])
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    ximg = transforms.ToTensor()(img)

    size = 256  # patch size
    stride = 64  # patch stride
    patches = ximg.unfold(1, size, stride).unfold(2, size, stride)
    patches = patches.reshape(img.shape[2], -1, size, size)
    patches = torch.permute(patches, (1, 2, 3, 0))
    patches = patches.numpy()

    for i in range(patches.shape[0]):
        save_path = os.path.join(OUT_FOLDER, 'image_' + str(indexing) + '.jpg')
        cv2.imwrite(save_path, cv2.cvtColor(np.uint8(255 * patches[i, :, :, :]), cv2.COLOR_RGB2BGR))
        indexing += 1
print('Number of image patches in total: {}'.format(indexing))

print('===========================================================================')
print('                          LABEL PATCHES                                    ')
print('===========================================================================')
label_file = loadmat('/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/train/full_size_labels.mat')
all_labels = label_file['data']
LABEL_PATCHES = []
indexing = 0
for i in range(all_labels.shape[0]):
    xlabel = transforms.ToTensor()(all_labels[i, :, :])

    size = 256  # patch size
    stride = 64  # patch stride
    patches = xlabel.unfold(1, size, stride).unfold(2, size, stride)
    patches = patches.reshape(-1, size, size)
    patches = np.uint8(255 * patches.numpy())

    for i in range(patches.shape[0]):
        LABEL_PATCHES.append(patches[i, :, :])
        indexing += 1
    print('{} label patches are created'.format(indexing))
LABEL_PATCHES = np.array(LABEL_PATCHES)
print('Number of label patches in total: {}'.format(LABEL_PATCHES.shape[0]))

print('===========================================================================')
print('                            Now Saving the patches                         ')
print('===========================================================================')


path = os.path.join('/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/train', 'label_patches.mat')
mdic = {"data": LABEL_PATCHES, "label": "train_labels"}
savemat(path, mdic)
print(f'Number of training labels: {LABEL_PATCHES.shape[0]}')




                          IMAGE PATCHES                                    
0
1
2
3
4
5
6
7
8
9
10
Number of image patches in total: 11
                          LABEL PATCHES                                    
144 label patches are created
288 label patches are created
432 label patches are created
576 label patches are created
720 label patches are created
864 label patches are created
1008 label patches are created
1152 label patches are created
1296 label patches are created
1440 label patches are created
1584 label patches are created
Number of label patches in total: 1584
                            Now Saving the patches                         
Number of training labels: 1584


In [None]:
import os
from PIL import Image
import glob
import numpy as np
import cv2
from scipy.io import savemat

folder = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/GLaS_data/*' # path for the original GlaS dataset

train_images = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/data/train_images'
os.makedirs(train_images, exist_ok=True)
train_labels = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/data/train_masks'
os.makedirs(train_labels, exist_ok=True)

IMGS = glob.glob(folder)
for img_path in IMGS:
    image_type = img_path.split('/')[-1].split('_')[0][:5]
    if image_type == 'train':
        is_annotation = img_path.split('/')[-1].split('_')[-1].split('.bmp')[0]
        if is_annotation != 'anno':
            img_num = int(img_path.split('/')[-1].split('_')[1].split('.bmp')[0])
            # print(image_type+str(img_num))
            print(img_num - 1)

            img = Image.open(img_path)
            img = np.array(img)
            save_path = os.path.join(train_images, 'image_' + str(img_num-1) + '.jpg')
            cv2.imwrite(save_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))

        else:
            img_num = int(img_path.split('/')[-1].split('_')[1])
            # print(image_type+str(img_num))
            print(img_num - 1)

            mask = Image.open(img_path)
            mask = np.array(mask)
            mask[mask != 0] = 1
            save_path = os.path.join(train_labels, 'label_' + str(img_num - 1) + '.jpg')
            cv2.imwrite(save_path, cv2.cvtColor(255 * mask, cv2.COLOR_RGB2BGR))
print('='*30)

test_images = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/data/test_images'
os.makedirs(test_images, exist_ok=True)
test_labels = './content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/data/test_masks'
os.makedirs(test_labels, exist_ok=True)

IMGS = glob.glob(folder)
indexing = 0
for img_path in IMGS:
    image_type = img_path.split('/')[-1].split('_')[0][:5]
    if image_type == 'testA':
        is_annotation = img_path.split('/')[-1].split('_')[-1].split('.bmp')[0]
        if is_annotation != 'anno':
            img_num = int(img_path.split('/')[-1].split('_')[1].split('.bmp')[0])
            # print(image_type+str(img_num))
            print(img_num - 1)

            img = Image.open(img_path)
            img = np.array(img)
            save_path = os.path.join(test_images, 'image_' + str(img_num - 1) + '.jpg')
            cv2.imwrite(save_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
            indexing += 1
        else:
            img_num = int(img_path.split('/')[-1].split('_')[1])
            # print(image_type + str(img_num))
            print(img_num-1)

            mask = Image.open(img_path)
            mask = np.array(mask)
            mask[mask != 0] = 1
            save_path = os.path.join(test_labels, 'label_' + str(img_num-1) + '.jpg')
            cv2.imwrite(save_path, cv2.cvtColor(255*mask, cv2.COLOR_RGB2BGR))


ref_index = indexing-1
for img_path in IMGS:
    image_type = img_path.split('/')[-1].split('_')[0][:5]
    if image_type == 'testB':
        is_annotation = img_path.split('/')[-1].split('_')[-1].split('.bmp')[0]
        if is_annotation != 'anno':
            img_num = int(img_path.split('/')[-1].split('_')[1].split('.bmp')[0])
            # print(image_type+str(img_num))
            print(img_num + ref_index)

            img = Image.open(img_path)
            img = np.array(img)
            save_path = os.path.join(test_images, 'image_' + str(img_num + ref_index) + '.jpg')
            cv2.imwrite(save_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
        else:
            img_num = int(img_path.split('/')[-1].split('_')[1])
            # print(image_type + str(img_num))
            print(img_num+ref_index)

            mask = Image.open(img_path)
            mask = np.array(mask)
            mask[mask != 0] = 1
            save_path = os.path.join(test_labels, 'label_' + str(img_num+ref_index) + '.jpg')
            cv2.imwrite(save_path, cv2.cvtColor(255*mask, cv2.COLOR_RGB2BGR))

print('='*30)


11
0
15
13
10
10
14
12
11
9
12
15
13
9
14
16
19
18
27
26
23
24
27
21
16
24
17
20
25
18
0
20
25
22
28
17
19
26
23
1
21
22
30
33
31
39
35
34
29
37
32
31
38
40
34
39
36
2
30
38
29
33
2
40
32
35
3
37
1
36
28
46
43
41
49
48
45
52
3
47
48
50
42
44
51
41
45
46
53
52
50
43
4
49
47
44
42
51
66
4
61
64
55
65
63
53
57
54
57
62
65
60
54
60
56
58
61
59
55
58
56
5
62
63
59
64
5
69
73
68
73
75
66
71
72
74
6
77
78
76
69
67
74
77
67
78
70
68
6
72
71
70
76
75
80
84
8
79
7
82
84
81
8
81
83
79
83
80
7
82
11
11
10
9
0
10
9
24
1
23
16
16
12
18
12
21
20
18
17
13
25
13
20
22
23
22
14
14
17
24
19
21
15
19
15
0
29
29
27
36
28
32
38
2
1
33
33
31
34
27
25
2
37
38
26
35
34
36
26
31
35
30
28
30
37
32
49
43
50
41
46
40
47
42
39
49
4
45
44
48
3
40
41
47
42
44
3
48
45
43
39
50
46
53
8
6
52
53
57
4
7
51
52
54
58
56
5
55
59
54
56
57
51
5
7
58
8
59
55
6
78
69
60
74
61
73
77
74
70
77
71
60
75
79
62
73
75
71
72
78
76
61
79
69
70
76
72
68
63
66
64
65
64
66
65
67
68
63
67
62


In [None]:
# This script is used to generate the unlabeled image patches for self-supervision from the official train set of GlaS

from PIL import Image
from torchvision.transforms import transforms
import torch
import glob
import json
import numpy as np
import os
import cv2
from skimage import io
import sys

folder = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/data/train_images/*'

# directories for images
OUT_FOLDER = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/unlabelled_img_patches'
os.makedirs(OUT_FOLDER, exist_ok=True)

PATCHES = []
indexing = 0
IMGS = glob.glob(folder)
for img_path in sorted(IMGS, key=lambda x: int(x.split("_")[-1].split('.jpg')[0])):
    image_number = int(img_path.split("_")[-1].split(".jpg")[0])

    img = cv2.imread(img_path)
    ximg = transforms.ToTensor()(img)

    size = 256  # patch size
    stride = 64  # patch stride
    patches = ximg.unfold(1, size, stride).unfold(2, size, stride)
    patches = patches.reshape(img.shape[2], -1, size, size)
    patches = torch.permute(patches, (1, 2, 3, 0))
    patches = patches.numpy()

    for i in range(patches.shape[0]):
        save_path = os.path.join(OUT_FOLDER, 'image_' + str(indexing) + '.jpg')
        cv2.imwrite(save_path, np.uint8(255*patches[i, :, :, :]))
        indexing += 1
    print('{} patches are created'.format(indexing))


45 patches are created
90 patches are created
135 patches are created
180 patches are created
225 patches are created
270 patches are created
294 patches are created
339 patches are created
384 patches are created
429 patches are created
474 patches are created
489 patches are created
534 patches are created
579 patches are created
624 patches are created
669 patches are created
714 patches are created
759 patches are created
804 patches are created
849 patches are created
894 patches are created
939 patches are created
984 patches are created
1029 patches are created
1074 patches are created
1119 patches are created
1164 patches are created
1209 patches are created
1254 patches are created
1299 patches are created
1344 patches are created
1389 patches are created
1434 patches are created
1479 patches are created
1524 patches are created
1548 patches are created
1593 patches are created
1638 patches are created
1683 patches are created
1728 patches are created
1743 patches are created


In [None]:
# This script is used to obtain the test image-label patches form the full-scale image-labels

from PIL import Image
from torchvision.transforms import transforms
import torch
import glob
import json
import numpy as np
import os
import cv2
from skimage import io
import sys
from scipy.io import loadmat, savemat


folder1 = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/test/full_size_images/*'
folder2 = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/test/full_size_masks/*'

# directories for images and labels
OUT_FOLDER = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/test/images'
os.makedirs(OUT_FOLDER, exist_ok=True)

print('===========================================================================')
print('                          IMAGE PATCHES                                    ')
print('===========================================================================')
indexing = 0
IMGS = glob.glob(folder1)
for img_path in sorted(IMGS, key=lambda x: int(x.split("_")[-1].split(".jpg")[0])):
    print(img_path.split("_")[-1].split(".jpg")[0])
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    ximg = transforms.ToTensor()(img)

    size = 256  # patch size
    stride = 64  # patch stride
    patches = ximg.unfold(1, size, stride).unfold(2, size, stride)
    patches = patches.reshape(img.shape[2], -1, size, size)
    patches = torch.permute(patches, (1, 2, 3, 0))
    patches = patches.numpy()

    for i in range(patches.shape[0]):
        save_path = os.path.join(OUT_FOLDER, 'image_' + str(indexing) + '.jpg')
        cv2.imwrite(save_path, cv2.cvtColor(np.uint8(255 * patches[i, :, :, :]), cv2.COLOR_RGB2BGR))
        indexing += 1
print('Number of image patches in total: {}'.format(indexing))

print('===========================================================================')
print('                          LABEL PATCHES                                    ')
print('===========================================================================')

LABEL_PATCHES = []
indexing = 0
IMGS = glob.glob(folder2)
for label_path in sorted(IMGS, key=lambda x: int(x.split("_")[-1].split(".jpg")[0])):
    print(label_path.split("_")[-1].split(".jpg")[0])
    mask = cv2.imread(label_path)
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    xlabel = transforms.ToTensor()(mask)

    size = 256  # patch size
    stride = 64  # patch stride
    patches = xlabel.unfold(1, size, stride).unfold(2, size, stride)
    patches = patches.reshape(-1, size, size)
    patches = np.uint8(patches.numpy())

    for i in range(patches.shape[0]):
        LABEL_PATCHES.append(patches[i, :, :])
        indexing += 1
LABEL_PATCHES = np.array(LABEL_PATCHES)
print('Number of label patches in total: {}'.format(LABEL_PATCHES.shape[0]))
print(np.unique(LABEL_PATCHES))

print('===========================================================================')
print('                            Now Saving the patches                         ')
print('===========================================================================')


path = os.path.join('/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/test', 'label_patches.mat')
mdic = {"data": LABEL_PATCHES, "label": "test_labels"}
savemat(path, mdic)
print(f'Number of training labels: {LABEL_PATCHES.shape[0]}')


                          IMAGE PATCHES                                    
0
1
2
Number of image patches in total: 135
                          LABEL PATCHES                                    
0
0
1
1
2
2
Number of label patches in total: 435
[0 1]
                            Now Saving the patches                         
Number of training labels: 435


In [None]:
# This script is used to obtain the train image patches from the full scale train images

from PIL import Image
from torchvision.transforms import transforms
import torch
import glob
import json
import numpy as np
import os
import cv2
from skimage import io
import sys
from scipy.io import loadmat, savemat


folder1 = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/train/full_size_images/*'
folder2 = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/train/full_size_masks/*'

# directories for images and labels
OUT_FOLDER = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/train/images'
os.makedirs(OUT_FOLDER, exist_ok=True)

print('===========================================================================')
print('                          IMAGE PATCHES                                    ')
print('===========================================================================')
indexing = 0
IMGS = glob.glob(folder1)
for img_path in sorted(IMGS, key=lambda x: int(x.split("_")[-1].split(".jpg")[0])):
    print(img_path.split("_")[-1].split(".jpg")[0])
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    ximg = transforms.ToTensor()(img)

    size = 256  # patch size
    stride = 64  # patch stride
    patches = ximg.unfold(1, size, stride).unfold(2, size, stride)
    patches = patches.reshape(img.shape[2], -1, size, size)
    patches = torch.permute(patches, (1, 2, 3, 0))
    patches = patches.numpy()

    for i in range(patches.shape[0]):
        save_path = os.path.join(OUT_FOLDER, 'image_' + str(indexing) + '.jpg')
        cv2.imwrite(save_path, cv2.cvtColor(np.uint8(255 * patches[i, :, :, :]), cv2.COLOR_RGB2BGR))
        indexing += 1
print('Number of image patches in total: {}'.format(indexing))

print('===========================================================================')
print('                          LABEL PATCHES                                    ')
print('===========================================================================')

LABEL_PATCHES = []
indexing = 0
IMGS = glob.glob(folder2)
for label_path in sorted(IMGS, key=lambda x: int(x.split("_")[-1].split(".jpg")[0])):
    print(label_path.split("_")[-1].split(".jpg")[0])
    mask = cv2.imread(label_path)
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    xlabel = transforms.ToTensor()(mask)

    size = 256  # patch size
    stride = 64  # patch stride
    patches = xlabel.unfold(1, size, stride).unfold(2, size, stride)
    patches = patches.reshape(-1, size, size)
    patches = np.uint8(patches.numpy())

    for i in range(patches.shape[0]):
        LABEL_PATCHES.append(patches[i, :, :])
        indexing += 1
LABEL_PATCHES = np.array(LABEL_PATCHES)
print('Number of label patches in total: {}'.format(LABEL_PATCHES.shape[0]))
print(np.unique(LABEL_PATCHES))

print('===========================================================================')
print('                            Now Saving the patches                         ')
print('===========================================================================')


path = os.path.join('/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/train', 'label_patches.mat')
mdic = {"data": LABEL_PATCHES, "label": "train_labels"}
savemat(path, mdic)
print(f'Number of training labels: {LABEL_PATCHES.shape[0]}')


                          IMAGE PATCHES                                    
0
1
2
3
4
5
6
7
8
9
10
Number of image patches in total: 11
                          LABEL PATCHES                                    
0
1
2
3
4
5
6
7
8
9
10
Number of label patches in total: 1584
[0]
                            Now Saving the patches                         
Number of training labels: 1584


In [None]:
# This script is used to split the test set of the official GlaS into train-test split for our downstream
# histopathological image segmentation

from PIL import Image
from torchvision.transforms import transforms
import torch
import glob
import json
import numpy as np
import os
import cv2
from skimage import io
import sys
from scipy.io import loadmat, savemat


def split_data(images, labels, ratio):
    idxs = np.random.RandomState(2023).permutation(images.shape[0])
    split = int(images.shape[0] * ratio)
    split_1 = idxs[:split]
    split_2 = idxs[split:]
    return images[split_1], images[split_2], labels[split_1], labels[split_2]


folder1 = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/data/test_images'
folder2 = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/data/test_masks'

# directories for images and labels
TRAIN_OUT_FOLDER = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/train/full_size_images'
os.makedirs(TRAIN_OUT_FOLDER, exist_ok=True)
TEST_OUT_FOLDER = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/test/full_size_images'
os.makedirs(TEST_OUT_FOLDER, exist_ok=True)

TRAIN_LABEL_FOLDER = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/train/full_size_masks'
os.makedirs(TRAIN_LABEL_FOLDER, exist_ok=True)
TEST_LABEL_FOLDER = './content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/test/full_size_masks'
os.makedirs(TEST_LABEL_FOLDER, exist_ok=True)

print('===========================================================================')
print('                          IMAGES and LABELS                                    ')
print('===========================================================================')

total_list1 = os.listdir(folder1)
total_list1 = np.array(sorted(total_list1, key=lambda x: int(x.split('_')[-1].split('.jpg')[0])))

total_list2 = os.listdir(folder2)
total_list2 = np.array(sorted(total_list2, key=lambda x: int(x.split('_')[-1].split('.jpg')[0])))

ratio = 0.8
# Get the minimum length between the two lists

min_len = min(len(total_list1), len(total_list2))

# Generate indices based on the minimum length

idxs = np.random.RandomState(2023).permutation(min_len)

split = int(min_len * ratio)
split_1 = idxs[:split]
split_2 = idxs[split:]
train_images, test_images = total_list1[split_1], total_list1[split_2]
train_labels, test_labels = total_list2[split_1], total_list2[split_2]

i = 0
for img_name in train_images:
    img = cv2.imread(os.path.join(folder1, img_name))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    TRAIN_OUT_IMAGE_PATH = os.path.join(TRAIN_OUT_FOLDER, 'image_' + str(i) + '.jpg')
    cv2.imwrite(TRAIN_OUT_IMAGE_PATH, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
    print(img_name)
    i += 1

print('='*30)

i = 0
for label_name in train_labels:
    mask = cv2.imread(os.path.join(folder2, label_name))
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    TRAIN_OUT_IMAGE_PATH = os.path.join(TRAIN_LABEL_FOLDER, 'label_' + str(i) + '.jpg')
    cv2.imwrite(TRAIN_OUT_IMAGE_PATH, mask)
    print(label_name)
    i += 1

print('='*30)

i = 0
for img_name in test_images:
    img = cv2.imread(os.path.join(folder1, img_name))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    TEST_OUT_IMAGE_PATH = os.path.join(TEST_OUT_FOLDER, 'image_' + str(i) + '.jpg')
    cv2.imwrite(TEST_OUT_IMAGE_PATH, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
    print(img_name)
    i += 1

print('='*30)

i = 0
for label_name in test_labels:
    mask = cv2.imread(os.path.join(folder2, label_name))
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    TEST_OUT_IMAGE_PATH = os.path.join(TEST_LABEL_FOLDER, 'label_' + str(i) + '.jpg')
    cv2.imwrite(TEST_OUT_IMAGE_PATH, mask)
    print(label_name)
    i += 1


                          IMAGES and LABELS                                    
image_11.jpg
image_8.jpg
image_12.jpg
image_2.jpg
image_0.jpg
image_5.jpg
image_10.jpg
image_4.jpg
image_3.jpg
image_1.jpg
image_13.jpg
mask_11.jpg
mask_8.jpg
mask_12.jpg
mask_2.jpg
mask_0.jpg
mask_5.jpg
mask_10.jpg
mask_4.jpg
mask_3.jpg
mask_1.jpg
mask_13.jpg
image_6.jpg
image_9.jpg
image_7.jpg
mask_6.jpg
mask_9.jpg
mask_7.jpg


In [None]:
!pip install import-ipynb

Collecting import-ipynb
  Downloading import_ipynb-0.2-py3-none-any.whl.metadata (2.3 kB)
Collecting jedi>=0.16 (from IPython->import-ipynb)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Downloading import_ipynb-0.2-py3-none-any.whl (4.0 kB)
Downloading jedi-0.19.2-py2.py3-none-any.whl (1.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m31.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: jedi, import-ipynb
Successfully installed import-ipynb-0.2 jedi-0.19.2


In [None]:
!pip install import_ipynb
import import_ipynb



In [None]:
import sys
# Append the directory to your python path using sys
sys.path.append('/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/GenSelfDiff/pretrain/utils.ipynb')

In [None]:
import torch
import torch.nn as nn
from torchsummary import summary
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.nn import Conv2d, ConvTranspose2d, Linear, Embedding
from torch.nn import MaxPool2d, BatchNorm2d
from torch.nn import LeakyReLU, Tanh, ReLU, Sigmoid
from torch.nn import Module
from torch.nn import MSELoss
from torch import flatten
from functools import partial
import numpy as np
import random
import math
import os, os.path
from inspect import isfunction
from einops import rearrange
from tqdm import tqdm


def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x


def Upsample(dim):
    return nn.ConvTranspose2d(dim, dim, 4, 2, 1)


def Downsample(dim):
    return nn.Conv2d(dim, dim, 4, 2, 1)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x


class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
            if exists(time_emb_dim)
            else None
        )

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        h = self.block1(x)

        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            # print(time_emb.shape)
            h = rearrange(time_emb, "b c -> b c 1 1") + h

        h = self.block2(h)
        return h + self.res_conv(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q * self.scale

        sim = torch.einsum("b h d i, b h d j -> b h i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = torch.einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)


class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)


class Unet(nn.Module):
    def __init__(
            self,
            dim,
            init_dim=None,
            out_dim=None,
            dim_mults=(1, 2, 4, 8),
            channels=3,
            with_time_emb=True,
            resnet_block_groups=8,
            convnext_mult=2,
            encoder_only=False
    ):
        super().__init__()
        self.encoder_only = encoder_only
        # determine dimensions
        self.channels = channels

        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # time embeddings
        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(dim),
                nn.Linear(dim, time_dim),
                nn.GELU(),
                nn.Linear(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None

        # layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Downsample(dim_out) if not is_last else nn.Identity(),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        if self.encoder_only is False:
            print('decoder')
            for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
                is_last = ind >= (num_resolutions - 1)

                self.ups.append(
                    nn.ModuleList(
                        [
                            block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                            block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                            Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                            Upsample(dim_in) if not is_last else nn.Identity(),
                        ]
                    )
                )

            out_dim = default(out_dim, channels)
            self.final_conv = nn.Sequential(
                block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
            )

    def forward(self, x, time):
        x = self.init_conv(x)

        t = self.time_mlp(time) if exists(self.time_mlp) else None

        h = []

        # down sample
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)
            x = downsample(x)

        # bottleneck
        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        # up sample
        if self.encoder_only is False:
            for block1, block2, attn, upsample in self.ups:
                x = torch.cat((x, h.pop()), dim=1)
                x = block1(x, t)
                x = block2(x, t)
                x = attn(x)
                x = upsample(x)

            x = self.final_conv(x)

        return x


class DiffusionNet(nn.Module):
    def __init__(self, dim, channels):
        super(DiffusionNet, self).__init__()

        self.net = Unet(dim=dim, channels=channels, dim_mults=(1, 2, 4, 8))

    def forward(self, x, time_stamps):

        e = self.net(x, time_stamps)

        return e


In [None]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, CelebA
from torch.utils.data import Subset
from torch.utils.data import Dataset, DataLoader

import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
import os
from skimage import io
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter
import random

IMG_SIZE = 256


def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)


def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)


def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2


def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start


def get_index_from_list(vals, t, x_shape):
    """
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)


def forward_diffusion_sample(x_0, t, betas_schedule, device="cpu"):
    """
    Takes an image and a timestep as input and
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(betas_schedule['sqrt_alphas_cumprod'], t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        betas_schedule['sqrt_one_minus_alphas_cumprod'], t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)


def _my_normalization(x):
    return (x * 2) - 1


def get_images_list(path1, k=None):
    total_list1 = os.listdir(path1)
    total_list1 = sorted(total_list1, key=lambda x: int(x.split('_')[-1].split('.jpg')[0]))
    if k is None:
        return np.array(total_list1)
    else:
        return np.array(total_list1[:k])


class Histo_Dataset(Dataset):
    def __init__(self, image1_dir, image1_list, transform=None):
        self.image1_dir = image1_dir
        self.image1_list = image1_list
        self.transform = transform

    def __len__(self):
        return len(self.image1_list)

    def __getitem__(self, index):
        img1_path = os.path.join(self.image1_dir, self.image1_list[index])

        image1 = io.imread(img1_path)

        if self.transform is not None:
            image1 = self.transform(image1)

        return image1


def load_transformed_dataset():
    data_transforms = [
        transforms.ToTensor(),  # Scales data into [0,1]
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.Lambda(_my_normalization)  # Scale between [-1, 1]
    ]
    data_transform = transforms.Compose(data_transforms)

    data_size = None
    TRAIN_IMAGE_DIR = './img_patches'  # Directory for the unlabeled images used for pretraining
    img1_list = get_images_list(TRAIN_IMAGE_DIR, k=data_size)

    ratio = 0.9
    idxs = np.random.RandomState(2023).permutation(img1_list.shape[0])
    split = int(img1_list.shape[0] * ratio)
    train_index = idxs[:split]
    valid_index = idxs[split:]

    train_dataset = Histo_Dataset(TRAIN_IMAGE_DIR, img1_list[train_index], transform=data_transform)
    eval_dataset = Histo_Dataset(TRAIN_IMAGE_DIR, img1_list[valid_index], transform=data_transform)

    return train_dataset, eval_dataset


def reverse_transforms_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)),  # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :]
    return reverse_transforms(image)


def get_beta_schedule(betas):
    schedule = {}
    schedule['alphas'] = 1. - betas
    schedule['alphas_cumprod'] = torch.cumprod(schedule['alphas'], dim=0)
    schedule['alphas_cumprod_prev'] = F.pad(schedule['alphas_cumprod'][:-1], (1, 0), value=1.0)
    schedule['sqrt_recip_alphas'] = torch.sqrt(1.0 / schedule['alphas'])
    schedule['sqrt_alphas_cumprod'] = torch.sqrt(schedule['alphas_cumprod'])
    schedule['sqrt_one_minus_alphas_cumprod'] = torch.sqrt(1. - schedule['alphas_cumprod'])
    schedule['posterior_variance'] = betas * (1. - schedule['alphas_cumprod_prev']) / (
                1. - schedule['alphas_cumprod'])
    return schedule


def get_loss(noise, noise_pred, time_stamps, betas_schedule, gpu):
    t = time_stamps.cpu()
    snr = 1.0 / (1 - betas_schedule['alphas_cumprod'][t]) - 1
    k = 1.0
    gamma = 1.0
    lambda_t = 1.0/((k+snr)**gamma)
    lambda_t = lambda_t.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(gpu)

    n = noise.shape[1] * noise.shape[2] * noise.shape[3]
    loss = torch.sum(lambda_t * F.mse_loss(noise, noise_pred, reduction='none'))/n
    return loss


# def get_loss(noise, noise_pred, time_stamps, betas_schedule, gpu):
#     t = time_stamps.cpu()
#     n = noise.shape[1] * noise.shape[2] * noise.shape[3]

#     snr = 1.0 / (1 - betas_schedule['alphas_cumprod'][t]) - 1
#     k = 1.0
#     gamma = 1.0
#     lambda_t = 1.0/((k+snr)**gamma)
#     lambda_t = lambda_t.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(gpu)
#     loss1 = torch.sum(lambda_t * F.mse_loss(noise, noise_pred, reduction='none'))/n

#     scale_factor = (1.0 - betas_schedule['alphas'][t]) / (betas_schedule['alphas'][t] * (1.0 - betas_schedule['alphas_cumprod'][t]))
#     scale_factor = scale_factor.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(gpu)
#     loss2 = torch.sum(scale_factor * F.mse_loss(noise, noise_pred, reduction='none'))/n

#     c = 0.001
#     loss = loss1 + c*loss2
#     return loss


# def get_loss(noise, noise_pred, time_stamps, betas_schedule, gpu):
#     t = time_stamps.cpu()
#     n = noise.shape[1] * noise.shape[2] * noise.shape[3]

#     loss = torch.sum(F.mse_loss(noise, noise_pred, reduction='none'))/n

#     return loss


class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(self, patience=10, verbose=False, delta=0,
                 path='checkpoint.pth'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pth'
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model, epoch=None, ddp=False):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch, ddp)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')

            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch, ddp)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, epoch, ddp):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        if epoch != None:
            weight_path = self.path[:-4] + '_' + str(epoch) + '_' + str(val_loss)[:7] + '.pth'
        else:
            weight_path = self.path

        torch.save({
            'epoch': epoch,
            'loss': val_loss,
            'model_state_dict': model.module.state_dict(),
        }, weight_path)

        self.val_loss_min = val_loss



  from scipy.ndimage.interpolation import map_coordinates
  from scipy.ndimage.filters import gaussian_filter


In [None]:
%run '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/GenSelfDiff/pretrain/utils.ipynb'
%run '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/GenSelfDiff/pretrain/model.ipynb'
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.multiprocessing as mp
#import utils
#from model import DiffusionNet

import os
import shutil
from matplotlib import pyplot as plt
from tqdm import tqdm
import time
import numpy as np
from scipy.io import savemat

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
IMG_SIZE = 256
EPOCHS = 100
BATCH_SIZE = 8
LEARNING_RATE = 0.0001
T = 1000
DATA_TYPE = 'diff_quadratic'

betas = quadratic_beta_schedule(timesteps=T)
betas_schedule = get_beta_schedule(betas)


def cleanup():
    dist.destroy_process_group()


@torch.no_grad()
def sample_timestep(model, x, t):
    """
    Calls the model to predict the noise in the image and returns
    the denoised image.
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        betas_schedule['sqrt_one_minus_alphas_cumprod'], t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(betas_schedule['sqrt_recip_alphas'], t, x.shape)

    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(betas_schedule['posterior_variance'], t, x.shape)

    if t == 0:
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise


@torch.no_grad()
def sample_plot_image(model, gpu, epoch):
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=gpu)
    plt.figure()
    plt.axis('off')
    num_images = 100
    stepsize = int(T/num_images)

    all_images = []
    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=gpu, dtype=torch.long)
        img = sample_timestep(model, img, t)
        if i % stepsize == 0:
            all_images.append(img)

    fig, axs = plt.subplots(10,10)
    x=0
    for i in range(10):
        for j in range(10):
            out_img = reverse_transforms_image(all_images[x].detach().cpu())
            axs[i,j].imshow(out_img)
            axs[i,j].axis('off')
            x += 1
    plt.savefig('./images/'+DATA_TYPE+'/image_' + str(epoch) + '.jpg', dpi=300)


def initialize_weights(model):
    # Initializes weights according to the normal distribution
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.normal_(m.weight.data, 0.0, 0.01)


def train_epoch(train_dataloader, model, optimizer, gpu, epoch, args):

    model.train()
    losses = []
    p_bar = tqdm(train_dataloader)

    for img_batch in p_bar:
        optimizer.zero_grad()

        img_batch = img_batch.to(gpu, non_blocking=False)
        t = torch.randint(0, T, (img_batch.shape[0],)).long()
        t = t.to(gpu, non_blocking=False)
        x_noisy, noise = forward_diffusion_sample(img_batch, t, betas_schedule, gpu)
        noise_pred = model(x_noisy, t)
        loss = get_loss(noise, noise_pred, t, betas_schedule, gpu)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        p_bar.set_description('Epoch {}'.format(epoch))
        p_bar.set_postfix(loss=loss.item())

    print('Epoch: {}\ttotal_loss {:.4f}'.format(epoch, np.mean(losses)))

    return np.mean(losses)


def eval_epoch(eval_dataloader, model, gpu, epoch, args, early_stopping=None):

    with torch.no_grad():
        model.eval()
        losses = []
        p_bar = tqdm(eval_dataloader)

        for img_batch in p_bar:
            img_batch = img_batch.to(gpu, non_blocking=False)
            t = torch.randint(0, T, (img_batch.shape[0],)).long()
            t = t.to(gpu, non_blocking=False)
            x_noisy, noise = forward_diffusion_sample(img_batch, t, betas_schedule, gpu)
            noise_pred = model(x_noisy, t)
            loss = get_loss(noise, noise_pred, t, betas_schedule, gpu)

            losses.append(loss.item())

            p_bar.set_description('Epoch {}'.format(epoch))
            p_bar.set_postfix(loss=loss.item())

    print('Epoch: {}\ttotal_loss {:.4f}'.format(epoch, np.mean(losses)))

    return np.mean(losses)


def main(gpu, args):
    rank = args['nr'] * args['gpus'] + gpu
    dist.init_process_group('nccl', rank=rank, world_size=args['world_size'])
    torch.cuda.set_device(gpu)

    data_size = None

    # data loaders
    train_dataset, eval_dataset = load_transformed_dataset()

    train_sampler = torch.data.distributed.DistributedSampler(train_dataset,
                                                                    num_replicas=args['world_size'],
                                                                    rank=rank)
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                                drop_last=True, num_workers=4, pin_memory=True,
                                sampler=train_sampler)

    eval_sampler = torch.data.distributed.DistributedSampler(eval_dataset,
                                                                    num_replicas=args['world_size'],
                                                                    rank=rank, shuffle=False)
    eval_dataloader = DataLoader(eval_dataset, batch_size=BATCH_SIZE,
                                shuffle=False, drop_last=False, num_workers=4, pin_memory=True,
                                sampler=eval_sampler)

    model = DiffusionNet(dim=64, channels=3).to(gpu)
    initialize_weights(model)
    print("Num params: ", sum(p.numel() for p in model.parameters()))
    model = DDP(model, device_ids=[gpu], find_unused_parameters=True)

    optimizer = Adam(model.parameters(), lr=LEARNING_RATE)

    checkpoint_path = args['checkpoints_path']
    os.makedirs(checkpoint_path, exist_ok=True)
    os.makedirs('./images/' + DATA_TYPE, exist_ok=True)

    if args['load_from_chkpt'] is not None:
        chkpt_file = args['load_from_chkpt']
        print('Loading checkpoint from:', chkpt_file)
        checkpoint = torch.load(chkpt_file)
        model.module.load_state_dict(checkpoint['model_state_dict'])

    epoch_start = 0

    if gpu == 0:
        early_stopping = EarlyStopping(patience=15, verbose=True,
                                             path=checkpoint_path + '{}_{}.pth'.format(BATCH_SIZE, LEARNING_RATE))
    else:
        early_stopping = None

    train_losses = []
    eval_losses = []
    start_time = time.process_time()
    for epoch in range(epoch_start, EPOCHS):
        print('epoch {}/{}'.format(epoch + 1, EPOCHS))
        train_sampler.set_epoch(epoch)
        train_loss = train_epoch(train_dataloader, model, optimizer, gpu, epoch + 1, args)
        eval_loss = eval_epoch(eval_dataloader, model, gpu, epoch + 1, early_stopping)

        mean_train_loss = torch.tensor(train_loss / args['gpus']).to(gpu)
        mean_eval_loss = torch.tensor(eval_loss / args['gpus']).to(gpu)

        dist.barrier()
        dist.all_reduce(mean_train_loss)
        dist.all_reduce(mean_eval_loss)
        print('gpu {} eval_loss:{}, mean_loss:{}'.format(gpu, eval_loss,
                                                         mean_eval_loss.cpu().numpy()))

        # if optim_name.split('-')[-1] == 'step':
        #     scheduler.step(mean_eval_loss.cpu().numpy())
        # elif optim_name.split('-')[-1] == 'cosine':
        #     scheduler.step()
        # elif optim_name.split('-')[-1] == 'no':
        #     pass

        if (epoch+1) % 20 == 0:
            sample_plot_image(model, gpu, epoch+1)

        if gpu == 0:
            early_stopping(mean_eval_loss.cpu().numpy(), model, epoch + 1)

        train_losses.append(mean_train_loss.cpu().numpy())
        eval_losses.append(mean_eval_loss.cpu().numpy())

    current_time = time.process_time()
    print("Total Time Elapsed={:12.5} seconds".format(str(current_time - start_time)))

    # saving the plots
    plots_path = './plots/diff'
    os.makedirs(plots_path, exist_ok=True)
    epochs = np.arange(EPOCHS)
    train_losses = np.array(train_losses)
    eval_losses = np.array(eval_losses)
    fig, axes = plt.subplots(1, 1, figsize=(8, 5))
    axes.plot(epochs, train_losses, 'tab:blue', epochs, eval_losses, 'tab:orange')
    axes.set_title(f'Training and Validation Loss (pretrained model = None, loss = MSE Loss, '
                   f'data size = {data_size})',
                   weight='bold', fontsize=7)
    axes.set_xlabel('Epochs', weight='bold', fontsize=9)
    axes.set_ylabel('Loss', weight='bold', fontsize=9)
    plt.savefig(plots_path + '/'+DATA_TYPE+'loss_' + str(data_size) + '.jpg', dpi=300)

    # os.makedirs('./loss', exist_ok=True)
    # path = os.path.join('./loss', DATA_TYPE+'_train_loss.mat')
    # mdic = {"data": train_losses, "label": "epochs"}
    # savemat(path, mdic)

    cleanup()


if __name__ == '__main__':

    args = {}

    args['gpus'] = 1
    args['nr'] = 0
    args['world_size'] = args['gpus']
    args['checkpoints_path'] =  './snapshots/' + DATA_TYPE + '/'
    args['load_from_chkpt'] = None

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'

    print(args['gpus'])
   # mp.spawn(main, args=(args,), nprocs=args['gpus'])


  from scipy.ndimage.interpolation import map_coordinates
  from scipy.ndimage.filters import gaussian_filter


1


In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import numpy as np
from sklearn.metrics import f1_score, recall_score, precision_score
import os
import cv2


class FLoss(nn.Module):
    def __init__(self, gamma, weight=1.0):
        super(FLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight

    def forward(self, y_pred, y_true):
        '''
            y_true shape: NxCxHxW
            y_pred shape: NxCxHxW
            '''

        N = y_true.shape[0] * y_true.shape[2] * y_true.shape[3]
        t2 = -self.weight * torch.pow((y_pred - 1), self.gamma) * y_true * torch.log(y_pred)
        loss = torch.sum(t2) / N

        return loss


class CELoss(nn.Module):
    def __init__(self, weight=None):
        super(CELoss, self).__init__()
        self.weights = weight

    def forward(self, y_pred, y_true):
        '''
                    y_true shape: NxCxHxW
                    y_pred shape: NxCxHxW
        '''

        N = y_true.shape[0] * y_true.shape[2] * y_true.shape[3]
        loss = -torch.sum(y_true * torch.log(y_pred)) / N

        return loss


class SSLoss(nn.Module):
    def __init__(self, beta=0.1, C=0.01):
        super(SSLoss, self).__init__()
        self.beta = beta
        self.C = C
        self.LCE = CELoss()

    def forward(self, y_pred, y_true):
        '''
            y_true shape: NxCxHxW
            y_pred shape: NxCxHxW
        '''

        mean_true = torch.mean(y_true, (2, 3), keepdim=True)
        std_true = torch.std(y_true, (2, 3), keepdim=True)
        mean_pred = torch.mean(y_pred, (2, 3), keepdim=True)
        std_pred = torch.std(y_pred, (2, 3), keepdim=True)

        e1 = (y_true - mean_true + self.C) / (std_true + self.C)
        e2 = (y_pred - mean_pred + self.C) / (std_pred + self.C)
        e = torch.abs(e1 - e2)

        e_max, _ = torch.max(torch.flatten(e, start_dim=2), dim=2, keepdim=True)
        e_max = torch.unsqueeze(e_max, dim=2)
        f = (e > (self.beta * e_max)).float()

        lce = self.LCE(y_pred, y_true)

        loss = e * f * lce
        M = torch.sum(f)

        ssl_loss = torch.sum(loss)/M

        return ssl_loss


def tversky_coefficient(y_true, y_predict, smooth=1.0, beta=0.3):
    intersection = torch.sum(y_true * y_predict)
    i1 = beta * torch.sum((1-y_true) * y_predict)
    i2 = (1-beta) * torch.sum(y_true * (1-y_predict))
    return (intersection + smooth) / (intersection + i1 + i2 + smooth)


class Tversky_Loss(nn.Module):
    def __init__(self, beta):
        super(Tversky_Loss, self).__init__()
        self.beta = beta

    def forward(self, y_predict, y_true):

        tversky = 0.0
        N = y_true.shape[0]
        for i in range(y_true.shape[0]):
            tversky += (1 - tversky_coefficient(y_true[i], y_predict[i], beta=self.beta))
        loss = tversky/N

        return loss


class CosineLoss(nn.Module):
    def __init__(self):
        super(CosineLoss, self).__init__()

    def forward(self, y_predict, y_true):
        N = y_true.shape[0] * y_true.shape[2] * y_true.shape[3]
        product_sum = torch.sum(y_true * y_predict, dim=1)
        loss = torch.sum(torch.cos((3.1416/2) * product_sum))/N

        return loss


class FocalLogLoss(nn.Module):
    def __init__(self, gamma):
        super(FocalLogLoss, self).__init__()
        self.gamma = gamma

    def forward(self, y_predict, y_true):
        loss = torch.ones_like(y_true)
        N = y_true.shape[0] * y_true.shape[1] * y_true.shape[2] * y_true.shape[3]
        wrong_predictions = y_predict[y_true == 0]
        loss[y_true == 0] = -15 * torch.pow(wrong_predictions, 2)
        right_predictions = y_predict[y_true == 1]
        loss[y_true != 0] = 15 * torch.pow((right_predictions - 1), self.gamma) * torch.log(right_predictions)

        loss = -torch.sum(loss)/N

        return loss


class LogMaxLoss(nn.Module):
    def __init__(self, gamma):
        super(LogMaxLoss, self).__init__()
        self.gamma = gamma

    def forward(self, y_predict, y_true):

        y_false = 1.0 * torch.logical_not(y_true)
        loss1 = 5 * torch.pow(torch.sum(y_false * y_predict, dim=1), 2)
        loss2 = -5 * torch.sum(torch.pow((y_predict - 1), self.gamma) * y_true * torch.log(y_predict), dim=1)

        log_max_loss = torch.mean(torch.maximum(loss1, loss2))

        return log_max_loss


class PolyLogLoss(nn.Module):
    def __init__(self, gamma, weight=1.0):
        super(PolyLogLoss, self).__init__()
        self.gamma = gamma
        self.weights = weight

    def forward(self, y_predict, y_true):

        right_predictions = torch.sum(y_true * y_predict, 1)
        loss1 = -torch.pow((right_predictions - 1), self.gamma) * self.weights * torch.log(right_predictions)

        y_false = 1.0 * torch.logical_not(y_true)
        wrong_predictions = torch.sum(y_false * y_predict, 1)
        loss2 = -torch.pow(wrong_predictions, self.gamma) * torch.log(wrong_predictions)

        poly_log_loss = torch.mean(torch.abs(loss1 - loss2))

        return poly_log_loss


In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms.functional as TF
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.multiprocessing as mp

from scipy.io import loadmat, savemat
from torchsummary import summary
from torch.nn import functional as F
from skimage import io
import torchvision.transforms as transforms
from tqdm import tqdm
from torchvision.models import resnet18, resnet50
from matplotlib import pyplot as plt
from sklearn.metrics import f1_score
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter
import random
import time
import sys
from PIL import Image
%run '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/GenSelfDiff/downstream_train/losses.ipynb'
%run '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/GenSelfDiff/downstream_train/model.ipynb'
#from losses import CELoss, Tversky_Loss, FLoss, SSLoss, PolyLogLoss
#from model import SegNet


os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# seed = 0  # You can choose any value as the seed
# torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)

# Hyper Parameters
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
IMG_CHANNELS = 3
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
EPOCHS = 150
NUM_CLASSES = 4 # number of classes
TRAIN_IMAGE_DIR = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/train/images'
TRAIN_LABEL_PATH = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/train/label_patches.mat'
MODEL_TYPE = 'diff_quadratic_SSFL'
LOSS_TYPE = 'SSFL' + '_' + str(BATCH_SIZE)

transformations = transforms.Compose([
    transforms.ToTensor(),
])


def cleanup():
    dist.destroy_process_group()

# Can use other image augmentations also
def my_transforms(image1, mask):
    if random.random() > 0.5:
        image1 = TF.vflip(image1)
        mask = TF.vflip(mask)

    if random.random() > 0.5:
        image1 = TF.hflip(image1)
        mask = TF.hflip(mask)

    if random.random() > 0.7:
        image1 = TF.gaussian_blur(image1, [3, 3], [1.0, 2.0])

    # if random.random() > 0.7:
    #     image1 = TF.adjust_sharpness(image1, 2.0)

    if random.random() > 0.7:
        jitter = transforms.ColorJitter(brightness=.5, contrast=.4)
        image1 = jitter(image1)

    return image1, mask


class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(self, patience=10, verbose=False, delta=0,
                 path='checkpoint.pth'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pth'
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model, epoch=None, ddp=False):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch, ddp)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')

            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch, ddp)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, epoch, ddp):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        if epoch != None:
            weight_path = self.path[:-4] + '_' + str(epoch) + '_' + str(val_loss)[:7] + '.pth'
        else:
            weight_path = self.path

        torch.save({
            'epoch': epoch,
            'loss': val_loss,
            'model_state_dict': model.module.state_dict(),
        }, weight_path)
        self.val_loss_min = val_loss


def get_images_list(path1, k=None):
    total_list1 = os.listdir(path1)
    total_list1 = sorted(total_list1, key=lambda x: int(x.split('_')[-1].split('.jpg')[0]))

    if k is None:
        return np.array(total_list1)
    else:
        return np.array(total_list1[:k])


class Histo_Dataset(Dataset):
    def __init__(self, image1_dir, image1_list, label_list, transform=None):
        self.image1_dir = image1_dir
        self.image1_list = image1_list
        self.label_list = label_list
        self.transform = transform

    def __len__(self):
        return len(self.image1_list)

    def __getitem__(self, index):
        img1_path = os.path.join(self.image1_dir, self.image1_list[index])

        image1 = io.imread(img1_path)
        mask = self.label_list[index]

        if self.transform is not None:
            image1 = self.transform(image1)
            mask = 255 * self.transform(mask)
        image1, mask = my_transforms(image1, mask)

        return image1, mask


def initialize_weights(model):
    # Initializes weights according to the normal distribution
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.normal_(m.weight.data, 0.0, 0.01)


def F1_score(y_true, y_pred):
    class_f1_scores = []
    _, y_true = torch.max(y_true, 0)
    _, y_pred = torch.max(y_pred, 0)
    for i in range(NUM_CLASSES):
        true = (y_true == i).reshape(-1).cpu().numpy()
        pred = (y_pred == i).reshape(-1).cpu().numpy()
        value = f1_score(true, pred, zero_division=1)
        class_f1_scores.append(value)

    return class_f1_scores


def image_weights(y_true):
    _, target_label = torch.max(y_true, dim=1)
    img_weights = torch.zeros_like(y_true)
    for i in range(y_true.shape[0]):
        labels, label_counts = torch.unique(target_label[i], return_counts=True)
        label_counts_avg = torch.mean(label_counts / torch.sum(label_counts))
        img_weights[i] = 1.0 + label_counts_avg

        return img_weights


def class_weights(y_predict, y_true):
    fscores = np.zeros(NUM_CLASSES)
    for i in range(y_true.shape[0]):
        fscores += F1_score(y_true[i], y_predict[i])
    # fscores = fscores/y_true.shape[0]
    # weights = (1 - fscores + 0.001) / (fscores + 0.001)

    return fscores


def train_epoch(train_loader, model, optimizer, gpu, epoch):
    model.train()
    losses = []
    p_bar = tqdm(train_loader)

    for h, true_label in p_bar:
        h = h.to(gpu, non_blocking=False)
        true_label = true_label.squeeze(1).to(gpu, non_blocking=False)

        target_label = F.one_hot(true_label.long(), NUM_CLASSES)
        target_label = torch.permute(target_label, (0, 3, 1, 2))
        target_label = target_label.float()

        t = torch.full((h.shape[0],), 0, dtype=torch.long)
        t = t.to(gpu, non_blocking=False)
        predicted_label = model(h, t)

        # ce_loss = CELoss()
        ss_loss = SSLoss()
        f_loss = FLoss(2.0)
        # t_loss = Tversky_Loss(0.75)
        # poly_log_loss = PolyLogLoss(2.0)

        # loss1 = ce_loss(predicted_label, target_label)
        loss2 = ss_loss(predicted_label, target_label)
        loss3 = f_loss(predicted_label, target_label)
        # loss4 = t_loss(predicted_label, target_label)
        # loss5 = poly_log_loss(predicted_label, target_label)

        loss = loss2 + loss3

        losses.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        p_bar.set_description('Epoch {}'.format(epoch))
        p_bar.set_postfix(loss=loss.item())

    print('Epoch: {}\ttotal_loss {:.4f}'.format(epoch, np.mean(losses)))
    return np.mean(losses)


def eval_epoch(eval_loader, model, gpu, epoch, early_stopping=None):
    with torch.no_grad():
        model.eval()
        val_loss = []
        p_bar = tqdm(eval_loader)

        total_items = 0.0
        f_scores = np.zeros(NUM_CLASSES)
        for h, true_label in p_bar:
            h = h.to(gpu, non_blocking=False)
            true_label = true_label.squeeze(1).to(gpu, non_blocking=False)

            target_label = F.one_hot(true_label.long(), NUM_CLASSES)
            target_label = torch.permute(target_label, (0, 3, 1, 2))
            target_label = target_label.float()

            t = torch.full((h.shape[0],), 0, dtype=torch.long)
            t = t.to(gpu, non_blocking=False)

            predicted_label = model(h, t)
            f_scores += class_weights(predicted_label, target_label)
            total_items += target_label.shape[0]

            # ce_loss = CELoss()
            ss_loss = SSLoss()
            f_loss = FLoss(2.0)
            # t_loss = Tversky_Loss(0.75)
            # poly_log_loss = PolyLogLoss(2.0)

            # loss1 = ce_loss(predicted_label, target_label)
            loss2 = ss_loss(predicted_label, target_label)
            loss3 = f_loss(predicted_label, target_label)
            # loss4 = t_loss(predicted_label, target_label)
            # loss5 = poly_log_loss(predicted_label, target_label)

            loss = loss2 + loss3

            val_loss.append(loss.item())

            p_bar.set_description('Epoch {}'.format(epoch))
            p_bar.set_postfix(loss=loss.item())

    f_scores = f_scores / total_items
    print(f_scores)
    print(np.mean(f_scores[1:]))
    # cls_weights = torch.ones((BATCH_SIZE, NUM_CLASSES, 256, 256)).to(device)
    # for j in range(1, NUM_CLASSES):
    #     cls_weights[:, j, :, :] = 10 * ((1 - f_scores[j] + 0.0001) / (f_scores[j] + 0.0001))

    print('Epoch: {}\tval_loss {:.4f}'.format(epoch, np.mean(val_loss)))
    # early_stopping(np.mean(val_loss), model, epoch)

    return np.mean(val_loss)


def main(gpu, args):
    rank = args['nr'] * args['gpus'] + gpu
    dist.init_process_group('nccl', rank=rank, world_size=args['world_size'])
    torch.cuda.set_device(gpu)

    data_size = None
    backbone = 'Unet'

    # Loading the data
    img1_list = get_images_list(TRAIN_IMAGE_DIR, k=data_size)
    label_file = loadmat(TRAIN_LABEL_PATH)
    train_labels = label_file['data']

    ratio = 0.9
    idxs = np.random.RandomState(2023).permutation(img1_list.shape[0])
    split = int(img1_list.shape[0] * ratio)
    train_index = idxs[:split]
    valid_index = idxs[split:]

    train_dataset = Histo_Dataset(TRAIN_IMAGE_DIR,
                                  img1_list[train_index], train_labels[train_index],
                                  transform=transformations)
    eval_dataset = Histo_Dataset(TRAIN_IMAGE_DIR,
                                 img1_list[valid_index], train_labels[valid_index],
                                 transform=transformations)

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,
                                                                    num_replicas=args['world_size'],
                                                                    rank=rank)
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                                  drop_last=True, num_workers=4, pin_memory=True,
                                  sampler=train_sampler)
    eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset,
                                                                   num_replicas=args['world_size'],
                                                                   rank=rank, shuffle=False)
    eval_dataloader = DataLoader(eval_dataset, batch_size=BATCH_SIZE,
                                 shuffle=False, drop_last=False, num_workers=4, pin_memory=True,
                                 sampler=eval_sampler)

    model = SegNet(dim=64, channels=3, num_classes=4).to(gpu)
    initialize_weights(model)

    start_epoch = 0
    if args['load_from_chkpt'] is not None:
        chkpt_file = args['load_from_chkpt']
        print('Loading checkpoint from:', chkpt_file)
        checkpoint = torch.load(chkpt_file, map_location=torch.device('cpu'))
        pretrained_dict = checkpoint['model_state_dict']
        new_pretrained_dict = {k: v for k, v in pretrained_dict.items() if k[:15] != 'net.final_conv.'}
        model.load_state_dict(new_pretrained_dict, strict=False)
        # model.load_state_dict(pretrained_dict)

    print("Num params: ", sum(p.numel() for p in model.parameters()))
    model = DDP(model, device_ids=[gpu], find_unused_parameters=True)

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

    checkpoint_path = args['checkpoints_path']
    os.makedirs(checkpoint_path, exist_ok=True)

    if gpu == 0:
        early_stopping = EarlyStopping(patience=15, verbose=True,
                                       path=checkpoint_path + '{}_{}.pth'.format(BATCH_SIZE, LEARNING_RATE))
    else:
        early_stopping = None

    train_losses = []
    eval_losses = []
    start_time = time.process_time()
    for epoch in range(start_epoch, EPOCHS):
        print('epoch {}/{}'.format(epoch + 1, EPOCHS))
        train_sampler.set_epoch(epoch)
        train_loss = train_epoch(train_dataloader, model, optimizer, gpu, epoch + 1)
        eval_loss = eval_epoch(eval_dataloader, model, gpu, epoch + 1, early_stopping)

        mean_eval_loss = torch.tensor(eval_loss / args['gpus']).to(gpu)
        mean_train_loss = torch.tensor(train_loss / args['gpus']).to(gpu)
        dist.barrier()
        dist.all_reduce(mean_eval_loss)
        dist.all_reduce(mean_train_loss)
        print('gpu {} eval_loss:{}, mean_loss:{}'.format(gpu, eval_loss,
                                                         mean_eval_loss.cpu().numpy()))

        # if optim_name.split('-')[-1] == 'step':
        #     scheduler.step(mean_eval_loss.cpu().numpy())
        # elif optim_name.split('-')[-1] == 'cosine':
        #     scheduler.step()
        # elif optim_name.split('-')[-1] == 'no':
        #     pass

        if gpu == 0:
            early_stopping(mean_eval_loss.cpu().numpy(), model, epoch + 1)
        '''
        if early_stopping.early_stop:
            print('Early stop!')
            break
        '''
        train_losses.append(mean_train_loss.cpu().numpy())
        eval_losses.append(mean_eval_loss.cpu().numpy())

    current_time = time.process_time()
    print("Total Time Elapsed={:12.5} seconds".format(str(current_time - start_time)))

    # saving the plots
    plots_path = './plots/' + MODEL_TYPE
    os.makedirs(plots_path, exist_ok=True)
    epochs = np.arange(start_epoch, EPOCHS)
    train_losses = np.array(train_losses)
    eval_losses = np.array(eval_losses)
    fig, axes = plt.subplots(1, 1, figsize=(8, 5))
    axes.plot(epochs, train_losses, 'tab:blue', epochs, eval_losses, 'tab:orange')
    axes.set_title(f'Training and Validation Loss (pretrained model = diffusion, loss = (SS + Focal) Loss, '
                   f'data size = {data_size})',
                   weight='bold', fontsize=7)
    axes.set_xlabel('Epochs', weight='bold', fontsize=9)
    axes.set_ylabel('Loss', weight='bold', fontsize=9)
    axes.legend(['training loss', 'validation loss'], loc='best')
    plt.savefig(plots_path + '/' + backbone + '_' + LOSS_TYPE + 'loss_' + str(data_size) + '.jpg', dpi=300)

    cleanup()


if __name__ == '__main__':
    args = {}

    args['gpus'] = 1
    args['nr'] = 0
    args['world_size'] = args['gpus']
    args['checkpoints_path'] = './snapshots/' + MODEL_TYPE + '/'
    args['load_from_chkpt'] = 'path for the diffusion pretrained checkpoint'

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'

    print(args['gpus'])
    #mp.spawn(main, args=(args,), nprocs=args['gpus'])


  from scipy.ndimage.interpolation import map_coordinates
  from scipy.ndimage.filters import gaussian_filter


1


In [None]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, CelebA
from torch.utils.data import Subset
from torch.utils.data import Dataset, DataLoader

import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
import os
from skimage import io

IMG_SIZE = 256


def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)


def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)


def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2


def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start


def get_index_from_list(vals, t, x_shape):
    """
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)


def forward_diffusion_sample(x_0, t, betas_schedule, device="cpu"):
    """
    Takes an image and a timestep as input and
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(betas_schedule['sqrt_alphas_cumprod'], t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        betas_schedule['sqrt_one_minus_alphas_cumprod'], t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)


def _my_normalization(x):
    return (x * 2) - 1


def get_images_list(folder1, folder2, folder3, k=None):
    total_list1 = os.listdir(folder1)
    total_list1 = sorted(total_list1, key=lambda x: int(x.split('_')[-1].split('.jpg')[0]))
    path1_list = [os.path.join(folder1, f) for f in total_list1]

    total_list2 = os.listdir(folder2)
    total_list2 = sorted(total_list2, key=lambda x: int(x.split('_')[-1].split('.jpg')[0]))
    path2_list = [os.path.join(folder2, f) for f in total_list2]

    total_list3 = os.listdir(folder3)
    total_list3 = sorted(total_list3, key=lambda x: int(x.split('_')[-1].split('.jpg')[0]))
    path3_list = [os.path.join(folder3, f) for f in total_list3]
    if k is None:
        return path1_list, path2_list, path3_list
    else:
        return path1_list[:k], path2_list[:k], path3_list[:k]


class AIIMS_Dataset(Dataset):
    def __init__(self, images_list, transform=None):
        self.transform = transform
        self.images_list = images_list

    def __len__(self):
        return len(self.images_list)

    def __getitem__(self, index):
        img_path = self.images_list[index]
        image = io.imread(img_path)
        if self.transform is not None:
            image = self.transform(image)
        return image


def load_transformed_dataset():
    data_transforms = [
        transforms.ToTensor(), # Scales data into [0,1]
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.Lambda(_my_normalization) # Scale between [-1, 1]
    ]
    data_transform = transforms.Compose(data_transforms)

    data_size = None
    TRAIN_IMAGE_DIR1 = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/pre_process/HN/unlabelled_img_patches'
    TRAIN_IMAGE_DIR2 = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/pre_process/MoNuSeg/unlabelled_img_patches'
    TRAIN_IMAGE_DIR3 = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/pre_process/GlaS/unlabelled_img_patches'
    img1_list, img2_list, img3_list = get_images_list(TRAIN_IMAGE_DIR1, TRAIN_IMAGE_DIR2, TRAIN_IMAGE_DIR3, k=data_size)
    img_list = np.array(img1_list + img2_list + img3_list)

    ratio = 0.9
    idxs = np.random.RandomState(2023).permutation(img_list.shape[0])
    split = int(img_list.shape[0] * ratio)
    train_index = idxs[:split]
    valid_index = idxs[split:]

    train_dataset = AIIMS_Dataset(img_list[train_index], transform=data_transform)
    eval_dataset = AIIMS_Dataset(img_list[valid_index], transform=data_transform)


    return train_dataset, eval_dataset


def reverse_transforms_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :]
    return reverse_transforms(image)

def get_beta_schedule(betas):
    schedule = {}
    schedule['alphas'] = 1. - betas
    schedule['alphas_cumprod'] = torch.cumprod(schedule['alphas'], dim=0)
    schedule['alphas_cumprod_prev'] = F.pad(schedule['alphas_cumprod'][:-1], (1, 0), value=1.0)
    schedule['sqrt_recip_alphas'] = torch.sqrt(1.0 / schedule['alphas'])
    schedule['sqrt_alphas_cumprod'] = torch.sqrt(schedule['alphas_cumprod'])
    schedule['sqrt_one_minus_alphas_cumprod'] = torch.sqrt(1. - schedule['alphas_cumprod'])
    schedule['posterior_variance'] = betas * (1. - schedule['alphas_cumprod_prev']) / (
                1. - schedule['alphas_cumprod'])
    return schedule


def get_loss(noise, noise_pred, time_stamps, betas_schedule, gpu):
    t = time_stamps.cpu()
    snr = 1.0 / (1 - betas_schedule['alphas_cumprod'][t]) - 1
    k = 1.0
    gamma = 1.0
    lambda_t = 1.0/((k+snr)**gamma)
    lambda_t = lambda_t.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(gpu)

    n = noise.shape[1] * noise.shape[2] * noise.shape[3]
    loss = torch.sum(lambda_t * F.mse_loss(noise, noise_pred, reduction='none'))/n
    return loss


# def get_loss(noise, noise_pred, time_stamps, betas_schedule, gpu):
#     t = time_stamps.cpu()
#     n = noise.shape[1] * noise.shape[2] * noise.shape[3]

#     snr = 1.0 / (1 - betas_schedule['alphas_cumprod'][t]) - 1
#     k = 1.0
#     gamma = 1.0
#     lambda_t = 1.0/((k+snr)**gamma)
#     lambda_t = lambda_t.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(gpu)
#     loss1 = torch.sum(lambda_t * F.mse_loss(noise, noise_pred, reduction='none'))/n

#     scale_factor = (1.0 - betas_schedule['alphas'][t]) / (betas_schedule['alphas'][t] * (1.0 - betas_schedule['alphas_cumprod'][t]))
#     scale_factor = scale_factor.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(gpu)
#     loss2 = torch.sum(scale_factor * F.mse_loss(noise, noise_pred, reduction='none'))/n

#     c = 0.001
#     loss = loss1 + c*loss2
#     return loss


# def get_loss(noise, noise_pred, time_stamps, betas_schedule, gpu):
#     t = time_stamps.cpu()
#     n = noise.shape[1] * noise.shape[2] * noise.shape[3]

#     loss = torch.sum(F.mse_loss(noise, noise_pred, reduction='none'))/n

#     return loss


class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(self, patience=10, verbose=False, delta=0,
                 path='checkpoint.pth'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pth'
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model, epoch=None, ddp=False):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch, ddp)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')

            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch, ddp)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, epoch, ddp):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        if epoch != None:
            weight_path = self.path[:-4] + '_' + str(epoch) + '_' + str(val_loss)[:7] + '.pth'
        else:
            weight_path = self.path

        torch.save({
            'epoch': epoch,
            'loss': val_loss,
            'model_state_dict': model.module.state_dict(),
        }, weight_path)

        self.val_loss_min = val_loss

In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.multiprocessing as mp
%run '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/utils_general.ipynb'
%run '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/GenSelfDiff/downstream_train/model.ipynb'
#import utils_general as utils
#from GenSelfDiff.downstream_train.model import DiffusionNet

import os
import shutil
from matplotlib import pyplot as plt
from tqdm import tqdm
import time
import numpy as np
from scipy.io import savemat

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
IMG_SIZE = 256
EPOCHS = 100
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
T = 1000
DATA_TYPE = 'diff_quadratic_general'

betas = quadratic_beta_schedule(timesteps=T)
betas_schedule = get_beta_schedule(betas)

def cleanup():
    dist.destroy_process_group()

@torch.no_grad()
def sample_timestep(model, x, t):
    """
    Calls the model to predict the noise in the image and returns
    the denoised image.
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        betas_schedule['sqrt_one_minus_alphas_cumprod'], t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(betas_schedule['sqrt_recip_alphas'], t, x.shape)

    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(betas_schedule['posterior_variance'], t, x.shape)

    if t == 0:
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise


@torch.no_grad()
def sample_plot_image(model, gpu, epoch):
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=gpu)
    plt.figure()
    plt.axis('off')
    num_images = 100
    stepsize = int(T/num_images)

    all_images = []
    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=gpu, dtype=torch.long)
        img = sample_timestep(model, img, t)
        if i % stepsize == 0:
            all_images.append(img)

    fig, axs = plt.subplots(10,10)
    x=0
    for i in range(10):
        for j in range(10):
            out_img = reverse_transforms_image(all_images[x].detach().cpu())
            axs[i,j].imshow(out_img)
            axs[i,j].axis('off')
            x += 1
    plt.savefig('./images/'+DATA_TYPE+'/image_' + str(epoch) + '.jpg', dpi=300)


def initialize_weights(model):
    # Initializes weights according to the normal distribution
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.normal_(m.weight.data, 0.0, 0.01)

def train_epoch(train_dataloader, model, optimizer, gpu, epoch, args):

    model.train()
    losses = []
    p_bar = tqdm(train_dataloader)

    for img_batch in p_bar:
        optimizer.zero_grad()

        img_batch = img_batch.to(gpu, non_blocking=False)
        t = torch.randint(0, T, (img_batch.shape[0],)).long()
        t = t.to(gpu, non_blocking=False)
        x_noisy, noise = forward_diffusion_sample(img_batch, t, betas_schedule, gpu)
        noise_pred = model(x_noisy, t)
        loss = get_loss(noise, noise_pred, t, betas_schedule, gpu)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        p_bar.set_description('Epoch {}'.format(epoch))
        p_bar.set_postfix(loss=loss.item())

    print('Epoch: {}\ttotal_loss {:.4f}'.format(epoch, np.mean(losses)))

    return np.mean(losses)


def eval_epoch(eval_dataloader, model, gpu, epoch, args, early_stopping=None):

    with torch.no_grad():
        model.eval()
        losses = []
        p_bar = tqdm(eval_dataloader)

        for img_batch in p_bar:
            img_batch = img_batch.to(gpu, non_blocking=False)
            t = torch.randint(0, T, (img_batch.shape[0],)).long()
            t = t.to(gpu, non_blocking=False)
            x_noisy, noise = forward_diffusion_sample(img_batch, t, betas_schedule, gpu)
            noise_pred = model(x_noisy, t)
            loss = get_loss(noise, noise_pred, t, betas_schedule, gpu)

            losses.append(loss.item())

            p_bar.set_description('Epoch {}'.format(epoch))
            p_bar.set_postfix(loss=loss.item())

    print('Epoch: {}\ttotal_loss {:.4f}'.format(epoch, np.mean(losses)))

    return np.mean(losses)



def main(gpu, args):
    rank = args['nr'] * args['gpus'] + gpu
    dist.init_process_group('nccl', rank=rank, world_size=args['world_size'])
    torch.cuda.set_device(gpu)

    data_size = None
    # img1_list = get_images_list(TRAIN_IMAGE_DIR, k=data_size)

    # data loaders
    train_dataset, eval_dataset = load_transformed_dataset()

    train_sampler = torch.data.distributed.DistributedSampler(train_dataset,
                                                                    num_replicas=args['world_size'],
                                                                    rank=rank)
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                                drop_last=True, num_workers=4, pin_memory=True,
                                sampler=train_sampler)

    eval_sampler = torch.data.distributed.DistributedSampler(eval_dataset,
                                                                    num_replicas=args['world_size'],
                                                                    rank=rank, shuffle=False)
    eval_dataloader = DataLoader(eval_dataset, batch_size=BATCH_SIZE,
                                shuffle=False, drop_last=False, num_workers=4, pin_memory=True,
                                sampler=eval_sampler)


    model = DiffusionNet(dim=64, channels=3).to(gpu)
    initialize_weights(model)
    print("Num params: ", sum(p.numel() for p in model.parameters()))
    model = DDP(model, device_ids=[gpu], find_unused_parameters=True)

    optimizer = Adam(model.parameters(), lr=LEARNING_RATE)

    checkpoint_path = args['checkpoints_path']
    os.makedirs(checkpoint_path, exist_ok=True)
    os.makedirs('./images/' + DATA_TYPE, exist_ok=True)

    epoch_start = 0
    if args['load_from_chkpt'] is not None:
        chkpt_file = args['load_from_chkpt']
        print('Loading checkpoint from:', chkpt_file)
        checkpoint = torch.load(chkpt_file)
        model.module.load_state_dict(checkpoint['model_state_dict'])

    if gpu == 0:
        early_stopping = EarlyStopping(patience=15, verbose=True,
                                             path=checkpoint_path + '{}_{}.pth'.format(BATCH_SIZE, LEARNING_RATE))
    else:
        early_stopping = None

    train_losses = []
    eval_losses = []
    start_time = time.process_time()
    for epoch in range(epoch_start, EPOCHS):
        print('epoch {}/{}'.format(epoch + 1, EPOCHS))
        train_sampler.set_epoch(epoch)
        train_loss = train_epoch(train_dataloader, model, optimizer, gpu, epoch + 1, args)
        eval_loss = eval_epoch(eval_dataloader, model, gpu, epoch + 1, early_stopping)

        mean_train_loss = torch.tensor(train_loss / args['gpus']).to(gpu)
        mean_eval_loss = torch.tensor(eval_loss / args['gpus']).to(gpu)

        dist.barrier()
        dist.all_reduce(mean_train_loss)
        dist.all_reduce(mean_eval_loss)
        print('gpu {} eval_loss:{}, mean_loss:{}'.format(gpu, eval_loss,
                                                         mean_eval_loss.cpu().numpy()))

        # if optim_name.split('-')[-1] == 'step':
        #     scheduler.step(mean_eval_loss.cpu().numpy())
        # elif optim_name.split('-')[-1] == 'cosine':
        #     scheduler.step()
        # elif optim_name.split('-')[-1] == 'no':
        #     pass

        # if (epoch+1) % 10 == 0:
        #     sample_plot_image(model, gpu, epoch+1)

        if gpu == 0:
            early_stopping(mean_eval_loss.cpu().numpy(), model, epoch + 1)

        train_losses.append(mean_train_loss.cpu().numpy())
        eval_losses.append(mean_eval_loss.cpu().numpy())

    current_time = time.process_time()
    print("Total Time Elapsed={:12.5} seconds".format(str(current_time - start_time)))

    # saving the plots
    plots_path = './plots/diff'
    os.makedirs(plots_path, exist_ok=True)
    epochs = np.arange(epoch_start, EPOCHS)
    train_losses = np.array(train_losses)
    eval_losses = np.array(eval_losses)
    fig, axes = plt.subplots(1, 1, figsize=(8, 5))
    axes.plot(epochs, train_losses, 'tab:blue', epochs, eval_losses, 'tab:orange')
    axes.set_title(f'Training and Validation Loss (pretrained model = None, loss = MSE Loss, '
                   f'data size = {data_size})',
                   weight='bold', fontsize=7)
    axes.set_xlabel('Epochs', weight='bold', fontsize=9)
    axes.set_ylabel('Loss', weight='bold', fontsize=9)
    plt.savefig(plots_path + '/'+DATA_TYPE+'loss_' + str(data_size) + '.jpg', dpi=300)

    # os.makedirs('./loss', exist_ok=True)
    # path = os.path.join('./loss', DATA_TYPE+'_train_loss.mat')
    # mdic = {"data": train_losses, "label": "epochs"}
    # savemat(path, mdic)

    cleanup()


if __name__ == '__main__':

    args = {}

    args['gpus'] = 4
    args['nr'] = 0
    args['world_size'] = args['gpus']
    args['checkpoints_path'] =  './snapshots/' + DATA_TYPE + '/'
    args['load_from_chkpt'] = None

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'

    print(args['gpus'])
    #mp.spawn(main, args=(args,), nprocs=args['gpus'])

4


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from torch.nn import functional as F
import torch
from sklearn.metrics import precision_score, accuracy_score, recall_score, f1_score, jaccard_score
from sklearn.metrics import confusion_matrix
from scipy.spatial.distance import directed_hausdorff

start = 1 # This is set to 0 if class:0 (background in our case) is to included as one class in the metric computation, Otherwise it is set to 1


def dice_coef(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    smooth = 1.0
    return ((2. * intersection) + smooth) / (np.sum(y_true) + np.sum(y_pred) + smooth)


def dice_score(y_true, y_pred, num_classes):
    _, y_true = torch.max(y_true, 1)
    _, y_pred = torch.max(y_pred, 1)
    dice = []
    for i in range(start, num_classes):
        true = (y_true == i).reshape(-1).cpu().numpy()
        pred = (y_pred == i).reshape(-1).cpu().numpy()
        dice.append(dice_coef(true, pred))

    return dice


def conf_matrix(y_true, y_pred, num_classes):
    _, y_true = torch.max(y_true, 1)
    y_true = y_true.cpu().numpy()
    _, y_pred = torch.max(y_pred, 1)
    y_pred = y_pred.cpu().numpy()
    cm = confusion_matrix(y_true.reshape(-1), y_pred.reshape(-1), labels=np.arange(num_classes))

    return cm


def precision(y_true, y_pred, num_classes):
    class_precisions = []
    _, y_true = torch.max(y_true, 1)
    _, y_pred = torch.max(y_pred, 1)
    for i in range(start, num_classes):
        true = (y_true == i).reshape(-1).cpu().numpy()
        pred = (y_pred == i).reshape(-1).cpu().numpy()
        value = precision_score(true, pred, zero_division=1)  # * y_true.shape[0]
        class_precisions.append(value)

    return class_precisions


def sensitivity(y_true, y_pred, num_classes):
    class_sensitivities = []
    _, y_true = torch.max(y_true, 1)
    _, y_pred = torch.max(y_pred, 1)
    for i in range(start, num_classes):
        true = (y_true == i).reshape(-1).cpu().numpy()
        pred = (y_pred == i).reshape(-1).cpu().numpy()
        value = recall_score(true, pred, zero_division=1)  # * y_true.shape[0]
        class_sensitivities.append(value)

    return class_sensitivities


def specificity(y_true, y_pred, num_classes):
    class_specifities = []
    _, y_true = torch.max(y_true, 1)
    _, y_pred = torch.max(y_pred, 1)
    for i in range(start, num_classes):
        true = (y_true == i).reshape(-1).cpu().numpy()
        pred = (y_pred == i).reshape(-1).cpu().numpy()
        value = recall_score(true, pred, pos_label=0, zero_division=1)  # * y_true.shape[0]
        class_specifities.append(value)

    return class_specifities


def accuracy(y_true, y_pred, num_classes):
    class_accuracies = []
    _, y_true = torch.max(y_true, 1)
    _, y_pred = torch.max(y_pred, 1)
    for i in range(start, num_classes):
        true = (y_true == i).reshape(-1).cpu().numpy()
        pred = (y_pred == i).reshape(-1).cpu().numpy()
        value = accuracy_score(true, pred)  # * y_true.shape[0]
        class_accuracies.append(value)

    return class_accuracies


def F1_score(y_true, y_pred, num_classes):
    class_f1_scores = []
    _, y_true = torch.max(y_true, 1)
    _, y_pred = torch.max(y_pred, 1)
    for i in range(start, num_classes):
        true = (y_true == i).reshape(-1).cpu().numpy()
        pred = (y_pred == i).reshape(-1).cpu().numpy()
        value = f1_score(true, pred, zero_division=1)  # * y_true.shape[0]
        class_f1_scores.append(value)

    return class_f1_scores


def Jaccard_score(y_true, y_pred, num_classes):
    class_jaccard_scores = []
    _, y_true = torch.max(y_true, 1)
    _, y_pred = torch.max(y_pred, 1)
    for i in range(start, num_classes):
        true = (y_true == i).reshape(-1).cpu().numpy()
        pred = (y_pred == i).reshape(-1).cpu().numpy()
        value = jaccard_score(true, pred, zero_division=1)  # * y_true.shape[0]
        class_jaccard_scores.append(value)

    return class_jaccard_scores


def Hausdorff_distance(y_true, y_pred, num_classes):
    class_hd = []
    _, y_true = torch.max(y_true, 1)
    _, y_pred = torch.max(y_pred, 1)
    for i in range(start, num_classes):
        true = (y_true == i).squeeze(0).cpu().numpy()
        pred = (y_pred == i).squeeze(0).cpu().numpy()
        hd1 = directed_hausdorff(true, pred)[0]
        hd2 = directed_hausdorff(pred, true)[0]
        hd = max(hd1, hd2)
        class_hd.append(hd)

    return class_hd

# This script is adopted from the offical repository of AJI score
def Aggregated_jaccard_index(gt_map, predicted_map, gpu):
    _, gt_map = torch.max(gt_map, 1)
    _, predicted_map = torch.max(predicted_map, 1)

    gt_list = torch.unique(gt_map)
    pr_list = torch.unique(predicted_map)

    if start != 0:
        gt_list = gt_list[gt_list != 0]
        pr_list = pr_list[pr_list != 0]

    pr_list = torch.cat((pr_list.view(-1, 1), torch.zeros(pr_list.size(0), 1).to(gpu)), dim=1)

    overall_correct_count = 0.0
    union_pixel_count = 0.0

    i = len(gt_list)

    while len(gt_list) > 0:
        # print(f'Processing object # {i}')

        gt = (gt_map == gt_list[i - 1]).float()

        predicted_match = gt * predicted_map.float()

        if predicted_match.sum() == 0:
            union_pixel_count += gt.sum()
            gt_list = gt_list[:-1]
            i = len(gt_list)
        else:
            predicted_nuc_index = torch.unique(predicted_match)
            if start != 0:
                predicted_nuc_index = predicted_nuc_index[predicted_nuc_index != 0]

            JI = 0
            best_match = None

            for j in range(len(predicted_nuc_index)):
                matched = (predicted_map == predicted_nuc_index[j]).float()
                nJI = matched.logical_and(gt).sum() / matched.logical_or(gt).sum()

                if nJI > JI:
                    best_match = predicted_nuc_index[j]
                    JI = nJI

            predicted_nuclei = (predicted_map == best_match).float()

            overall_correct_count += (gt.logical_and(predicted_nuclei)).sum()
            union_pixel_count += (gt.logical_or(predicted_nuclei)).sum()

            gt_list = gt_list[:-1]
            i = len(gt_list)

            best_match_idx = (pr_list[:, 0] == best_match).nonzero().item()
            pr_list[best_match_idx, 1] += 1

    unused_nuclei_list = (pr_list[:, 1] == 0).nonzero().view(-1)

    for k in range(len(unused_nuclei_list)):
        print(pr_list[unused_nuclei_list[k], 0])
        unused_nuclei = (predicted_map == pr_list[unused_nuclei_list[k], 0]).float()
        union_pixel_count += unused_nuclei.sum()

    if overall_correct_count == 0 and union_pixel_count == 0:
        return 1.0
    aji = overall_correct_count / union_pixel_count

    return aji.cpu().numpy()

In [None]:
from PIL import Image
from torchvision.transforms import transforms
import torch
import glob
import json
import numpy as np
import os
import cv2
from skimage import io
import sys
from scipy.io import loadmat, savemat


folder1 = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/unlabelled_img_patches/'
num_images = 1000
idxs = np.random.RandomState(2023).permutation(5328)

OUT_FOLDER = './samp_images'
os.makedirs(OUT_FOLDER, exist_ok=True)

for i in range(num_images):
    img_path = folder1 + 'image_' + str(idxs[i]) + '.jpg'
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    OUT_IMAGE_PATH = os.path.join(OUT_FOLDER, 'image_' + str(i) + '.jpg')
    cv2.imwrite(OUT_IMAGE_PATH, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))

In [None]:
import os
import numpy as np
from skimage import io
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18, resnet50
import torchvision.transforms as transforms
from torch.nn import functional as F
from scipy.io import loadmat
import cv2
import shutil
from matplotlib import pyplot as plt
# from simple_colors import *
%run '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/GenSelfDiff/test/metrics.ipynb'
%run '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/GenSelfDiff/downstream_train/model.ipynb'

from tqdm import tqdm
#from metrics import Aggregated_jaccard_index, Hausdorff_distance, Jaccard_score, precision, sensitivity, accuracy, F1_score, conf_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import sys

sys.path.append('../downstream_train')
#from GenSelfDiff.downstream_train.model import SegNet


os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Hyper Parameters
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
IMG_CHANNELS = 3
BATCH_SIZE = 1
NUM_CLASSES = 4
TEST_IMAGE_DIR = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/test/images'
TEST_LABEL_PATH = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/test/label_patches.mat'

start = "\033[1m"
end = "\033[0;0m"

# transformations to be performed on the data points
transformations = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)


def class_weights(target_labl):
    _, target_label1 = torch.max(target_labl, dim=0)
    weights = np.ones(NUM_CLASSES)
    target_label = target_label1.reshape(-1)
    all_labels = target_label.cpu().numpy()
    labels, label_counts = np.unique(np.array(all_labels), return_counts=True)
    w = 1 - np.round(label_counts / np.sum(label_counts), 4)
    if len(labels) == 1:
        w = 1.0
    weights[labels] = w

    print(labels)
    print(label_counts)

    return weights


def get_images_list(path1, k=None):
    total_list1 = os.listdir(path1)
    total_list1 = sorted(total_list1, key=lambda x: int(x.split('_')[-1].split('.')[0]))

    if k is None:
        return np.array(total_list1)
    else:
        return np.array(total_list1[:k])


class Histo_Dataset(Dataset):
    def __init__(self, image1_dir, image1_list, label_list, transform=None):
        self.image1_dir = image1_dir
        self.image1_list = image1_list
        self.label_list = label_list
        self.transform = transform

    def __len__(self):
        return len(self.image1_list)

    def __getitem__(self, index):
        img1_path = os.path.join(self.image1_dir, self.image1_list[index])

        image1 = io.imread(img1_path)
        mask = self.label_list[index]

        if self.transform is not None:
            image1 = self.transform(image1)

        return image1, mask


def eval_epoch(test_loader, model, num_classes, device):
    model_type = 'diffusion'
    loss_type = 'SSFL_A'
    with torch.no_grad():
        model.eval()
        aji_scores = 0.0
        jaccard_scores = np.zeros(num_classes-1)
        fscores = np.zeros(num_classes-1)
        hds = np.zeros(num_classes-1)
        accuracies = np.zeros(num_classes-1)
        senstivities = np.zeros(num_classes-1)
        precisions = np.zeros(num_classes-1)

        # jaccard_scores = np.zeros(num_classes)
        # fscores = np.zeros(num_classes)
        # hds = np.zeros(num_classes)
        # accuracies = np.zeros(num_classes)
        # senstivities = np.zeros(num_classes)
        # precisions = np.zeros(num_classes)

        cm = np.zeros((num_classes, num_classes))
        total_items = 0
        p_bar = tqdm(test_loader)

        save_path1 = './true_labels'
        save_path2 = './predicted_labels'
        save_path3 = './images'

        if os.path.exists(save_path1):
            shutil.rmtree(save_path1)
        os.makedirs(save_path1, exist_ok=True)

        if os.path.exists(save_path2):
            shutil.rmtree(save_path2)
        os.makedirs(save_path2, exist_ok=True)

        if os.path.exists(save_path3):
            shutil.rmtree(save_path3)
        os.makedirs(save_path3, exist_ok=True)

        indexing = 0
        for img, target_label in p_bar:
            img = img.to(device)
            target_label = target_label.squeeze(1).to(device)
            target_label1 = target_label.long()
            target_label = F.one_hot(target_label1, num_classes)
            target_label = torch.permute(target_label, (0, 3, 1, 2))

            t = torch.full((BATCH_SIZE,), 0, dtype=torch.long)
            t = t.to(device)

            predicted_label = model(img, t)

            aji_score = Aggregated_jaccard_index(target_label, predicted_label, device)
            jaccard_score = Jaccard_score(target_label, predicted_label, num_classes)
            fscore = F1_score(target_label, predicted_label, num_classes)
            hd = Hausdorff_distance(target_label, predicted_label, num_classes)
            acc = accuracy(target_label, predicted_label, num_classes)
            senstvty = sensitivity(target_label, predicted_label, num_classes)
            prec = precision(target_label, predicted_label, num_classes)


            print('===============================')
            print('           IMAGE_' + str(total_items))
            print('===============================')
            print('')
            print('AJI: ', np.round(aji_score, 4))
            print('Jaccard Score: ', np.round(jaccard_score, 4), ' Mean: ',
                  np.round(np.mean(jaccard_score), 4))
            print('F1 Score: ', np.round(fscore, 4), ' Mean: ', np.round(np.mean(fscore), 4))
            print('Hausdorff Distance: ', np.round(hd, 4), ' Mean: ', np.round(np.mean(hd), 4))
            print('Accuracy: ', np.round(acc, 4), ' Mean: ', np.round(np.mean(acc), 4))
            print('Sensitivity: ', np.round(senstvty, 4), ' Mean: ', np.round(np.mean(senstvty), 4))
            print('Precision: ', np.round(prec, 4), ' Mean: ', np.round(np.mean(prec), 4))
            print('')

            aji_scores += aji_score
            jaccard_scores += jaccard_score
            fscores += fscore
            hds += hd
            accuracies += acc
            senstivities += senstvty
            precisions += prec

            cm += conf_matrix(target_label, predicted_label, num_classes)

            batch = 1  # predicted_label.shape[0]
            total_items += 1  # batch

            # True label mappings
            labels_t = np.zeros((target_label1.shape[0], target_label1.shape[1], target_label1.shape[2], 3),
                                dtype=np.uint8)
            # labels_t[target_label1.cpu() == 1] = [127, 255, 0]
            labels_t[target_label1.cpu() == 1] = [255, 69, 0]
            labels_t[target_label1.cpu() == 2] = [127, 255, 0]
            labels_t[target_label1.cpu() == 3] = [135, 206, 250]

            # Predicted label mappings
            _, pred_labels = torch.max(predicted_label, 1)
            labels_p = np.zeros((pred_labels.shape[0], pred_labels.shape[1], pred_labels.shape[2], 3),
                                dtype=np.uint8)
            # labels_p[pred_labels.cpu() == 1] = [127, 255, 0]
            labels_p[pred_labels.cpu() == 1] = [255, 69, 0]
            labels_p[pred_labels.cpu() == 2] = [127, 255, 0]
            labels_p[pred_labels.cpu() == 3] = [135, 206, 250]

            for i in range(target_label.shape[0]):
                image_label_t = labels_t[i]
                image_label_p = labels_p[i]
                out_label_path1 = os.path.join(save_path1, 'label_' + str(indexing) + '_' +
                                               str(np.round(np.mean(senstvty) / batch, 2)) + '_' +
                                               str(np.round(np.mean(prec) / batch, 2)) + '.jpg')
                cv2.imwrite(out_label_path1, cv2.cvtColor(image_label_t, cv2.COLOR_RGB2BGR))

                out_label_path2 = os.path.join(save_path2, 'label_' + str(indexing) + '_' +
                                               str(np.round(np.mean(senstvty) / batch, 2)) + '_' +
                                               str(np.round(np.mean(prec) / batch, 2)) + '.jpg')
                cv2.imwrite(out_label_path2, cv2.cvtColor(image_label_p, cv2.COLOR_RGB2BGR))

                image = torch.permute(img[i], (1, 2, 0))
                image = image.cpu().numpy()
                image = np.uint8(255 * image)
                out_image_path = os.path.join(save_path3, 'image_' + str(indexing) + '_' +
                                              str(np.round(np.mean(senstvty) / batch, 2)) + '_' +
                                              str(np.round(np.mean(prec) / batch, 2)) + '.jpg')
                cv2.imwrite(out_image_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
                wt1 = class_weights(target_label[i])
                wt2 = class_weights(predicted_label[i])
                temp_t = target_label[i] * predicted_label[i]
                # temp_t[temp_t == 0] = 0.5
                for cls in range(NUM_CLASSES):
                    print([torch.min(temp_t[cls]), torch.max(temp_t[cls])])

                indexing += 1

            p_bar.set_description()
            p_bar.set_postfix(f1_score=np.mean(fscore) / batch, prec=np.mean(prec) / batch,
                              acc=np.mean(acc) / batch, max_label=torch.max(target_label1))

        print('total items: {}'.format(total_items))
        aji_scores = aji_scores/total_items
        jaccard_scores = jaccard_scores / total_items
        fscores = fscores / total_items
        hds = hds/total_items
        accuracies = accuracies / total_items
        senstivities = senstivities / total_items
        precisions = precisions / total_items
        cm = np.round(cm / total_items).astype(int)
        print('Average AJI: {}'.format(aji_scores))
        print('class jaccard scores: {} and Average jaccard score: {}'.format(jaccard_scores,
                                                                              np.mean(jaccard_scores)))
        print('class F1 scores: {} and Average F1 score: {}'.format(fscores, np.mean(fscores)))
        print('class Hausdorff Distances: {} and Average Hausdorff Distance: {}'.format(hds, np.mean(hds)))
        print('class accuracies: {} and Average accuracy: {}'.format(accuracies, np.mean(accuracies)))
        print('class sensitivities: {} and Average sensitivity: {}'.format(senstivities, np.mean(senstivities)))
        print('class precisions: {} and Average precision: {}'.format(precisions, np.mean(precisions)))
        disp = ConfusionMatrixDisplay(cm)
        disp.plot(cmap='Blues', values_format='')
        plt.savefig('./cm_' + model_type + '_' + loss_type + '_' + str(BATCH_SIZE) + '.jpg', dpi=300)


def main():
    backbone = 'Unet'

    # Loading the data
    img1_list = get_images_list(TEST_IMAGE_DIR)
    label_file = loadmat(TEST_LABEL_PATH)
    train_labels = label_file['data']

    test_dataset = Histo_Dataset(TEST_IMAGE_DIR, img1_list, train_labels,
                                 transform=transformations)

    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                             shuffle=False, drop_last=False, num_workers=2)

    path_train = '/content/drive/MyDrive/converted_notebooks/GenSelfDiff-HIS-main/GenSelfDiff/downstream_train/checkpoint.pth'
    snapshot = torch.load(path_train)
    print(DEVICE)
    print(path_train)

    model = SegNet(dim=64, channels=3, num_classes=NUM_CLASSES).to(DEVICE)
    model.load_state_dict(snapshot['model_state_dict'])

    eval_epoch(test_loader, model, NUM_CLASSES, DEVICE)


if __name__ == '__main__':
    main()


EOFError: 