In [2]:
import numpy as np
import os
import tifffile
import cv2
from os.path import join, isfile, exists
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn as nn
import os
from PIL import Image

In [4]:
# input_dir = "../images"
# output_dir = "../images/png"
label_map = {"BMP4" :0, "CHIR": 1, "DS": 2, "DSandCHIR": 3,  "WT": 4}
# Define the input shape of the images
input_shape = (3, 224, 224)
# Define the number of classes
num_classes = 5
# Load the pre-trained GoogleNet model
model = models.googlenet(pretrained=True)
# Freeze the weights of all layers except the last fully connected layer
for param in model.parameters():
    param.requires_grad = False
model.fc.requires_grad = True
# Replace the last fully connected layer with a new layer that has `num_classes` output units
model.fc = nn.Linear(1024, num_classes)
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

In [5]:
# Define the dataset class
class MyDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        files = [f for f in os.listdir(root_dir) if isfile(join(root_dir, f))]
        for file in files:
            file_path = os.path.join(root_dir, file)
            print(file_path)
            if file_path.endswith(".png"):
                
                img = Image.open(file_path).convert("RGB")
                # self.images.append(img)
                name = file.split("_")[0]
                label = label_map.get(name, -1)
                if label != -1:
                    self.images.append(img)
                    self.labels.append(label)
                # self.labels.append(lable_map[name])
            elif file_path.endswith(".tif"):
                image_array = tifffile.imread(file_path)
                img_rescaled = 255 * (image_array - image_array.min()) / (image_array.max() - image_array.min())
                img_col = cv2.applyColorMap(img_rescaled.astype(np.uint8), cv2.COLORMAP_DEEPGREEN)
                img = Image.fromarray(img_col)
                img = img.convert("RGB")
                # self.images.append(img)

                name = file.split("_")[0]
                label = label_map.get(name, -1)
                if label != -1:
                    self.images.append(img)
                    self.labels.append(label)
                # self.labels.append(lable_map[name])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]
        # print(np.array(img))
        if self.transform is not None:
            img = self.transform(img)
        return img, label


In [6]:
# Define the data transformations
transform = transforms.Compose([
    transforms.Resize(input_shape[1:]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Load the dataset
dataset = MyDataset("../images/png", transform=transform)
image0, label0 = dataset.__getitem__(0)
# print(image0)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

../images/png/DS_26_04_t060_c002.png
../images/png/CHIR_22_40_t140_c002.png
../images/png/BMP4_22_20_t260_c002.png
../images/png/WT_22_02_t020_c002.png
../images/png/DSandCHIR_26_14_t150_c002.png


In [7]:
# Define the data transformations
transform = transforms.Compose([
    transforms.Resize(input_shape[1:]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Load the dataset
dataset = MyDataset("../images", transform=transform)
image0, label0 = dataset.__getitem__(0)
# print(image0)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

../images/BMP4_22_20_t260_c002.tif
../images/DS_26_04_t060_c002.tif
../images/DSandCHIR_26_14_t150_c002.tif
../images/CHIR_22_40_t140_c002.tif
../images/WT_22_02_t020_c002.tif


In [8]:
# Train the model
num_epochs = 50
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(dataloader, 0):
        # print(i)
        inputs, labels = data
        # print(inputs, labels)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        # if i % 10 == 9:
        print(f"[Epoch {epoch+1}, Batch {i+1}] Loss: {running_loss/10:.3f}")
        running_loss = 0.0
        
            

[Epoch 1, Batch 1] Loss: 0.172
[Epoch 2, Batch 1] Loss: 0.168
[Epoch 3, Batch 1] Loss: 0.162
[Epoch 4, Batch 1] Loss: 0.162
[Epoch 5, Batch 1] Loss: 0.151
[Epoch 6, Batch 1] Loss: 0.154
[Epoch 7, Batch 1] Loss: 0.149
[Epoch 8, Batch 1] Loss: 0.137
[Epoch 9, Batch 1] Loss: 0.131
[Epoch 10, Batch 1] Loss: 0.144
[Epoch 11, Batch 1] Loss: 0.125
[Epoch 12, Batch 1] Loss: 0.116
[Epoch 13, Batch 1] Loss: 0.115
[Epoch 14, Batch 1] Loss: 0.100
[Epoch 15, Batch 1] Loss: 0.093
[Epoch 16, Batch 1] Loss: 0.090
[Epoch 17, Batch 1] Loss: 0.093
[Epoch 18, Batch 1] Loss: 0.085
[Epoch 19, Batch 1] Loss: 0.067
[Epoch 20, Batch 1] Loss: 0.072
[Epoch 21, Batch 1] Loss: 0.063
[Epoch 22, Batch 1] Loss: 0.062
[Epoch 23, Batch 1] Loss: 0.056
[Epoch 24, Batch 1] Loss: 0.051
[Epoch 25, Batch 1] Loss: 0.049
[Epoch 26, Batch 1] Loss: 0.046
[Epoch 27, Batch 1] Loss: 0.045
[Epoch 28, Batch 1] Loss: 0.040
[Epoch 29, Batch 1] Loss: 0.037
[Epoch 30, Batch 1] Loss: 0.034
[Epoch 31, Batch 1] Loss: 0.032
[Epoch 32, Batch 