In [42]:
import pandas as pd
import numpy as np
import torch.cuda as cuda
from torch.cuda import is_available
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet50
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
from torch import device
import os
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from collections import Counter
from sklearn.utils.class_weight import compute_class_weight

In [61]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Constants
BATCH_SIZE = 50
NUM_EPOCHS = 200
LABEL_TO_CLASS: dict = {
    'N': 0,
    'D': 1,
    'G': 2,
    'C': 3,
    'A': 4,
    'H': 5,
    'M': 6,
    'O': 7
} 
CLASS_TO_LABEL: dict = {
    1 : 'D',
    5 : 'H',
    6 : 'M',
    7 : 'O'
}
NUM_CLASSES: int = len(LABEL_TO_CLASS)
DEVICE = device("cuda" if is_available() else "cpu")
LEARNING_RATE: float = 0.0002

In [24]:
class SyntheticDataset(Dataset):
    def __init__(self, root_dir, labels, transform=None, selected_labels=['D','O','M','H']):
        self.data = pd.read_csv(labels)
        self.root_dir = root_dir
        self.transform = transform
        self.selected_labels = selected_labels
        
        # Filter Data
        self.data = self.data[self.data['labels'].isin(selected_labels)]
        self.label_mapping = {label: i for i, label in enumerate(selected_labels)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.data.iloc[idx, 0])
        
        # Extract label and orientation information from filename
        label = self.data.iloc[idx, 1]
        orientation = img_name.split('_')[-1].split('.')[0]
        
        # Load image
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image, label, orientation

    
# Define DataLoader for synthetic dataset
repetition_dataset = SyntheticDataset(root_dir='../data/train_rep_imgs', labels='../data/train_rep_labels.csv', transform=transform)
repetition_dataloader = DataLoader(repetition_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Define image classification model
model = resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, NUM_CLASSES)
model = model.to(DEVICE)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
losses_per_epoch = []

# Training loop
for epoch in range(NUM_EPOCHS):
    running_loss = 0.0
    for i, (inputs, labels, orientation) in enumerate(repetition_dataloader):
        inputs = inputs.to(DEVICE)
        label_indices = [LABEL_TO_CLASS[label] for label in labels]
        labels = torch.tensor(label_indices).to(DEVICE)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
        if i % 15 == 14:  # Print every 15 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 15))
            running_loss = 0.0
    
    # Calculate average loss for the epoch
    epoch_loss = running_loss / len(repetition_dataloader)
    losses_per_epoch.append(epoch_loss)

print('Finished Training')



[1,    15] loss: 0.971
[1,    30] loss: 0.578
[1,    45] loss: 0.459
[1,    60] loss: 0.388
[1,    75] loss: 0.403
[1,    90] loss: 0.360
[1,   105] loss: 0.323
[1,   120] loss: 0.403
[1,   135] loss: 0.337
[1,   150] loss: 0.349
[1,   165] loss: 0.202
[1,   180] loss: 0.182
[2,    15] loss: 0.182
[2,    30] loss: 0.267
[2,    45] loss: 0.173
[2,    60] loss: 0.133
[2,    75] loss: 0.092
[2,    90] loss: 0.084
[2,   105] loss: 0.158
[2,   120] loss: 0.085
[2,   135] loss: 0.097
[2,   150] loss: 0.099
[2,   165] loss: 0.215
[2,   180] loss: 0.158
[3,    15] loss: 0.082
[3,    30] loss: 0.111
[3,    45] loss: 0.087
[3,    60] loss: 0.057
[3,    75] loss: 0.049
[3,    90] loss: 0.035
[3,   105] loss: 0.037
[3,   120] loss: 0.116
[3,   135] loss: 0.051
[3,   150] loss: 0.098
[3,   165] loss: 0.082
[3,   180] loss: 0.116
[4,    15] loss: 0.072
[4,    30] loss: 0.060
[4,    45] loss: 0.059
[4,    60] loss: 0.065
[4,    75] loss: 0.027
[4,    90] loss: 0.067
[4,   105] loss: 0.078
[4,   120] 

[29,   165] loss: 0.000
[29,   180] loss: 0.000
[30,    15] loss: 0.000
[30,    30] loss: 0.000
[30,    45] loss: 0.000
[30,    60] loss: 0.000
[30,    75] loss: 0.000
[30,    90] loss: 0.000
[30,   105] loss: 0.000
[30,   120] loss: 0.000
[30,   135] loss: 0.000
[30,   150] loss: 0.000
[30,   165] loss: 0.000
[30,   180] loss: 0.000
[31,    15] loss: 0.115
[31,    30] loss: 0.191
[31,    45] loss: 0.137
[31,    60] loss: 0.102
[31,    75] loss: 0.097
[31,    90] loss: 0.064
[31,   105] loss: 0.034
[31,   120] loss: 0.098
[31,   135] loss: 0.121
[31,   150] loss: 0.086
[31,   165] loss: 0.040
[31,   180] loss: 0.041
[32,    15] loss: 0.040
[32,    30] loss: 0.020
[32,    45] loss: 0.011
[32,    60] loss: 0.044
[32,    75] loss: 0.022
[32,    90] loss: 0.022
[32,   105] loss: 0.024
[32,   120] loss: 0.010
[32,   135] loss: 0.007
[32,   150] loss: 0.006
[32,   165] loss: 0.004
[32,   180] loss: 0.003
[33,    15] loss: 0.009
[33,    30] loss: 0.002
[33,    45] loss: 0.008
[33,    60] loss

[58,    75] loss: 0.000
[58,    90] loss: 0.000
[58,   105] loss: 0.007
[58,   120] loss: 0.003
[58,   135] loss: 0.001
[58,   150] loss: 0.008
[58,   165] loss: 0.003
[58,   180] loss: 0.002
[59,    15] loss: 0.001
[59,    30] loss: 0.000
[59,    45] loss: 0.000
[59,    60] loss: 0.000
[59,    75] loss: 0.000
[59,    90] loss: 0.003
[59,   105] loss: 0.001
[59,   120] loss: 0.002
[59,   135] loss: 0.002
[59,   150] loss: 0.014
[59,   165] loss: 0.019
[59,   180] loss: 0.054
[60,    15] loss: 0.001
[60,    30] loss: 0.030
[60,    45] loss: 0.005
[60,    60] loss: 0.002
[60,    75] loss: 0.002
[60,    90] loss: 0.001
[60,   105] loss: 0.003
[60,   120] loss: 0.001
[60,   135] loss: 0.001
[60,   150] loss: 0.000
[60,   165] loss: 0.000
[60,   180] loss: 0.020
[61,    15] loss: 0.004
[61,    30] loss: 0.001
[61,    45] loss: 0.006
[61,    60] loss: 0.002
[61,    75] loss: 0.001
[61,    90] loss: 0.034
[61,   105] loss: 0.002
[61,   120] loss: 0.012
[61,   135] loss: 0.005
[61,   150] loss

[86,   165] loss: 0.000
[86,   180] loss: 0.000
[87,    15] loss: 0.000
[87,    30] loss: 0.000
[87,    45] loss: 0.000
[87,    60] loss: 0.000
[87,    75] loss: 0.000
[87,    90] loss: 0.000
[87,   105] loss: 0.000
[87,   120] loss: 0.000
[87,   135] loss: 0.000
[87,   150] loss: 0.000
[87,   165] loss: 0.000
[87,   180] loss: 0.000
[88,    15] loss: 0.000
[88,    30] loss: 0.000
[88,    45] loss: 0.000
[88,    60] loss: 0.000
[88,    75] loss: 0.000
[88,    90] loss: 0.000
[88,   105] loss: 0.000
[88,   120] loss: 0.000
[88,   135] loss: 0.000
[88,   150] loss: 0.000
[88,   165] loss: 0.000
[88,   180] loss: 0.000
[89,    15] loss: 0.000
[89,    30] loss: 0.000
[89,    45] loss: 0.000
[89,    60] loss: 0.000
[89,    75] loss: 0.000
[89,    90] loss: 0.000
[89,   105] loss: 0.000
[89,   120] loss: 0.000
[89,   135] loss: 0.000
[89,   150] loss: 0.000
[89,   165] loss: 0.000
[89,   180] loss: 0.000
[90,    15] loss: 0.000
[90,    30] loss: 0.000
[90,    45] loss: 0.000
[90,    60] loss

[114,   150] loss: 0.005
[114,   165] loss: 0.013
[114,   180] loss: 0.005
[115,    15] loss: 0.048
[115,    30] loss: 0.014
[115,    45] loss: 0.025
[115,    60] loss: 0.012
[115,    75] loss: 0.006
[115,    90] loss: 0.004
[115,   105] loss: 0.007
[115,   120] loss: 0.016
[115,   135] loss: 0.001
[115,   150] loss: 0.007
[115,   165] loss: 0.001
[115,   180] loss: 0.001
[116,    15] loss: 0.001
[116,    30] loss: 0.000
[116,    45] loss: 0.000
[116,    60] loss: 0.000
[116,    75] loss: 0.000
[116,    90] loss: 0.000
[116,   105] loss: 0.000
[116,   120] loss: 0.000
[116,   135] loss: 0.000
[116,   150] loss: 0.000
[116,   165] loss: 0.000
[116,   180] loss: 0.000
[117,    15] loss: 0.143
[117,    30] loss: 0.055
[117,    45] loss: 0.130
[117,    60] loss: 0.034
[117,    75] loss: 0.039
[117,    90] loss: 0.018
[117,   105] loss: 0.009
[117,   120] loss: 0.005
[117,   135] loss: 0.003
[117,   150] loss: 0.002
[117,   165] loss: 0.002
[117,   180] loss: 0.001
[118,    15] loss: 0.002


[142,    30] loss: 0.000
[142,    45] loss: 0.000
[142,    60] loss: 0.000
[142,    75] loss: 0.000
[142,    90] loss: 0.000
[142,   105] loss: 0.000
[142,   120] loss: 0.000
[142,   135] loss: 0.000
[142,   150] loss: 0.000
[142,   165] loss: 0.000
[142,   180] loss: 0.000
[143,    15] loss: 0.000
[143,    30] loss: 0.000
[143,    45] loss: 0.000
[143,    60] loss: 0.000
[143,    75] loss: 0.000
[143,    90] loss: 0.000
[143,   105] loss: 0.000
[143,   120] loss: 0.000
[143,   135] loss: 0.000
[143,   150] loss: 0.000
[143,   165] loss: 0.000
[143,   180] loss: 0.000
[144,    15] loss: 0.000
[144,    30] loss: 0.000
[144,    45] loss: 0.000
[144,    60] loss: 0.000
[144,    75] loss: 0.000
[144,    90] loss: 0.000
[144,   105] loss: 0.000
[144,   120] loss: 0.000
[144,   135] loss: 0.000
[144,   150] loss: 0.000
[144,   165] loss: 0.000
[144,   180] loss: 0.000
[145,    15] loss: 0.000
[145,    30] loss: 0.000
[145,    45] loss: 0.000
[145,    60] loss: 0.000
[145,    75] loss: 0.000


[169,    90] loss: 0.000
[169,   105] loss: 0.000
[169,   120] loss: 0.000
[169,   135] loss: 0.000
[169,   150] loss: 0.000
[169,   165] loss: 0.000
[169,   180] loss: 0.000
[170,    15] loss: 0.000
[170,    30] loss: 0.000
[170,    45] loss: 0.000
[170,    60] loss: 0.000
[170,    75] loss: 0.000
[170,    90] loss: 0.000
[170,   105] loss: 0.000
[170,   120] loss: 0.000
[170,   135] loss: 0.000
[170,   150] loss: 0.000
[170,   165] loss: 0.000
[170,   180] loss: 0.000
[171,    15] loss: 0.000
[171,    30] loss: 0.000
[171,    45] loss: 0.000
[171,    60] loss: 0.000
[171,    75] loss: 0.000
[171,    90] loss: 0.000
[171,   105] loss: 0.000
[171,   120] loss: 0.000
[171,   135] loss: 0.000
[171,   150] loss: 0.000
[171,   165] loss: 0.000
[171,   180] loss: 0.000
[172,    15] loss: 0.000
[172,    30] loss: 0.000
[172,    45] loss: 0.000
[172,    60] loss: 0.000
[172,    75] loss: 0.000
[172,    90] loss: 0.000
[172,   105] loss: 0.000
[172,   120] loss: 0.000
[172,   135] loss: 0.000


[196,   150] loss: 0.000
[196,   165] loss: 0.000
[196,   180] loss: 0.000
[197,    15] loss: 0.120
[197,    30] loss: 0.091
[197,    45] loss: 0.068
[197,    60] loss: 0.022
[197,    75] loss: 0.020
[197,    90] loss: 0.033
[197,   105] loss: 0.008
[197,   120] loss: 0.006
[197,   135] loss: 0.001
[197,   150] loss: 0.024
[197,   165] loss: 0.015
[197,   180] loss: 0.003
[198,    15] loss: 0.002
[198,    30] loss: 0.005
[198,    45] loss: 0.001
[198,    60] loss: 0.000
[198,    75] loss: 0.000
[198,    90] loss: 0.000
[198,   105] loss: 0.010
[198,   120] loss: 0.001
[198,   135] loss: 0.001
[198,   150] loss: 0.005
[198,   165] loss: 0.000
[198,   180] loss: 0.001
[199,    15] loss: 0.000
[199,    30] loss: 0.001
[199,    45] loss: 0.000
[199,    60] loss: 0.000
[199,    75] loss: 0.000
[199,    90] loss: 0.000
[199,   105] loss: 0.000
[199,   120] loss: 0.000
[199,   135] loss: 0.000
[199,   150] loss: 0.000
[199,   165] loss: 0.000
[199,   180] loss: 0.000
[200,    15] loss: 0.000


In [55]:
# Image Classification

# Check the distribution of labels in the dataset
label_counts = Counter(all_labels)
print("Label Distribution:", label_counts)

# Compute class weights manually
total_samples = len(all_labels)
class_weights = [total_samples / (len(selected_labels) * count) for count in label_counts.values()]
class_weights = torch.FloatTensor(class_weights).to(DEVICE)


# Convert class labels to integers if they are not already
filtered_labels = [label for label in all_labels if label in selected_labels]
label_indices = [selected_labels.index(label) for label in filtered_labels]

# Convert class labels to integers corresponding to class indices
mapped_labels = [LABEL_TO_CLASS[selected_labels[index]] for index in label_indices]
label_indices_tensor = torch.tensor(label_indices, dtype=torch.long, device=DEVICE)
predicted_labels = [CLASS_TO_LABEL[index] for index in all_predictions]

print("Label Indices:", label_indices_tensor)
print("Predictions:", predicted_labels)

# Lists to store predictions and ground truth labels
all_predictions = []
all_labels = []
selected_labels=['D', 'O', 'M', 'H']

# Set the model to evaluation mode
model.eval()


# Disable gradient tracking during inference
with torch.no_grad():
    for inputs, labels, _ in repetition_dataloader:
        inputs = inputs.to(DEVICE) 

        # Forward pass
        outputs = model(inputs)
        
        probabilities = torch.softmax(outputs, dim=1)

        # Get predicted labels
        _, predicted = torch.max(probabilities, 1)

        # Store predictions and ground truth labels
        all_predictions.extend(predicted.cpu().numpy())
        all_labels.extend(list(labels))  # Convert labels to list
        
all_labels = [selected_labels.index(label) for label in all_labels]

# Print classification report
print("Classification Report:")
print(classification_report(all_labels, all_predictions, labels=[0, 1, 2, 3], target_names=selected_labels))

Label Distribution: Counter({'D': 463, 'H': 453, 'M': 434})
Label Indices: tensor([3, 0, 0,  ..., 3, 3, 3])
Predictions: ['H', 'D', 'D', 'H', 'H', 'M', 'D', 'H', 'H', 'M', 'D', 'D', 'H', 'M', 'M', 'H', 'M', 'D', 'M', 'D', 'D', 'H', 'H', 'M', 'H', 'D', 'D', 'D', 'D', 'M', 'H', 'M', 'D', 'H', 'M', 'H', 'M', 'M', 'M', 'H', 'H', 'D', 'H', 'D', 'M', 'D', 'D', 'M', 'M', 'M', 'D', 'D', 'M', 'H', 'D', 'D', 'D', 'M', 'D', 'M', 'M', 'H', 'H', 'H', 'H', 'M', 'H', 'M', 'M', 'D', 'M', 'M', 'M', 'M', 'M', 'H', 'M', 'D', 'D', 'H', 'M', 'H', 'M', 'H', 'M', 'D', 'M', 'M', 'M', 'H', 'M', 'D', 'M', 'M', 'H', 'H', 'M', 'D', 'H', 'D', 'M', 'M', 'M', 'H', 'D', 'M', 'H', 'D', 'M', 'M', 'D', 'M', 'M', 'D', 'D', 'M', 'H', 'M', 'H', 'D', 'D', 'H', 'H', 'H', 'M', 'D', 'D', 'M', 'M', 'D', 'H', 'M', 'D', 'D', 'M', 'H', 'D', 'D', 'M', 'D', 'M', 'D', 'D', 'D', 'D', 'H', 'M', 'H', 'M', 'H', 'H', 'H', 'D', 'D', 'D', 'D', 'D', 'D', 'D', 'H', 'D', 'D', 'H', 'M', 'H', 'D', 'H', 'D', 'H', 'H', 'D', 'M', 'D', 'M', 'H', 'D'

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
# Define DataLoader for synthetic dataset
synthetic_dataset = SyntheticDataset(root_dir='../data/train_synth_imgs', labels='../data/train_synth_labels.csv', transform=transform)
synthetic_dataloader = DataLoader(synthetic_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

# Training loop
for epoch in range(NUM_EPOCHS):
    running_loss = 0.0
    for i, (inputs, labels, _) in enumerate(synthetic_dataloader):
        inputs = inputs.to(DEVICE)
        labels = [LABEL_TO_CLASS[label] for label in labels]
        labels = torch.tensor(labels, dtype=torch.long).to(DEVICE)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
        if i % 15 == 14:  # Print every 15 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 15))
            running_loss = 0.0
    
    # Calculate average loss for the epoch
    epoch_loss = running_loss / len(synthetic_dataloader)
    losses_per_epoch.append(epoch_loss)

print('Finished Training')

# Evaluation
all_predictions = torch.tensor([], dtype=torch.long).to(DEVICE)
all_labels = torch.tensor([], dtype=torch.long).to(DEVICE)

# Set the model to evaluation mode
model.eval()

# Disable gradient tracking during inference
with torch.no_grad():
    for inputs, labels, _ in synthetic_dataloader:
        inputs = inputs.to(DEVICE) 

        # Forward pass
        outputs = model(inputs)
        
        probabilities = torch.softmax(outputs, dim=1)

        # Get predicted labels
        _, predicted = torch.max(probabilities, 1)

        # Store predictions and ground truth labels
        all_predictions = torch.cat((all_predictions, predicted), dim=0)
        all_labels = torch.cat((all_labels, labels), dim=0)

# Convert tensors to numpy arrays
all_predictions = all_predictions.cpu().numpy()
all_labels = all_labels.cpu().numpy()

# Convert class indices back to original labels
all_labels = [CLASS_TO_LABEL[label] for label in all_labels]

# Print classification report
print("Classification Report:")
print(classification_report(all_labels, all_predictions, labels=[0, 1, 2, 3], target_names=selected_labels))