In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image


In [None]:
class ImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        
        label = self.labels[idx]
        return image, label


In [None]:
def load_dataset(data_dir):
    classes = sorted(os.listdir(data_dir))
    image_paths = []
    labels = []
    for idx, cls in enumerate(classes):
        cls_dir = os.path.join(data_dir, cls)
        for img_name in os.listdir(cls_dir):
            if img_name.endswith(('.png', '.jpg', '.jpeg')):
                image_paths.append(os.path.join(cls_dir, img_name))
                labels.append(idx)
    return image_paths, labels

# Example usage
data_dir = 'dataset/'  # Replace with your dataset path
image_paths, labels = load_dataset(data_dir)

# Split into training and validation sets
train_paths, val_paths, train_labels, val_labels = train_test_split(
    image_paths, labels, test_size=0.2, random_state=42, stratify=labels)


In [None]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Using ImageNet means
                         std=[0.229, 0.224, 0.225])
])


In [None]:
class Classifier(nn.Module):
    def __init__(self, num_classes=4):
        super(Classifier, self).__init__()
        self.model = models.resnet18(pretrained=True)
        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, num_classes)
        
    def forward(self, x):
        return self.model(x)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Classifier(num_classes=4).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} - Training'):
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
    
    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Training Loss: {epoch_loss:.4f}')
    
    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} - Validation'):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
    val_loss /= len(val_loader.dataset)
    accuracy = correct / total
    print(f'Validation Loss: {val_loss:.4f} - Accuracy: {accuracy:.4f}')


In [None]:
torch.save(model.state_dict(), 'classifier.pth')


In [None]:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# Initialize GradCAM
target_layer = model.model.layer4  # Last convolutional layer in ResNet-18
cam = GradCAM(model=model.model, target_layers=[target_layer], use_cuda=torch.cuda.is_available())


In [None]:
def generate_cam(image_path, model, cam, device, class_idx):
    # Load and preprocess image
    image = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    input_tensor = transform(image_rgb).unsqueeze(0).to(device)
    
    # Define target for CAM
    target = ClassifierOutputTarget(class_idx)
    
    # Generate CAM
    grayscale_cam = cam(input_tensor=input_tensor, targets=[target])
    grayscale_cam = grayscale_cam[0, :]
    
    # Overlay CAM on image
    cam_image = show_cam_on_image(image_rgb / 255.0, grayscale_cam, use_rgb=True)
    
    return cam_image


In [None]:
# Example: Generate CAM for the first image in validation set
example_img_path = val_paths[0]
example_label = val_labels[0]

cam_image = generate_cam(example_img_path, model, cam, device, example_label)

plt.figure(figsize=(8, 8))
plt.imshow(cam_image)
plt.title(f'CAM for Class {example_label}')
plt.axis('off')
plt.show()


In [None]:
def create_pseudo_mask(image_path, model, cam, device, num_classes=4, threshold=0.2):
    # Initialize an empty mask
    mask = np.zeros((224, 224), dtype=np.uint8)
    
    # Iterate over each class to generate CAMs
    for cls in range(num_classes):
        cam_image = generate_cam(image_path, model, cam, device, cls)
        cam_gray = cv2.cvtColor(cam_image, cv2.COLOR_RGB2GRAY)
        _, cam_binary = cv2.threshold(cam_gray, int(threshold * 255), 255, cv2.THRESH_BINARY)
        
        # Assign class label where CAM is active
        mask[cam_binary > 0] = cls
    
    return mask


In [None]:
pseudo_masks = {}
for img_path in tqdm(val_paths, desc='Generating Pseudo-Masks'):
    mask = create_pseudo_mask(img_path, model, cam, device, num_classes=4, threshold=0.2)
    pseudo_masks[img_path] = mask


In [None]:
example_mask = pseudo_masks[example_img_path]

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(cv2.cvtColor(cv2.imread(example_img_path), cv2.COLOR_BGR2RGB))
plt.title('Original Image')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(example_mask, cmap='jet', alpha=0.5)
plt.title('Pseudo-Mask')
plt.axis('off')
plt.show()


In [None]:
class UNet(nn.Module):
    def __init__(self, num_classes=4, in_channels=3, dropout_rate=0.5):
        super(UNet, self).__init__()
        def CBR(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout_rate)
            )
        
        # Encoder
        self.enc1 = CBR(in_channels, 64)
        self.enc2 = CBR(64, 128)
        self.enc3 = CBR(128, 256)
        self.enc4 = CBR(256, 512)
        
        self.pool = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = CBR(512, 1024)
        
        # Decoder
        self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = CBR(1024, 512)
        
        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = CBR(512, 256)
        
        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = CBR(256, 128)
        
        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = CBR(128, 64)
        
        # Final Convolution
        self.conv_final = nn.Conv2d(64, num_classes, kernel_size=1)
        
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        p1 = self.pool(e1)
        
        e2 = self.enc2(p1)
        p2 = self.pool(e2)
        
        e3 = self.enc3(p2)
        p3 = self.pool(e3)
        
        e4 = self.enc4(p3)
        p4 = self.pool(e4)
        
        # Bottleneck
        b = self.bottleneck(p4)
        
        # Decoder
        up4 = self.upconv4(b)
        up4 = torch.cat([up4, e4], dim=1)
        d4 = self.dec4(up4)
        
        up3 = self.upconv3(d4)
        up3 = torch.cat([up3, e3], dim=1)
        d3 = self.dec3(up3)
        
        up2 = self.upconv2(d3)
        up2 = torch.cat([up2, e2], dim=1)
        d2 = self.dec2(up2)
        
        up1 = self.upconv1(d2)
        up1 = torch.cat([up1, e1], dim=1)
        d1 = self.dec1(up1)
        
        out = self.conv_final(d1)
        return out


In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, image_paths, masks, transform=None):
        self.image_paths = image_paths
        self.masks = masks
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Load mask
        mask = self.masks[img_path]
        
        # Apply transformations
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        else:
            # Basic transformations
            transform_ops = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])
            image = transform_ops(image)
            mask = torch.from_numpy(mask).long()
        
        return image, mask


In [None]:
%pip install albumentations


In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

def get_segmentation_transforms():
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Rotate(limit=15, p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ], additional_targets={'mask': 'mask'})


In [None]:
# Define transformations
seg_transforms = get_segmentation_transforms()

# Create dataset
unet_dataset = SegmentationDataset(val_paths, pseudo_masks, transform=seg_transforms)

# Split into training and validation (optional)
train_unet, val_unet = train_test_split(unet_dataset, test_size=0.2, random_state=42)

# Create DataLoaders
unet_train_loader = DataLoader(train_unet, batch_size=16, shuffle=True, num_workers=4)
unet_val_loader = DataLoader(val_unet, batch_size=16, shuffle=False, num_workers=4)


In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        
    def forward(self, inputs, targets):
        # inputs: (N, C, H, W)
        # targets: (N, H, W)
        inputs = torch.softmax(inputs, dim=1)
        targets_one_hot = nn.functional.one_hot(targets, num_classes=inputs.shape[1]).permute(0, 3, 1, 2).float()
        
        intersection = (inputs * targets_one_hot).sum(dim=(2,3))
        dice = (2. * intersection + self.smooth) / (inputs.sum(dim=(2,3)) + targets_one_hot.sum(dim=(2,3)) + self.smooth)
        return 1 - dice.mean()
    
class CombinedLoss(nn.Module):
    def __init__(self, weight=None):
        super(CombinedLoss, self).__init__()
        self.ce = nn.CrossEntropyLoss(weight=weight)
        self.dice = DiceLoss()
        
    def forward(self, inputs, targets):
        ce_loss = self.ce(inputs, targets)
        dice_loss = self.dice(inputs, targets)
        return ce_loss + dice_loss


In [None]:
num_classes = 4
unet_model = UNet(num_classes=num_classes, in_channels=3, dropout_rate=0.5).to(device)
criterion = CombinedLoss()
optimizer = optim.Adam(unet_model.parameters(), lr=1e-4)
num_epochs = 20


In [None]:
for epoch in range(num_epochs):
    unet_model.train()
    running_loss = 0.0
    for images, masks in tqdm(unet_train_loader, desc=f'UNet Epoch {epoch+1}/{num_epochs} - Training'):
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        outputs = unet_model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
    
    epoch_loss = running_loss / len(unet_train_loader.dataset)
    print(f'UNet Training Loss: {epoch_loss:.4f}')
    
    # Validation
    unet_model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in tqdm(unet_val_loader, desc=f'UNet Epoch {epoch+1}/{num_epochs} - Validation'):
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = unet_model(images)
            loss = criterion(outputs, masks)
            
            val_loss += loss.item() * images.size(0)
    
    val_loss /= len(unet_val_loader.dataset)
    print(f'UNet Validation Loss: {val_loss:.4f}')
    
    # Early Stopping or Checkpointing can be added here


In [None]:
torch.save(unet_model.state_dict(), 'unet.pth')


In [None]:
def predict_segmentation(model, image_path, device):
    model.eval()
    
    # Load and preprocess image
    image = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    transform_ops = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    input_tensor = transform_ops(image_rgb).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(input_tensor)
        preds = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
    
    # Resize mask to original image size
    mask = cv2.resize(preds.astype(np.uint8), (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
    return mask


In [None]:
new_image_path = 'path_to_new_image.jpg'  # Replace with your image path
predicted_mask = predict_segmentation(unet_model, new_image_path, device)

# Visualize
original_image = cv2.cvtColor(cv2.imread(new_image_path), cv2.COLOR_BGR2RGB)

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(original_image)
plt.title('Original Image')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(predicted_mask, cmap='jet', alpha=0.5)
plt.title('Predicted Segmentation Mask')
plt.axis('off')

plt.show()
