In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.models as models
from PIL import Image
from sklearn.metrics import accuracy_score, precision_score
from torch import nn
from torch.nn import Softmax
from torchvision import datasets
from torchvision import transforms
from torchvision import transforms as T
import torchvision.models as models
import os
import cv2

from preprocess import resize_input, train_test_split, read_raw
from ear_dataset import EarDataset



In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
data_path = "../UERC"

In [4]:
train_subjects = np.loadtxt("../train_subjects_mask.txt")

In [5]:
from PIL import Image

In [6]:
ear_data = os.listdir(data_path)

ear_imgs = {}
for person in ear_data:
    if int(person) not in train_subjects:
        continue
    
    imgs = os.listdir("%s/%s" % (data_path, person))
    try:
        ear_imgs[person] = [
            cv2.cvtColor(
                np.asarray(Image.open(f"{data_path}/{person}/{img}")), cv2.COLOR_BGR2RGB
            )
            for img in imgs
        ]
    except Exception as e:
        print(e)

In [7]:
X_train, X_eval, y_train, y_eval = train_test_split(ear_imgs)

# Preprocess data
X_train = resize_input(X_train, tgt_size=64, mode="train")
X_test = resize_input(X_eval, tgt_size=64, mode="test")

train_dataset = EarDataset(X_train, y_train)
eval_dataset = EarDataset(X_eval, y_eval)

torch.save(train_dataset, "data/train_dataset.pt")
torch.save(eval_dataset, "data/eval_dataset.pt")



In [None]:
def train(model, criterion, optimizer, train_loader, valid_loader, epochs, save_path="models/model"):
    best_val_loss = float('inf')  # Initialize with infinity
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)    
    
    for epoch in range(epochs):
        # Training loop
        model.train()
        train_loss = 0.0
        val_loss = 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
        scheduler.step(val_loss)
        
        # Validation loop
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in valid_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        
                # Get accuracy
                _, pred = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (pred == labels).sum().item()
        
        # Average losses
        train_loss /= len(train_loader)
        val_loss /= len(valid_loader)
        
        print("Validation accuracy: ", 100 * correct / total)
        
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Valid Loss: {val_loss:.4f}')
        
        # Save the model if validation loss is improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print("Model saved with validation loss:", best_val_loss)


In [None]:
model = models.resnet50(pretrained=True)
n_features = model.fc.in_features
n_classes = 