In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import unittest
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import importlib
import transform
import train
import vgg

In [8]:
device = torch.device('mps')
importlib.reload(transform)
model = transform.ImageTransformNet()
x = torch.ones((10, 3, 256, 256)).to(device)
y = model(x)
print(y.shape)

torch.Size([10, 32, 256, 256])
torch.Size([10, 64, 128, 128])
torch.Size([10, 128, 64, 64])
torch.Size([10, 128, 64, 64])
torch.Size([10, 128, 64, 64])
torch.Size([10, 128, 64, 64])
torch.Size([10, 64, 128, 128])
torch.Size([10, 32, 256, 256])
torch.Size([10, 3, 256, 256])
torch.Size([10, 3, 256, 256])


In [4]:
importlib.reload(train)
x = torch.ones((10, 3, 256, 256))
y = train.gram(x)
print(y.shape)

torch.Size([10, 3, 65536])
torch.Size([10, 65536, 3])
torch.Size([10, 3, 3])
torch.Size([10, 3, 3])


In [18]:
importlib.reload(vgg)
model = vgg.vgg16()
x = torch.ones((10, 3, 256, 256))
y1, y2, y3, y4 = model(x)
print(y1.shape, y2.shape, y3.shape, y4.shape)



torch.Size([10, 64, 256, 256]) torch.Size([10, 128, 128, 128]) torch.Size([10, 256, 64, 64]) torch.Size([10, 512, 32, 32])


In [None]:

def train_scorenet(_):
    setup_logging()
    torch.set_num_threads(4)
    torch.manual_seed(FLAGS.seed)

    writer = SummaryWriter(FLAGS.output_dir, max_queue=1000, flush_secs=120)

    if FLAGS.model_type == "unet":
        net = UNet()
    elif FLAGS.model_type == "simple_fc":
        net = torch.nn.Sequential(
          SimpleEncoder(input_size=1024, hidden_size=128, latent_size=16),
          SimpleDecoder(latent_size=16, hidden_size=128, output_size=1024))
    
    scorenet = ScoreNet(net, FLAGS.sigma_begin, FLAGS.sigma_end,
                        FLAGS.noise_level, FLAGS.sigma_type)
    logging.info(f'Number of parameters in ScoreNet: {count_parameters(scorenet)}')
    scorenet.train()
    
    transform = transforms.Compose([transforms.Pad(2), transforms.ToTensor()])
    dataset = datasets.MNIST(FLAGS.mnist_data_dir, train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=FLAGS.batch_size, shuffle=True)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    scorenet.to(device)
    optimizer = optim.Adam(scorenet.parameters(), lr=FLAGS.lr)
    iterations = 0

    train_loss = []
    for epoch in range(1, FLAGS.num_epochs + 1):
        for batch_idx, (data, _) in enumerate(dataloader):
            data = data.reshape(data.shape[0], -1)
            data = data.to(device)
            optimizer.zero_grad()
            loss = scorenet.get_loss(data)
            loss.backward()
            optimizer.step()
            
            train_loss += [loss.item()]
            iterations += 1

            if iterations % FLAGS.log_every == 0:
                writer.add_scalar('loss', np.mean(train_loss), iterations)
                logger('loss', np.mean(train_loss), iterations)
                train_loss = []
            
            if iterations % FLAGS.sample_every == 0:
                scorenet.eval()
                with torch.no_grad():
                    X_gen = scorenet.sample(64, 1024, step_lr=FLAGS.step_lr)[-1, -1].view(-1, 1, 32, 32)
                    
                    samples_image = BytesIO()
                    tvutils.save_image(X_gen, samples_image, 'png')
                    samples_image = Image.open(samples_image)
                    file_name = f'{FLAGS.output_dir}/samples_{iterations:08d}.png'
                    samples_image.save(file_name)
                    writer.add_image('samples', np.transpose(np.array(samples_image), [2,0,1]), iterations)

                    X_gt = data.view(-1,1,32,32)[:64]
                    gt_image = BytesIO()
                    tvutils.save_image(X_gt, gt_image, 'png')
                    gt_image = Image.open(gt_image)
                    writer.add_image('gt', np.transpose(np.array(gt_image), [2,0,1]), iterations)
                scorenet.train()