<a href="https://colab.research.google.com/github/rachit2005/UNET-/blob/main/unet_using_snn_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [23]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("orvile/carotid-ultrasound-images")

print("Path to dataset files:", path)

Using Colab cache for faster access to the 'carotid-ultrasound-images' dataset.
Path to dataset files: /kaggle/input/carotid-ultrasound-images
Using Colab cache for faster access to the 'carotid-ultrasound-images' dataset.
Path to dataset files: /kaggle/input/carotid-ultrasound-images


In [24]:
! pip install kaggle
!pip install snntorch
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json

mkdir: cannot create directory ‘/root/.kaggle’: File exists
mkdir: cannot create directory ‘/root/.kaggle’: File exists


In [17]:
!kaggle datasets download -d orvile/carotid-ultrasound-images

Dataset URL: https://www.kaggle.com/datasets/orvile/carotid-ultrasound-images
License(s): Attribution 4.0 International (CC BY 4.0)
carotid-ultrasound-images.zip: Skipping, found more recently modified local copy (use --force to force download)


In [None]:
! unzip carotid-ultrasound-images.zip

Archive:  carotid-ultrasound-images.zip
replace Common Carotid Artery Ultrasound Images/Common Carotid Artery Ultrasound Images/Expert mask images/202201121748100022VAS_slice_1069.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [None]:
import matplotlib.pyplot as plt
import os

base_path = "Common Carotid Artery Ultrasound Images/Common Carotid Artery Ultrasound Images/"
mask_path = base_path + "Expert mask images"
image_path = base_path + "US images"

print(f"total len: {len(os.listdir(image_path))}")
print(f"sample images: {os.listdir(image_path)[:5]}")

In [None]:
from torch import nn
from torch.utils.data import DataLoader , Dataset
import numpy as np
import pandas as pd

In [None]:
images_files = sorted([os.path.join(image_path, file) for file in os.listdir(image_path)])
masks_files = sorted([os.path.join(mask_path, file) for file in os.listdir(mask_path)])

In [None]:
df = pd.DataFrame({"image":images_files , "mask":masks_files})

print(len(df))
print(df.shape)
print(df.head())

In [None]:
from PIL import Image
from torchvision import transforms

image_size = 256

transform = transforms.Compose([
    transforms.Resize((image_size , image_size)),
    transforms.ToTensor(),
])

class UNETDataset(Dataset):
    def __init__(self , df , transform=transform):
      self.df = df
      self.transform = transform

    def __len__(self):
      return len(self.df)

    def __getitem__(self, index):
      image = self.df.iloc[index , 0]
      mask = self.df.iloc[index , 0]

      image = Image.open(image).convert("RGB")
      mask = Image.open(mask).convert("L")

      if self.transform:
        image = self.transform(image)
        mask = self.transform(mask)

      return image , mask

In [None]:
dataset = UNETDataset(df)

print(dataset[0])

# Unet Model

In [None]:
import torch
from snntorch import surrogate
import snntorch as snn


class conv_block(nn.Module):
  def __init__(self , in_channels , out_channels , beta=0.9):
    super().__init__()

    spike_grad = surrogate.atan()

    self.conv_layer = nn.Sequential(
        nn.Conv2d(in_channels=in_channels , out_channels=out_channels , kernel_size=3 , padding=1),
        snn.Leaky(beta=beta, spike_grad=spike_grad , init_hidden=True),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=out_channels , out_channels=out_channels , kernel_size=3 , padding=1),
        snn.Leaky(beta=beta, spike_grad=spike_grad , init_hidden=True),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )

  def forward(self , x):
    return self.conv_layer(x)

  def reset_state(self):
        """Resets the membrane potential (state) of the internal neurons."""
        for layer in self.conv_layer:
            if isinstance(layer, snn.Leaky):
                layer.reset_hidden()


class Encoder(nn.Module):
  def __init__(self , in_channels , out_channels):
    super().__init__()
    self.conv_block = conv_block(in_channels , out_channels)
    self.max_pool = nn.MaxPool2d(kernel_size=2 , stride=2)

  def forward(self , x):
    x_skip = self.conv_block(x)
    p = self.max_pool(x_skip)
    return x_skip , p

  def reset_state(self):
    self.conv_block.reset_state()

# we will add the x to the decoder output to improve the accuracy and to follow the unet paper

class Decoder(nn.Module):
  def __init__(self , in_channels , out_channels):
    super().__init__()

    self.up_conv = nn.ConvTranspose2d(in_channels=in_channels , out_channels=out_channels , kernel_size=2 , stride=2)
    self.conv_block = conv_block(out_channels*2 , out_channels)

  def forward(self , x , skip_connections):
    x = self.up_conv(x)
    x = torch.cat([x , skip_connections] , dim=1)
    return self.conv_block(x)

  def reset_state(self):
    self.conv_block.reset_state()


class UNET(nn.Module):
  def __init__(self):
    super().__init__()
    # Encoder
    self.encoder1 = Encoder(3 , 64)
    self.encoder2 = Encoder(64 , 128)
    self.encoder3 = Encoder(128 , 256)
    self.encoder4 = Encoder(256 , 512)

    # bottle neck
    self.bottle_neck = conv_block(512 , 1024)

    # Decoder
    self.decoder1 = Decoder(1024 , 512)
    self.decoder2 = Decoder(512 , 256)
    self.decoder3 = Decoder(256, 128)
    self.decoder4 = Decoder(128 , 64)

    self.final = nn.Conv2d(64 , 1 , kernel_size=1)


  def forward(self , x):
    x1 , p1 = self.encoder1(x)
    x2 , p2 = self.encoder2(p1)
    x3 , p3 = self.encoder3(p2)
    x4 , p4 = self.encoder4(p3)

    x = self.bottle_neck(p4)

    d1 = self.decoder1(x , x4)
    d2 = self.decoder2(d1,x3)
    d3 = self.decoder3(d2,x2)
    d4 = self.decoder4(d3,x1)

    return self.final(d4)


  def reset_state(self):
      """
      Resets all membrane potentials across the entire network.
      CRITICAL for SNNs, typically called at the start of a new batch/sequence.
      """
      self.encoder1.reset_state()
      self.encoder2.reset_state()
      self.encoder3.reset_state()
      self.encoder4.reset_state()
      self.bottle_neck.reset_state()
      self.decoder4.reset_state()
      self.decoder3.reset_state()
      self.decoder2.reset_state()
      self.decoder1.reset_state()

In [None]:
model = UNET()
x = torch.randn((1,3,512,512))
y = model(x)
print(y.shape)

In [None]:
epochs = 5
batch_size = 32
learning_rate = 3e-4
device = "cuda" if torch.cuda.is_available() else "cpu"

loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters() , lr=learning_rate)

In [None]:
train_loader = DataLoader(dataset , batch_size , shuffle=True)
images , masks = next(iter(train_loader))
print(images.shape , masks.shape)

In [None]:
loss_hist = []

model.to(device)

for epoch in range(epochs):
  model.train()

  for batch , (images , masks) in enumerate(train_loader):
    images = images.to(device)
    masks = masks.to(device)

    model.reset_state()
    output_masks = model(images)

    loss = loss_fn(output_masks , masks)

    loss_hist.append(loss)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"epoch: {epoch} | loss: {loss} | batch: {batch}")