In [1]:
import matplotlib.pyplot as plt

In [2]:
import torch
#for defining convolutional layer, pooling layers, fully connected layer,etc
import torch.nn as nn 
#for implementing optimizing algorithm like adam
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
#for image preprocessing and agumentation
from torchvision.transforms import v2 

In [3]:
import torch.fft as fft

In [4]:
import os
import cv2
from tqdm import tqdm

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
class pneumoniaDataset(Dataset):
    def __init__(self, image_path, device):
        self.image_path = image_path
        self.device = device
        self.data = []
        self.transform = v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Resize((227,227)),
            v2.Normalize(mean=[0.5,], std=[0.5,]),
            v2.Grayscale(num_output_channels=3)
        ])
        self.load_data()

    def load_data(self):
        for idx, label in enumerate(os.listdir(os.path.join(self.image_path))):
            print(os.path.join(self.image_path, label))
            for img_file in tqdm(os.listdir(os.path.join(self.image_path, label))):
                img = cv2.imread(os.path.join(self.image_path, label, img_file), cv2.IMREAD_GRAYSCALE)

                img = self.transform(img)

                idx = torch.tensor([idx])
                self.data.append((img.to(self.device), idx.to(self.device)))
    
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        img,label= self.data[idx]
        label=label.squeeze()
        return img,label

In [7]:
# since i was struglling with automatically appearing of .DS_store in my folder
# which lead to unwanted classifying as class and  labeling it as 0 
# Remove .DS_Store files from the folder and subfolders
for root, dirs, files in os.walk('data'):
    if '.DS_Store' in files:
        os.remove(os.path.join(root, '.DS_Store'))

In [8]:
# Initialize the dataset
train_path=os.path.join('data/train_set')
validation_path=os.path.join('data/val_set')
# test_path=os.path.join('data/test_set')
train_dataset = pneumoniaDataset(image_path=train_path, device=device)
validation_dataset = pneumoniaDataset(image_path=validation_path, device=device)
# test_dataset = pneumoniaDataset(image_path=test_path, device=device)

data/train_set/bacterial


100%|██████████████████████████████████████| 2700/2700 [00:10<00:00, 251.07it/s]


data/train_set/normal


100%|██████████████████████████████████████| 2943/2943 [00:23<00:00, 124.05it/s]


data/train_set/viral


100%|██████████████████████████████████████| 1491/1491 [00:07<00:00, 203.09it/s]


data/val_set/bacterial


100%|████████████████████████████████████████| 301/301 [00:01<00:00, 204.78it/s]


data/val_set/normal


100%|█████████████████████████████████████████| 327/327 [00:04<00:00, 67.53it/s]


data/val_set/viral


100%|████████████████████████████████████████| 165/165 [00:00<00:00, 169.11it/s]


In [9]:
len(train_dataset), len(validation_dataset)

(7134, 793)

In [10]:
#dataloaders
batch_size = 8
train_dataloader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
validation_dataloader=DataLoader(validation_dataset,batch_size=batch_size,shuffle=False)
# test_dataloader=DataLoader(test_dataset,batch_size=16,shuffle=False)

In [11]:
# print(len(validation_dataset))
# image, label = train_dataset[6000] 
# print(image.shape)         
# print(label)  
# print(label.shape) 

In [12]:
class FFTConvNet(nn.Module):
    def __init__(self, conv_layer, fft_filter=None):
        super().__init__()
        self.conv_layer = conv_layer  # Original Conv2d layer
        self.fft_filter = fft_filter

    def fft_filter_def(self, fft_x, height, width):
        cht, cwt = height // 2, width // 2
        mask_radius = 30

        # Create a meshgrid for the mask
        fy, fx = torch.meshgrid(
            torch.arange(0, height, device=fft_x.device),
            torch.arange(0, width, device=fft_x.device),
            indexing='ij'
        )
        mask_area = torch.sqrt((fx - cwt) ** 2 + (fy - cht) ** 2)

        # Create the mask based on the filter type
        if self.fft_filter == 'high':
            mask = (mask_area > mask_radius).float()
        else:
            mask = (mask_area <= mask_radius).float()

        # Apply the mask to the FFT of the input
        filtered_fft = fft_x * mask
        return filtered_fft

    def forward(self, x):
        batch_size, in_channels, height, width = x.size()
        out_channels = self.conv_layer.out_channels  # Number of output channels

        # Apply FFT on input image
        fft_x = fft.fft2(x)  # Shape: [batch_size, in_channels, height, width]
        fft_x = fft.fftshift(fft_x)

        # Apply FFT on the convolutional kernel
        kernel_fft = fft.fft2(self.conv_layer.weight, s=(height, width))  # Shape: [out_channels, in_channels, height, width]
        kernel_fft = fft.fftshift(kernel_fft)

        # Apply FFT filter (low-pass or high-pass)
        if self.fft_filter is not None:
            fft_x = self.fft_filter_def(fft_x, height, width)

        # Perform element-wise complex multiplication
        fft_output = fft_x.unsqueeze(1) * kernel_fft.unsqueeze(0)  # Broadcast multiplication
        fft_output = torch.sum(fft_output, dim=2)  # Sum over input channels

        # Apply inverse FFT
        fft_output = fft.ifftshift(fft_output, dim=(-2, -1))
        spatial_output = fft.ifft2(fft_output, dim=(-2, -1)).real  # Shape: [batch_size, out_channels, height, width]

        # Add bias (if applicable)
        if self.conv_layer.bias is not None:
            spatial_output += self.conv_layer.bias.view(1, -1, 1, 1)

        # Debug: Print shapes
        print(f"Input shape: {x.shape}")
        print(f"Output shape: {spatial_output.shape}")

        # Return output
        return spatial_output

In [13]:
# Define the model
class AlexNet(nn.Module):
    def __init__(self,input_channels,number_of_classes):
        super(AlexNet, self).__init__() 
        # Convolutional layers
        self.conv1 = nn.Conv2d(input_channels, 96, kernel_size=11, stride=4, padding=2)
        self.conv2 = nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2)
        self.conv3 = nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1)

        # Fully connected layers
        self.fc1 = nn.Linear(186624, 4069)
        self.fc2 = nn.Linear(4069, 4069)
        self.fc3 = nn.Linear(4069, number_of_classes)

        # Other
        self.flatten = nn.Flatten()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
        self.norm = nn.LocalResponseNorm(size=5, k=2)
        self.droput = nn.Dropout(0.5)
        self.relu = nn.ReLU()
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.maxpool(self.norm(self.relu(self.conv1(x))))  # (B, 96, 27, 27)
        x = self.maxpool(self.norm(self.relu(self.conv2(x))))  # (B, 256, 13, 13)
        x = self.relu(self.conv3(x))                           # (B, 384, 13, 13)
        x = self.relu(self.conv4(x))                           # (B, 384, 13, 13)
        x = self.maxpool(self.relu(self.conv5(x)))             # (B, 256, 6, 6)
        x = self.flatten(x)                                    # (B, 9216)
        x = self.droput(self.relu(self.fc1(x)))                # (B, 4096)
        x = self.droput(self.relu(self.fc2(x)))                # (B, 4096)
        x = self.logsoftmax(self.fc3(x))                                       # (B, num_classes)
        return x

In [14]:
model=AlexNet(3,3)

In [15]:
for name, module in model.named_children():
    if isinstance(module, nn.Conv2d):
        if module.kernel_size[0] > 3:
            fft_cnn = FFTConvNet(module, 'low')
            setattr(model, name, fft_cnn)

In [None]:
model

In [16]:
learning_rate = 0.001

#loss function
criterion=nn.CrossEntropyLoss()

#optimizer
optimizer=optim.Adam(model.parameters(),learning_rate)

In [17]:
def train(model, train_dl, loss_fn, optimizer, epochs):
    best_acc = 0.0
    model.train()
    for epoch in range(epochs):
        print(f"Epoch [{epoch+1}/{epochs}]")
        running_loss = 0.0
        running_corrects = 0
        total_samples = 0

        for images, labels in tqdm(train_dl):
            # print("Data loaded")

            # Move data to GPU
            labels = labels.squeeze().long()
            images = images.to(device)
            labels = labels.to(device)
            # print(f"Data moved to {device}")

            # Forward pass
            outputs = model(images)
            # print("Forward pass completed")

            # Compute loss
            loss = loss_fn(outputs, labels)
            # print("Loss computed")

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # print("Backward pass completed")

            # Compute metrics
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum(preds == labels.data)
            total_samples += labels.size(0)
            
        epoch_loss = running_loss / total_samples
        epoch_acc = running_corrects / total_samples
        print(f"Epoch Loss: {epoch_loss:.4f}, Epoch Accuracy: {epoch_acc:.4f}")
    print('Training Complete.')

In [None]:
train(model, train_dataloader, criterion, optimizer, epochs=1)

Epoch [1/1]


  0%|                                                   | 0/892 [00:00<?, ?it/s]

Input shape: torch.Size([8, 3, 227, 227])
Output shape: torch.Size([8, 96, 227, 227])
Input shape: torch.Size([8, 96, 113, 113])
Output shape: torch.Size([8, 256, 113, 113])


In [None]:
# total_step = len(train_dataloader)

# model.to(device)

# for epoch in range(number_of_epochs):
#     model.train()
#     for images, labels in tqdm(train_dataloader):  
#         images = images.to(device)
#         labels = labels.to(device)
#         labels= labels.long()

#         # Forward pass
#         outputs = model(images)
#         loss = criterion(outputs, labels)

#         # Backward and optimize
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#     model.eval()
#     val_loss = 0.0
#     val_correct = 0
#     val_total = 0
    
#     with torch.no_grad():
#         for images,labels in validation_dataloader:
#             images = images.to(device)
#             labels = labels.to(device)
#             labels= labels.long()
            
#             outputs = model(images)
#             loss = criterion(outputs,labels)
#             val_loss += loss.item()

#             _, predicted = torch.max(outputs.data, 1)
#             val_total +=labels.size(0)
#             val_correct += (predicted ==labels).sum().item()

#     # Calculate validation metrics
#     avg_val_loss = val_loss / len(validation_dataloader)
#     val_accuracy = 100 * val_correct / val_total

#     print(f'Epoch [{epoch+1}/{number_of_epochs}], '
#           f'Train Loss: {loss.item():.4f}, '
#           f'Val Loss: {avg_val_loss:.4f}, '
#           f'Val Acc: {val_accuracy:.2f}%')

# print('Finished training')

In [None]:
torch.save(model.state_dict(), './alexNetModel.pth')