In [1]:
# import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.models as models
from PIL import Image
from sklearn.metrics import accuracy_score, precision_score
from torch import nn
from torchvision import transforms as T
import torchvision.models as models

from torch.utils.data import DataLoader

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

### 1. Load dataloaders

In [3]:
train_ds = torch.load("data/train_dataset_500.pt")
eval_ds = torch.load("data/eval_dataset_500.pt")


In [4]:
len(train_ds.data), len(train_ds.labels), len(eval_ds.data), len(eval_ds.labels)

(5174, 5174, 2389, 2389)

In [13]:

mapping = {}
rev_mapping = {}
for i, label in enumerate(train_ds.labels):
    mapping[i] = label
    rev_mapping[label] = i
    
train_mapped_labels = [rev_mapping[label] for label in train_ds.labels]
eval_mapped_labels = [rev_mapping[label] for label in eval_ds.labels]

print(len(train_mapped_labels))

train_ds.labels = train_mapped_labels
eval_ds.labels = eval_mapped_labels

BATCH_SIZE = 64

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
eval_dl = DataLoader(eval_ds, batch_size=BATCH_SIZE, shuffle=True)

5174


In [16]:
1798 in eval_ds.labels or 1798 in train_ds.labels

True

In [6]:
len(train_ds.labels), len(train_ds.data)

3616 in eval_ds.labels

False

In [7]:
# Check if shapes are ok
en1, en2 = train_ds[0], train_ds[5]

en1[0].shape, en2[0].shape

(torch.Size([3, 224, 224]), torch.Size([3, 224, 224]))

#### Set the train loop

In [8]:
loss_history = []
accuracy_history = []

def train(model, criterion, optimizer, train_loader, valid_loader, epochs, save_path="models/model"):
    best_val_loss = float('inf')  # Initialize with infinity
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)    
    
    for epoch in range(epochs):
        # Training loop
        model.train()
        train_loss = 0.0
        val_loss = 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
        loss_history.append(loss.item())    
        
        scheduler.step(val_loss)
        
        # Validation loop
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in valid_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        
                # Get accuracy
                _, pred = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (pred == labels).sum().item()
        
        # Average losses
        train_loss /= len(train_loader)
        val_loss /= len(valid_loader)
        
        val_acc = 100 * correct / total
        
        accuracy_history.append(val_acc)
        
        print("Validation accuracy: ", val_acc)
        
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Valid Loss: {val_loss:.4f}')
        
        # Save the model if validation loss is improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print("Model saved with validation loss:", best_val_loss)


### 2. Get model and set it up

In [9]:
model = models.resnet50(pretrained=True)
n_features = model.fc.in_features
n_classes = train_ds.get_n_classes()

model.fc = nn.Linear(n_features, n_classes)



In [10]:
n_classes

362

### 3. Train the model

In [17]:
LR = 1e-4

n_epochs = 100

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

In [18]:
train(model, criterion, optimizer, train_dl, eval_dl, n_epochs, "models/resnet50.pt")



: 

: 