In [1]:
import os
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.models import efficientnet_b0
import torch.nn as nn
import torch.optim as optim

In [3]:
class CellsDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        
        for quality in ['good', 'bad']:
            quality_dir = os.path.join(root_dir, quality)
            if os.path.exists(quality_dir):
                for cell_type in os.listdir(quality_dir):
                    cell_type_dir = os.path.join(quality_dir, cell_type)
                    for age_group in os.listdir(cell_type_dir):
                        age_group_dir = os.path.join(cell_type_dir, age_group)
                        for img_filename in os.listdir(age_group_dir):
                            self.samples.append({
                                'image_path': os.path.join(age_group_dir, img_filename),
                                'label': 1 if quality == 'good' else 0
                            })
            else:
                print(f"Warning: Directory {quality_dir} does not exist")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        image = Image.open(sample['image_path']).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, sample['label']


In [6]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = CellsDataset(root_dir='../OOC_image_dataset/train', transform=transform)
val_dataset = CellsDataset(root_dir='../OOC_image_dataset/val', transform=transform)
test_dataset = CellsDataset(root_dir='../OOC_image_dataset/test', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)


In [7]:
def initialize_model(num_classes=2):
    model = efficientnet_b0(pretrained=True)
    num_features = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(num_features, num_classes)
    return model



In [8]:
model = initialize_model(num_classes=2)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
device = torch.device("cpu")
model.to(device)




EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActivat

In [9]:
# def train(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
#     for epoch in range(num_epochs):
#         print(f"epoch {epoch + 1}")
#         model.train()
#         running_loss = 0.0
#         for images, labels in train_loader:
#             images, labels = images.to(device), labels.to(device)
#             optimizer.zero_grad()
#             outputs = model(images)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#             running_loss += loss.item()

#         # Validation step
#         model.eval()
#         val_loss = 0.0
#         with torch.no_grad():
#             for images, labels in val_loader:
#                 images, labels = images.to(device), labels.to(device)
#                 outputs = model(images)
#                 loss = criterion(outputs, labels)
#                 val_loss += loss.item()

#         print(f"Epoch {epoch+1}, Train Loss: {running_loss / len(train_loader)}, Val Loss: {val_loss / len(val_loader)}")

# train(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # Validation phase
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model(inputs)
                val_loss += criterion(outputs, labels).item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        print(f'Epoch {epoch+1}: Train Loss: {loss.item()}, Val Loss: {val_loss / len(val_loader)}, Accuracy: {accuracy}%')

In [14]:
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the model on the test images: {accuracy}%')
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)


test(model, test_loader)