In [12]:
import os
import pandas as pd

def load_dataset_as_dataframe(subdir):
    # Define the path to the dataset
    data_dir = f'/kaggle/input/knee-osteoarthritis-dataset-with-severity/{subdir}'
    print(f'Load dataset from `{subdir}` subdirectory')

    # Create lists to store image paths and labels
    image_paths = []
    labels = []

    # Get the list of class directories
    classes = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
    classes.sort()
    print('Classes:', classes)

    # Map class names to labels
    class_to_idx = {class_name: idx for idx, class_name in enumerate(classes)}

    # Loop over each class directory
    for class_name in classes:
        class_dir = os.path.join(data_dir, class_name)
        label = class_to_idx[class_name]
        # Get all image files in the class directory
        for img_name in os.listdir(class_dir):
            img_path = os.path.join(class_dir, img_name)
            if os.path.isfile(img_path):
                image_paths.append(img_path)
                labels.append(label)

    # Create a DataFrame
    data = pd.DataFrame({
        'image_path': image_paths,
        'label': labels
    })
    
    # Show the distribution of labels in the dataset
    dataset_distribution_dict = {}
    for i in range(5):
        dataset_distribution_dict[i] = len(data[data['label'] == i])
    print(dataset_distribution_dict)
    print()

    return data

In [13]:
# Prepare training and test dataset as dataframes
classes = ['0', '1', '2', '3', '4']
train_df = load_dataset_as_dataframe('train')
test_df = load_dataset_as_dataframe('test')

Load dataset from `train` subdirectory
Classes: ['0', '1', '2', '3', '4']
{0: 2286, 1: 1046, 2: 1516, 3: 757, 4: 173}

Load dataset from `test` subdirectory
Classes: ['0', '1', '2', '3', '4']
{0: 639, 1: 296, 2: 447, 3: 223, 4: 51}



In [14]:
from torchvision import transforms

# Define image transformations
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet standards
                         std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


In [15]:
from torch.utils.data import Dataset
from PIL import Image

# Create custom dataset class
class KneeDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.loc[idx, 'image_path']
        label = self.df.loc[idx, 'label']
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label


In [16]:
from torch.utils.data import DataLoader

# Create dataset instances
train_dataset = KneeDataset(train_df, transform=train_transform)
test_dataset = KneeDataset(test_df, transform=test_transform)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)


In [18]:
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn as nn

# Load pre-trained ResNet18 model
# model = models.resnet18(pretrained=True)
model = resnet18(weights = ResNet18_Weights.IMAGENET1K_V1)

# Modify the final layer
num_ftrs = model.fc.in_features
num_classes = len(classes)
model.fc = nn.Linear(num_ftrs, num_classes)


In [19]:
import torch
import torch.optim as optim

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [20]:
from tqdm import tqdm

# Number of epochs
num_epochs = 25

# Training loop with progress bars
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    # Wrap the train_loader with tqdm for a progress bar
    pbar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]")
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Update running loss
        running_loss += loss.item() * images.size(0)
        
        # Calculate accuracy within the batch
        _, preds = torch.max(outputs, 1)
        total_correct += torch.sum(preds == labels.data)
        total_samples += labels.size(0)
        
        # Update progress bar description
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{(total_correct/total_samples*100):.2f}%'
        })
    
    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = total_correct.double() / len(train_dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

Epoch [1/25]: 100%|██████████| 181/181 [00:20<00:00,  9.05it/s, loss=1.0849, acc=46.78%]


Epoch [1/25], Loss: 1.2338, Accuracy: 0.4678


Epoch [2/25]: 100%|██████████| 181/181 [00:20<00:00,  8.96it/s, loss=1.0986, acc=56.27%]


Epoch [2/25], Loss: 1.0229, Accuracy: 0.5627


Epoch [3/25]: 100%|██████████| 181/181 [00:20<00:00,  8.78it/s, loss=1.0743, acc=59.17%]


Epoch [3/25], Loss: 0.9605, Accuracy: 0.5917


Epoch [4/25]: 100%|██████████| 181/181 [00:20<00:00,  8.67it/s, loss=0.9613, acc=62.62%]


Epoch [4/25], Loss: 0.8914, Accuracy: 0.6262


Epoch [5/25]: 100%|██████████| 181/181 [00:20<00:00,  8.83it/s, loss=0.9851, acc=63.88%]


Epoch [5/25], Loss: 0.8589, Accuracy: 0.6388


Epoch [6/25]: 100%|██████████| 181/181 [00:20<00:00,  8.85it/s, loss=0.7812, acc=66.22%]


Epoch [6/25], Loss: 0.8031, Accuracy: 0.6622


Epoch [7/25]: 100%|██████████| 181/181 [00:20<00:00,  8.74it/s, loss=0.9899, acc=67.08%]


Epoch [7/25], Loss: 0.7809, Accuracy: 0.6708


Epoch [8/25]: 100%|██████████| 181/181 [00:20<00:00,  8.75it/s, loss=0.9542, acc=67.55%]


Epoch [8/25], Loss: 0.7756, Accuracy: 0.6755


Epoch [9/25]: 100%|██████████| 181/181 [00:20<00:00,  8.83it/s, loss=0.6152, acc=69.02%]


Epoch [9/25], Loss: 0.7384, Accuracy: 0.6902


Epoch [10/25]: 100%|██████████| 181/181 [00:20<00:00,  8.76it/s, loss=0.9049, acc=71.27%]


Epoch [10/25], Loss: 0.6912, Accuracy: 0.7127


Epoch [11/25]: 100%|██████████| 181/181 [00:20<00:00,  8.74it/s, loss=1.0247, acc=72.34%]


Epoch [11/25], Loss: 0.6569, Accuracy: 0.7234


Epoch [12/25]: 100%|██████████| 181/181 [00:20<00:00,  8.80it/s, loss=0.3515, acc=74.39%]


Epoch [12/25], Loss: 0.6177, Accuracy: 0.7439


Epoch [13/25]: 100%|██████████| 181/181 [00:20<00:00,  8.83it/s, loss=0.8677, acc=74.70%]


Epoch [13/25], Loss: 0.6111, Accuracy: 0.7470


Epoch [14/25]: 100%|██████████| 181/181 [00:20<00:00,  8.79it/s, loss=0.3538, acc=77.10%]


Epoch [14/25], Loss: 0.5573, Accuracy: 0.7710


Epoch [15/25]: 100%|██████████| 181/181 [00:20<00:00,  8.80it/s, loss=0.3941, acc=79.02%]


Epoch [15/25], Loss: 0.5146, Accuracy: 0.7902


Epoch [16/25]: 100%|██████████| 181/181 [00:20<00:00,  8.83it/s, loss=0.8810, acc=79.98%]


Epoch [16/25], Loss: 0.4890, Accuracy: 0.7998


Epoch [17/25]: 100%|██████████| 181/181 [00:20<00:00,  8.81it/s, loss=0.1925, acc=82.97%]


Epoch [17/25], Loss: 0.4170, Accuracy: 0.8297


Epoch [18/25]: 100%|██████████| 181/181 [00:20<00:00,  8.80it/s, loss=0.4823, acc=84.39%]


Epoch [18/25], Loss: 0.3900, Accuracy: 0.8439


Epoch [19/25]: 100%|██████████| 181/181 [00:20<00:00,  8.81it/s, loss=0.3348, acc=86.38%]


Epoch [19/25], Loss: 0.3469, Accuracy: 0.8638


Epoch [20/25]: 100%|██████████| 181/181 [00:20<00:00,  8.81it/s, loss=0.1229, acc=89.20%]


Epoch [20/25], Loss: 0.2830, Accuracy: 0.8920


Epoch [21/25]: 100%|██████████| 181/181 [00:20<00:00,  8.83it/s, loss=0.3132, acc=89.67%]


Epoch [21/25], Loss: 0.2635, Accuracy: 0.8967


Epoch [22/25]: 100%|██████████| 181/181 [00:20<00:00,  8.83it/s, loss=0.4054, acc=92.35%]


Epoch [22/25], Loss: 0.2113, Accuracy: 0.9235


Epoch [23/25]: 100%|██████████| 181/181 [00:20<00:00,  8.80it/s, loss=0.0462, acc=92.35%]


Epoch [23/25], Loss: 0.2046, Accuracy: 0.9235


Epoch [24/25]: 100%|██████████| 181/181 [00:20<00:00,  8.81it/s, loss=0.0781, acc=94.29%]


Epoch [24/25], Loss: 0.1595, Accuracy: 0.9429


Epoch [25/25]: 100%|██████████| 181/181 [00:20<00:00,  8.82it/s, loss=0.2233, acc=94.96%]

Epoch [25/25], Loss: 0.1428, Accuracy: 0.9496





In [21]:
torch.save(model.state_dict(), "model.pth")