In [1]:
import os
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import models
from PIL import Image
import torchvision.transforms as transforms
from scipy import ndimage
import cv2
# (Keep your other imports, e.g., for transforms, if needed)

class type_dataset(Dataset):
    def __init__(self, root_dir, transform=None, augmentation=False, crop_augmentation=False, noise=False):
        self.data = []
        self.transform = transform
        self.augmentation = augmentation
        self.crop_augmentation = crop_augmentation
        self.noise = noise

        pathways = ['apoptosis', 'necroptosis', 'necrosis', 'control', 'treatedalive']
        pathways_folder = ['0_Apoptosis', '0_Necroptosis', '0_Necrosis', '0_Control', '0_TreatedAlive']
        # pathways_folder = ['background', 'necrosis']
        for pathway_label, folder in enumerate(pathways_folder):
            pathway_dir = os.path.join(root_dir, folder)
            for fname in os.listdir(pathway_dir):
                image_path = os.path.join(pathway_dir, fname)
                image = Image.open(image_path)
                image = np.array(image)
                # image = cv2.equalizeHist(image)
                self.data.append((image, pathway_label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get the primary image and label
        image, label = self.data[idx]
        
        # Convert to tensor and repeat for 3 channels if necessary
        image_tensor = torch.from_numpy(image).repeat(3, 1, 1).float()
        
        # Apply any prebuilt transform
        flips = transforms.Compose([
                                transforms.RandomHorizontalFlip(p=0.5),
                                transforms.RandomVerticalFlip(p=0.5)
                            ])
        if self.crop_augmentation:
            crop_size = random.randint(160, 320)
            top = random.randint(0, 320 - crop_size)
            left = random.randint(0, 320 - crop_size)
            image_tensor = image_tensor[:, top:top+crop_size, left:left+crop_size]

        if self.transform:
            image_tensor = self.transform(image_tensor)
        
        if self.noise:
            noise_std = 0.1
            noise = torch.randn_like(image_tensor) * noise_std
            image_tensor = image_tensor + noise
        
        if self.augmentation:
            # ---- CutMix Augmentation Start ----
            image_tensor = flips(image_tensor)
            # Add Gaussian noise (mean=0, std_dev=some_value, e.g. 0.1)
            noise_std = 0.1
            noise = torch.randn_like(image_tensor) * noise_std
            image_tensor = image_tensor + noise
            
            # Sample a random second image (optionally ensure it's different from idx)
            rand_idx = random.randint(0, len(self.data) - 1)
            image2, label2 = self.data[rand_idx]

            """
            if self.noise:
                image2 = ndimage.gaussian_filter(image2, sigma=1)
            """
            image_tensor2 = torch.from_numpy(image2).repeat(3, 1, 1).float()
            if self.crop_augmentation:
                image_tensor2 = image_tensor2[:, top:top+crop_size, left:left+crop_size]
            
            if self.transform:
                image_tensor2 = self.transform(image_tensor2)
            image_tensor2 = flips(image_tensor2)
            
            
            if self.noise:
                noise_std = 0.1
                noise = torch.randn_like(image_tensor2) * noise_std
                image_tensor2 = image_tensor2 + noise
            
            # Sample lambda from a Beta distribution (here using alpha=1.0 for both sides)
            lam = np.random.beta(1.0, 1.0)
            # Get image dimensions (assuming images are of equal size)
            _, H, W = image_tensor.size()
            # Compute the size of the patch to cut and paste
            r = np.sqrt(1 - lam)
            cut_w = int(W * r)
            cut_h = int(H * r)
            
            # Choose a random center point for the patch
            cx = np.random.randint(W)
            cy = np.random.randint(H)
            
            # Calculate the bounding box coordinates and clip to image size
            x1 = np.clip(cx - cut_w // 2, 0, W)
            y1 = np.clip(cy - cut_h // 2, 0, H)
            x2 = np.clip(cx + cut_w // 2, 0, W)
            y2 = np.clip(cy + cut_h // 2, 0, H)
            
            # Replace the region in image_tensor with the corresponding patch from image_tensor2
            image_tensor[:, y1:y2, x1:x2] = image_tensor2[:, y1:y2, x1:x2]
            
            # Adjust lambda to exactly match the pixel ratio of the mixed region
            lam = 1 - ((x2 - x1) * (y2 - y1) / (W * H))
            # Return a tuple for the label: (label_a, label_b, lam)
            label = (label, label2, lam)
            # ---- CutMix Augmentation End ----
            
        return image_tensor, label


In [2]:
from utils.AugmentedDatasetWrapper import AugmentedDatasetWrapper
from torch.utils.data import DataLoader
from torchvision import models

train_dir = "C:/rkka_Projects/cell_death_v2/Data/model_training/mip/train"
val_dir = "C:/rkka_Projects/cell_death_v2/Data/model_training/mip/test"

transform = models.ResNet50_Weights.IMAGENET1K_V2.transforms()
train_dataset = type_dataset(train_dir, transform=transform, augmentation=True, crop_augmentation=True, noise=True)
val_dataset = type_dataset(val_dir, transform=transform, augmentation=False, crop_augmentation=False)
augmented_train_dataset = AugmentedDatasetWrapper(dataset=train_dataset, num_repeats=6)
augmented_val_dataset = AugmentedDatasetWrapper(dataset=train_dataset, num_repeats=1)

train_loader = DataLoader(dataset=augmented_train_dataset, shuffle=True, batch_size=64)
val_loader = DataLoader(dataset=val_dataset, shuffle=True, batch_size=64)

In [3]:
import utils
import torch

model = models.resnet50(pretrained=True)
num_features = model.fc.in_features
model.fc = torch.nn.Sequential(
    torch.nn.Dropout(0.2),
    torch.nn.Linear(num_features, 5)
)
model.load_state_dict(torch.load(r"C:\rkka_Projects\cell_death_v2\trained_models\ai_epoch_33_24.242485_0.9728.pth"))
for name, params in model.named_parameters():
    if 'layer4' in name or 'layer3.5' in name:
        params.requires_grad = True
    else:
        params.requires_grad = False

        
utils.print_trainable_parameters(model)



Total Parameters: 23,518,277
Trainable Parameters: 16,081,920


In [4]:
# Increase the weight for classes 3 and 4; adjust values as needed
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.00005)

model = model.cuda()

In [None]:
import torch
from tqdm import tqdm
import datetime

for epoch in tqdm(range(150)):
    # train
    model.train()
    train_loss, train_correct, train_total = 0, 0, 0
    try:
        for images, labels in train_loader:
            images = images.cuda()
            
            # Move each component of labels to the GPU if using CutMix
            if isinstance(labels, (list, tuple)):
                label_a, label_b, lam = labels
                label_a = label_a.cuda()
                label_b = label_b.cuda()
                lam     = lam.cuda()
            else:
                labels = labels.cuda()
            
            outputs = model(images)
            
            if isinstance(labels, (list, tuple)):
                # Compute per-sample losses by setting reduction='none'
                loss_a = torch.nn.functional.cross_entropy(outputs, label_a, reduction='none')
                loss_b = torch.nn.functional.cross_entropy(outputs, label_b, reduction='none')
                # Combine losses per sample and then average to get a scalar loss
                loss = (lam * loss_a + (1 - lam) * loss_b).mean()
            else:
                loss = criterion(outputs, labels)  # This remains as is if no CutMix is used
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
        #validation
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for val_images, val_labels in val_loader:
                val_images, val_labels = val_images.cuda(), val_labels.cuda()
                    
                outputs = model(val_images)
                loss = criterion(outputs, val_labels)
                val_loss += loss
                    
                _, preds = torch.max(outputs, 1)
                val_correct += (preds==val_labels).sum().item()
                val_total += len(val_labels)
        
        torch.save(model.state_dict(), f'trained_models/ai_epoch_{epoch+33}_{train_loss:.6f}_{val_correct/val_total:.4f}.pth')
        
        print(f"Epoch : {epoch}")
        print(f"train loss : {train_loss:.6f}")
        print(f"val loss : {val_loss:.6f} || val_acc : {val_correct/val_total:.4f}")
    except:
        pass
            

  1%|          | 1/150 [00:13<33:04, 13.32s/it]

Epoch : 0
train loss : 24.125369
val loss : 0.519867 || val_acc : 0.9388


  1%|▏         | 2/150 [00:27<33:26, 13.56s/it]

Epoch : 1
train loss : 23.536839
val loss : 0.495792 || val_acc : 0.9660


  2%|▏         | 3/150 [00:40<32:55, 13.44s/it]

Epoch : 2
train loss : 23.657306
val loss : 0.529723 || val_acc : 0.9592


  3%|▎         | 4/150 [00:54<32:55, 13.53s/it]

Epoch : 3
train loss : 23.608289
val loss : 0.513320 || val_acc : 0.9660


  3%|▎         | 5/150 [01:07<32:31, 13.46s/it]

Epoch : 4
train loss : 23.375250
val loss : 0.591015 || val_acc : 0.9660


  4%|▍         | 6/150 [01:20<32:21, 13.48s/it]

Epoch : 5
train loss : 23.725262
val loss : 0.549549 || val_acc : 0.9660


  5%|▍         | 7/150 [01:34<32:14, 13.53s/it]

Epoch : 6
train loss : 24.070198
val loss : 0.528996 || val_acc : 0.9592


  5%|▌         | 8/150 [01:47<31:48, 13.44s/it]

Epoch : 7
train loss : 23.477241
val loss : 0.492418 || val_acc : 0.9524


  6%|▌         | 9/150 [02:01<31:38, 13.46s/it]

Epoch : 8
train loss : 23.896513
val loss : 0.472503 || val_acc : 0.9524


  7%|▋         | 10/150 [02:15<31:38, 13.56s/it]

Epoch : 9
train loss : 23.042517
val loss : 0.614250 || val_acc : 0.9524


  7%|▋         | 11/150 [02:28<31:14, 13.49s/it]

Epoch : 10
train loss : 22.959663
val loss : 0.534966 || val_acc : 0.9456


  8%|▊         | 12/150 [02:41<30:58, 13.47s/it]

Epoch : 11
train loss : 23.662053
val loss : 0.698275 || val_acc : 0.9456


  9%|▊         | 13/150 [02:55<30:48, 13.49s/it]

Epoch : 12
train loss : 23.417795
val loss : 0.610718 || val_acc : 0.9524


  9%|▉         | 14/150 [03:08<30:36, 13.51s/it]

Epoch : 13
train loss : 23.606480
val loss : 0.563357 || val_acc : 0.9388


 11%|█         | 16/150 [03:26<25:42, 11.51s/it]

Epoch : 15
train loss : 23.337025
val loss : 0.458064 || val_acc : 0.9592
