In [1]:
import hydra
from omegaconf import DictConfig, OmegaConf
import torch
import sys
from torch.optim import Adam
from torchvision.utils import save_image
from torch.utils.data import DataLoader, random_split
from models import CentralModel
from utils import num_to_groups
sys.path.insert(1, '../datasets/')
from samplers import sample
from losses import p_losses
from dataset_wraper import DatasetWraper
from schedulers import Scheduler
import os
import time


def evaluate (data, model, scheduler, device):
    model.eval()
    total_loss = []
    start_eval_time = time.time()
    
    for step, batch in enumerate(data.TestDataloader()):
        batch_size = batch["pixel_values"].shape[0]
        batch = batch["pixel_values"].to(device)
        t = torch.randint(0, scheduler.timesteps, (batch_size,), device=device).long()
        loss = p_losses(model, scheduler, batch, t, loss_type="huber")
        total_loss += [loss.item()]

    model.train()
    mean_loss = sum(total_loss) / len(total_loss)
    print(f'eval loss {mean_loss:.4f}, with time {time.time() - start_eval_time:.2f}s')
    return mean_loss

def train(data, model, optimizer, scheduler, cfg, device):
    model.train()
    for epoch in range(cfg.training.epochs):
        start_epoch_time = time.time()
        formatted_time = time.strftime('%H-%M-%b-%d-%Y')
        experiment_name = f"{data}_{model}_{cfg.optimizer.type}_{scheduler}_{cfg.training.loss}__ep{epoch}_{formatted_time}"
        print(f'conducting experiment {experiment_name} in directory {os.getcwd()}')

        epoch_loss = []
        for step, batch in enumerate(data.DataLoader()):
            optimizer.zero_grad()

            batch_size = batch["pixel_values"].shape[0]
            batch = batch["pixel_values"].to(device)

            t = torch.randint(0, cfg.scheduler.timesteps, (batch_size,), device=device).long()

            loss = p_losses(model, scheduler, batch, t, loss_type=cfg.training.loss)
            epoch_loss.append(loss.item())

            if step % 100 == 0:
                print(f'loss: {loss:.4f} at step {step} out of {len(data.DataLoader())}')

            loss.backward()
            optimizer.step()

        print(f'Epoch {epoch}: loss={sum(epoch_loss) / len(epoch_loss):.4f}, with time {time.time() - start_epoch_time:.2f}s')

        if epoch % cfg.training.eval_every == 0:
            evaluate(data, model, scheduler, device)
        if epoch % cfg.training.save_every == 0:
            torch.save(model.state_dict(), f'../weights/{experiment_name}.pkl')
        if epoch % cfg.training.visualize_every == 0:
            data.visualize(model, scheduler, experiment_name, sum(epoch_loss) / len(epoch_loss))

In [7]:
model = torch.load('../weights/FasionMNIST_UNet_Adam_cosine_huber__ep0_21-29-Jan-31-2023.pkl')
scheduler = Scheduler('cosine', timesteps=200)
data = DatasetWraper(cfg)

usage: ipykernel_launcher.py [--help] [--hydra-help] [--version]
                             [--cfg {job,hydra,all}] [--resolve]
                             [--package PACKAGE] [--run] [--multirun]
                             [--shell-completion] [--config-path CONFIG_PATH]
                             [--config-name CONFIG_NAME]
                             [--config-dir CONFIG_DIR]
                             [--experimental-rerun EXPERIMENTAL_RERUN]
                             [--info [{all,config,defaults,defaults-tree,plugins,searchpath}]]
                             [overrides ...]
ipykernel_launcher.py: error: unrecognized arguments: -f


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
