In [None]:
!pip install timm

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
from tqdm import tqdm
from google.colab import drive
import timm

# Mount Google Drive
drive.mount('/content/drive')

In [None]:

# Custom Dataset Class
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [os.path.join(root_dir, img) for img in os.listdir(root_dir)]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('L')  # Convert to grayscale
        label = 0 if 'COVID' in img_path else 1  # Assuming 'COVID' as class 0, 'Pneumonia' as class 1

        if self.transform:
            image = self.transform(image)

        return image, label

# Define Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=1),  # Convert to 1-channel grayscale
    transforms.ToTensor(),
])

# Create Datasets
covid_train_dataset = CustomDataset(root_dir="/content/drive/My Drive/Colab Notebooks/Image_Processing_Project/COVID_FFT_linear/train", transform=transform)
pneumonia_train_dataset = CustomDataset(root_dir="/content/drive/My Drive/Colab Notebooks/Image_Processing_Project/Pneumonia_FFT_linear/train", transform=transform)
covid_test_dataset = CustomDataset(root_dir="/content/drive/My Drive/Colab Notebooks/Image_Processing_Project/COVID_FFT_linear/test", transform=transform)
pneumonia_test_dataset = CustomDataset(root_dir="/content/drive/My Drive/Colab Notebooks/Image_Processing_Project/Pneumonia_FFT_linear/test", transform=transform)

# Combine Datasets
train_dataset = torch.utils.data.ConcatDataset([covid_train_dataset, pneumonia_train_dataset])
test_dataset = torch.utils.data.ConcatDataset([covid_test_dataset, pneumonia_test_dataset])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [None]:

# Define Inception-ResNet-v2 Model
class InceptionResNetV2CustomInput(nn.Module):
    def __init__(self):
        super(InceptionResNetV2CustomInput, self).__init__()
        self.inception_resnet_v2 = timm.create_model('inception_resnet_v2', pretrained=True)
        # Modify the input layer to accept a single-channel input
        # (Note: You may need to adjust this based on the actual structure of the model)
        self.inception_resnet_v2.conv2d_1a = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)

    def forward(self, x):
        return self.inception_resnet_v2(x)

# Instantiate the model
inception_resnet_v2_custom_input = InceptionResNetV2CustomInput()


In [None]:

# Define Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(inception_resnet_v2_custom_input.parameters(), lr=0.001, momentum=0.9)

# Training Loop with Progress Bar
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inception_resnet_v2_custom_input.to(device)

for epoch in range(num_epochs):
    inception_resnet_v2_custom_input.train()
    running_loss = 0.0
    for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = inception_resnet_v2_custom_input(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")


In [None]:

# Evaluation
inception_resnet_v2_custom_input.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc='Testing', leave=False):
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = inception_resnet_v2_custom_input(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")
