Basic PyTorch Implementation of EfficientNetB3 using classification instead of regression

In [None]:
#Imports
from efficientnet_pytorch import EfficientNet
import torch
import torchmetrics
import tqdm
import numpy as np
import os

In [None]:
#Setup Model
MODEL_TITLE = "test_mk0"

model = EfficientNet.from_pretrained("efficientnet-b3", in_channels = 3, num_classes = 1000)

if torch.cuda.is_available():
    print("GPU Will Be Used")
    device = "cuda"
else:
    print("CPU Will Be Used")
    device = "cpu"

model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 30)
loss = torch.nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()

In [None]:
#Data Loading TODO since we'll have to re-organize some of the formatting
class Dataset:
    def __init__(self, dataset_name):
        self.dataset = np.array(dataset_name)
dataset_train = Dataset(self.cfg, dataset_name="train_set", transform=self.ttfms)
train_loader = DataLoader(dataset_train, 
                            batch_size=self.cfg['TRAIN']['BATCH_SIZE'], 
                            num_workers=self.cfg['DATASET']['NUM_WORKERS'])

dataset_val = ClassifierDataset(self.cfg, dataset_name="val_set", transform=self.vtfms)
val_loader = DataLoader(dataset_val, 
                        batch_size=self.cfg['TEST']['BATCH_SIZE'], 
                        num_workers=self.cfg['DATASET']['NUM_WORKERS'],)

In [None]:
#Training Functions
def train_fn(train_loader):
    model.train()
    loop = tqdm.tqdm(train_loader)
    total_loss = 0
    for batch_idx, data in enumerate(loop):
        imgs = data['image'].to(device)
        labels = data['class'].squeeze(1).to(device)

        with torch.cuda.amp.autocast():
            outputs = model(imgs)
            loss = loss(outputs, labels)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()
        loop.set_postfix(loss=total_loss/(batch_idx+1))
    
    return total_loss/len(loop)

def val_fn(val_loader):
    model.eval()
    loop = tqdm.tqdm(val_loader)
    total_loss = 0
    
    pr, re, ac, f1 = 0, 0 ,0 ,0
    with torch.no_grad():
        for batch_idx, data in enumerate(loop):
            imgs = data['image'].to(device)
            labels = data['class'].squeeze(1).to(device)

            # forward
            outputs = model(imgs)
            loss = loss(outputs, labels)
                
            # calc ap and ar metrics
            output_on_cpu = outputs.detach().cpu().softmax(dim=-1)
            labels_on_cpu = labels.detach().cpu()
            pr += torchmetrics.functional.precision(output_on_cpu, labels_on_cpu)
            re += torchmetrics.functional.recall(output_on_cpu, labels_on_cpu)
            ac += torchmetrics.functional.accuracy(output_on_cpu.argmax(axis=-1), labels_on_cpu)
            f1 += torchmetrics.functional.f1(output_on_cpu, labels_on_cpu)

            
            total_loss += loss.item()
            
            # update tqdm loop
            loop.set_postfix(loss=total_loss/(batch_idx+1))        

    return pr/len(loop), re/len(loop), ac/len(loop), f1/len(loop), total_loss/len(loop)

def save_checkpoint(model, f1, loss, model_name):
    print(f"=> Saving checkpoint with validation f1-score: {f1}, and validation loss: {loss}")
    torch.save(model, model_name) 

In [None]:
#Training Loop
history ={'loss': [], 'val_loss': [], 'val_f1': [], 'val_pr': [], 'val_acc': [], 'val_re':[]}
comp = np.inf
max_epoch = 1000
for epoch in range(max_epoch):
    loss = train_fn(train_loader)
    pr, re, ac, f1, val_loss = val_fn(val_loader)
    
    history['loss'].append(loss)
    history['val_loss'].append(val_loss) 
    history['val_pr'].append(pr) 
    history['val_re'].append(re)
    history['val_acc'].append(ac) 
    history['val_f1'].append(f1) 

    if val_loss < comp:
        comp = val_loss
        checkpoint = {
            "epoch": epoch,
            "state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "history": history
        }
        save_checkpoint(checkpoint, f1, val_loss, os.path.join(os.getcwd(),'model_output.pt'))
    else:
        print(f"=> The validation loss did not improve from: {comp}")
