In [7]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import random
import hydra
from omegaconf import OmegaConf
from ai.experiment_tracking.tracker import Tracker
from ai.datasets.cifar10 import Cifar10
from ai.models import models


In [8]:
def train(model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, criterion, optimizer:torch.optim.Optimizer):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    for i, (images, labels) in enumerate(tqdm(data_loader,desc='Training')):  
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    return loss

In [9]:
def evaluate(model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, criterion):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    with torch.no_grad():
        correct = 0
        total = 0
        loss_val = 0
        labels_all = []
        for images, labels in tqdm(data_loader,desc='Testing'):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss_val += criterion(outputs, labels)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            labels_all += [predicted]

    return correct/total, loss_val/len(data_loader), labels_all

In [10]:
with hydra.initialize(version_base=None, config_path="../config"):

    cfg_hydra = hydra.compose(config_name='train',return_hydra_config=True)['hydra']
    cfg = hydra.compose(config_name='train')
    print(f'Running training with: {OmegaConf.to_yaml(cfg)}')

    np.random.seed(cfg['seed'])
    torch.manual_seed(cfg['seed'])
    random.seed(cfg['seed'])

    datasets = {
        'cifar10': Cifar10
    }
    dataset = datasets['cifar10'](shape=(cfg.height, cfg.width), batch_size=cfg.batch_size, data_path=cfg.data_path)
    
    experiment_path = cfg_hydra['run']['dir']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Selected device: [{device}]')

    model = models.load(cfg.model.name, shape_in=(cfg.height, cfg.width), shape_out=len(dataset.classes),config=OmegaConf.to_object(cfg.model))

    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)  

    if len(cfg.checkpoint) > 1:
        checkpoint = torch.load(cfg.checkpoint)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']

    experiment = Tracker('image_classification', 
                         experiment_config=OmegaConf.to_object(cfg),
                         experiment_name=f'{cfg.experiment_prefix}_{cfg.model.name}_{cfg.dataset_name}', 
                         data_path=cfg.data_path,
                         experiment_path=experiment_path,
                         dataset=dataset,
                         model=model,
                         optimizer=optimizer,
                         use_wandb=cfg.use_wandb)
        

    for epoch in range(cfg.num_epochs):
        
        loss_train = train(model, dataset.train_loader, criterion, optimizer)

        accuracy_val, loss_val, labels_pred = evaluate(model, dataset.test_loader, criterion)

        experiment.log(epoch, scalars={'loss_val':loss_val,'loss_train':loss_train, 'accuracy_val':accuracy_val}, labels_pred=labels_pred)

Running training with: experiment_name: train_vit_cifar
batch_size: 128
learning_rate: 0.001
num_epochs: 50
model_name: vit
dataset_name: cifar10
height: 32
width: 32
checkpoint: ''
use_wandb: true
data_path: /mnt/dataset/pytorch
seed: 666

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Selected device: [cpu]


Training:  47%|████▋     | 184/391 [12:19<13:52,  4.02s/it] 


KeyboardInterrupt: 