In [None]:
!pip install einops



In [None]:
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
import torch.nn.functional as F
from einops import repeat
from torchvision.transforms.functional import to_pil_image
import numpy as np
import torch.nn as nn
from einops import rearrange
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
import os
import torchvision
import torch


class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size=256):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            nn.Conv2d(in_channels, emb_size,
                      kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e')
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.position = nn.Parameter(torch.randn(
            (img_size//patch_size)**2+1, emb_size))

    def forward(self, x):
        b, _, _, _ = x.shape
        x = self.projection(x)
        # cls token added  x batch times and appended
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        x = torch.cat([cls_tokens, x], dim=1)
        x += self.position
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 12, dropout: float = 0.1):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix

        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x, mask=None):
        # split keys, queries and values in num_heads
        # 3 x batch x no_head x sequence_length x emb_size
        qkv = rearrange(
            self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        # batch, num_heads, query_len, key_len
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out


class ResidualBlock(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x


class FeedForward(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.1):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
            nn.Dropout(drop_p),
        )


class TransformerEncoderBlock(nn.Sequential):
    def __init__(self, emb_size=768, drop_p: float = 0.1, forward_expansion: int = 4, forward_drop_p: float = 0.1, ** kwargs):
        super().__init__(
            ResidualBlock(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualBlock(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForward(emb_size),
                nn.Dropout(drop_p)
            )),
        )


class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 6, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs)
                           for _ in range(depth)])


class ViT(nn.Sequential):
    def __init__(self,
                 in_channels: int = 3,
                 patch_size: int = 16,
                 emb_size: int = 768,
                 img_size: int = 256,
                 depth: int = 6,
                 **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs)
        )


# model = ViT()
# print(model(torch.randn([1, 3, 32, 32])).shape)

In [None]:
import torch
from einops import rearrange
import torch.nn as nn


class DecoderLinear(nn.Module):
    def __init__(self, n_cls, patch_size, embedd_dim):
        super().__init__()
        self.n_cls = n_cls
        self.patch_size = patch_size
        self.embedd_dim = embedd_dim
        self.head = nn.Linear(embedd_dim, n_cls)

    def forward(self, x, img_size: int = 256):
        H = W = img_size
        num_patch = H//self.patch_size
        x = self.head(x)
        x = rearrange(x, "b (h w) c -> b c h w", h=num_patch)
        return x


class MaskDecoder(nn.Module):
    def __init__(self, scale, depth, patch_size, n_cls, dec_embdd):
        super().__init__()
        self.n_cls = n_cls
        self.dec_embdd = dec_embdd
        self.scale = scale
        self.patch_size = patch_size

        self.cls_emb = nn.Parameter(
            torch.randn([1, n_cls, dec_embdd]))
        self.dec_proj = nn.Linear(dec_embdd, dec_embdd)

        self.proj_patch = nn.Parameter(
            self.scale * torch.randn(dec_embdd, dec_embdd))
        self.proj_classes = nn.Parameter(
            self.scale * torch.randn(dec_embdd, dec_embdd))

        self.decoder_norm = nn.LayerNorm(dec_embdd)
        self.mask_norm = nn.LayerNorm(n_cls)

        self.blocks = TransformerEncoder(depth=depth)

    def forward(self, x, img_size):
        H, W = img_size
        GS = H//self.patch_size
        x = self.dec_proj(x)
        # Adding a cls token for each segmenting class
        cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
        x = torch.cat((x, cls_emb), 1)
        out = (self.blocks(x))
        x = self.decoder_norm(out)

        patches, cls_seg_feat = x[:, :-self.n_cls], x[:, -self.n_cls:]
        patches = patches @ self.proj_patch  # 1 x 61 x 768
        cls_seg_feat = cls_seg_feat @ self.proj_classes  # 1 x 4 x 768
        patches = patches / patches.norm(dim=-1, keepdim=True)
        cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True)

        masks = patches @ cls_seg_feat.transpose(1, 2)
        masks = self.mask_norm(masks)
        masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS))
        return masks

In [None]:

import torch.nn as nn
import torch.nn.functional as F
import seaborn as sns
import torch

class Segmenter(nn.Module):
    def __init__(self,
                 in_channels,
                 scale,
                 patch_size,
                 image_size,
                 enc_depth,
                 dec_depth,
                 enc_embdd,
                 dec_embdd,
                 n_cls):
        super().__init__()
        self.encoder = ViT(in_channels,
                           patch_size,
                           enc_embdd,
                           image_size,
                           enc_depth)
        self.decoder = MaskDecoder(scale,
                                   dec_depth,
                                   patch_size,
                                   n_cls,
                                   dec_embdd)

    def forward(self, img):
        H, W = img.size(2), img.size(3)
        x = self.encoder(img)
        x = x[:, 1:]  # remove Cls token
        masks = self.decoder(x, (H, W))
        out = F.interpolate(masks, size=(H, W), mode="bilinear")
        return out

model=Segmenter(3,0.05,16,256,12,6,768,768, 1)
print(model(torch.randn([16,3,256,256])).shape)

torch.Size([16, 1, 256, 256])


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from google.colab import drive
import os  # Add this line

# Define your dataset class
drive.mount('/content/drive')
# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, image_folder, mask_folder, transform=None):
        self.root_folder_original = image_folder
        self.root_folder_masked = mask_folder
        self.transform = transform

        self.original_images = os.listdir(image_folder)
        self.masked_images = os.listdir(mask_folder)

    def __len__(self):
        return len(self.original_images)

    def __getitem__(self, idx):
        original_image_path = os.path.join(self.root_folder_original, self.original_images[idx])
        masked_image_path = os.path.join(self.root_folder_masked, self.masked_images[idx])

        original_image = Image.open(original_image_path).convert('RGB')
        masked_image = Image.open(masked_image_path).convert('RGB')

        if self.transform:
            original_image = self.transform(original_image)
            masked_image = self.transform(masked_image)

        return original_image, masked_image

# Transformations for data augmentation
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Paths to your dataset folders
image_folder_path = '/content/drive/My Drive/Project/Data_Full_Image/Original'
mask_folder_path = '/content/drive/My Drive/Project/Data_Full_Image/Masked'

# Create custom dataset
dataset = CustomDataset(image_folder=image_folder_path, mask_folder=mask_folder_path, transform=transform)

# Split dataset into train and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

# Initialize your model, optimizer, and loss function
model = Segmenter(in_channels=3, scale=0.05, patch_size=16, image_size=256, enc_depth=12, dec_depth=6, enc_embdd=768, dec_embdd=768, n_cls=1)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()  # Binary cross-entropy loss for binary segmentation

# Training loop
num_epochs = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        # Assuming binary segmentation, convert masks to the same shape as outputs
        masks = masks.squeeze(1).float()

        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

# Validation loop
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, masks in val_loader:
        images, masks = images.to(device), masks.to(device)

        outputs = model(images)
        sigmoid_outputs = torch.sigmoid(outputs)

        # Convert to binary predictions (0 or 1) based on a threshold (e.g., 0.5)
        preds = (sigmoid_outputs > 0.5).float()

        # Flatten predictions and masks to 1D arrays
        preds_flat = preds.view(-1)
        masks_flat = masks.view(-1)

        all_preds.append(preds_flat.cpu().numpy())
        all_labels.append(masks_flat.cpu().numpy())

    # Concatenate results from all batches
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
