In [2]:
import random
import argparse
import os
import torch
import torchvision
from tqdm import tqdm
import numpy as np
import torchvision.utils as vutils
from torch.autograd import Variable
from nets import VanillaNet, NonlocalNet
import torchvision.transforms as tr
import torchvision.datasets as datasets
import torch.optim as optim
import json

def requires_grad(parameters, flag=True):
    for p in parameters:
        p.requires_grad = flag

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

'''
load state dict for ebm0 and ebm1
'''
EBM = VanillaNet(n_c = 3, n_f = 48).to(device)
state_EBM = torch.load('./output_dir/ebm_0.pth', map_location=device)
EBM.load_state_dict(state_EBM)

D_new = VanillaNet(n_c = 3, n_f = 48).to(device)
state_EBM2 = torch.load('./output_dir/ebm_1.pth', map_location=device)
D_new.load_state_dict(state_EBM2)

'''
load config file
'''
CONFIG_FILE = './config_locker/cifar10_nonconvergent.json'

with open(CONFIG_FILE) as file:
    config = json.load(file)

transform = tr.Compose([tr.Resize(config['im_sz']),
                        tr.CenterCrop(config['im_sz']),
                        tr.ToTensor(),
                        tr.Normalize(tuple(0.5*torch.ones(config['im_ch'])), tuple(0.5*torch.ones(config['im_ch'])))])

dataset_fmnist_train = datasets.CIFAR10(root='./data/cifar', train=True, download=True, transform=transform)
    
loader = torch.utils.data.DataLoader(dataset_fmnist_train, batch_size=128, shuffle=True, drop_last = True, num_workers=int(8))

Files already downloaded and verified


In [3]:
def sample_from_single_ebm(bs, nc, model, device, step_size=0.01, sample_step=200):
    #inital with uniform(0,1)
    noise = torch.randn(bs, nc, 32,32, device=device)
    neg_sample = Variable(2*torch.rand(bs, nc, 32,32)-1).to(device)
    neg_sample.requires_grad = True
    parameters = model.parameters()

    requires_grad(parameters, False)
    model.eval()

    for k in range(sample_step):
        noise.normal_(0, 1)
        neg_sample.data.add_(step_size, noise.data)
        
        dvalue = model(neg_sample)
        dvalue.sum().backward()

        neg_sample.data.add_(-1, neg_sample.grad.data)

        neg_sample.grad.detach_()
        neg_sample.grad.zero_()
        neg_sample.data.clamp_(-1, 1)
            
    return neg_sample.detach()

def sample_from_stack_2_ebm(model,D_new,bs,device,step_size,sample_step):
        neg_sample = sample_from_single_ebm(bs, 3, model, device, step_size=0.01, sample_step=80)
        neg_sample = Variable(neg_sample).to(device)
        neg_sample.requires_grad = True
        requires_grad(D_new.parameters(), False)
        requires_grad(model.parameters(), False)
        model.eval()
        D_new.eval()
        
        for k in range(sample_step):
            noise = torch.randn(neg_sample.shape[0], 3, 32, 32, device=device)

            noise.normal_(0, 1)
            neg_sample.data.add_(step_size, noise.data)

            dvalue = model(neg_sample) + D_new(neg_sample)
            dvalue.sum().backward()
            #neg_sample.grad.data.clamp_(-0.01, 0.01)
            
            neg_sample.data.add_(-1, neg_sample.grad.data)

            neg_sample.grad.detach_()
            neg_sample.grad.zero_()
            neg_sample.data.clamp_(-1, 1)

        neg_sample = neg_sample.detach()
        return neg_sample

In [3]:
#neg_img = sample_stack(EBM,D_new,device,step_size = 0.01, sample_step = 100)
for i in range(20):        
    bs = 500
    rec2 = sample_stack(EBM, D_new, bs, device, step_size=0.01, sample_step=70)
    with torch.no_grad():
        nc,h,w = rec2.shape[1:] 
        for j in range(rec2.size(0)):
               torchvision.utils.save_image(rec2.view(bs,nc,h,w)[j, :, :, :],
                                            ('./save/sample/stack_ebm_test/{}.png').format(j+i*bs),normalize=True)
        print(i)

0
1
2


KeyboardInterrupt: 

In [None]:
python fid_score.py ../code/samples/cifar/from_dataset ../opt_agg_code/ebm-anatomy-master/save/sample/stack_ebm_test/ --gpu 0

In [36]:
def sample_tri(model1,model2,D_new,bs,device,step_size,sample_step):
    neg_sample = sample_stack(model1,model2,bs, device, step_size=0.01, sample_step=0)

    neg_sample = Variable(neg_sample).to(device)
    neg_sample.requires_grad = True

    requires_grad(D_new.parameters(), False)
    requires_grad(model1.parameters(), False)
    requires_grad(model2.parameters(), False)
    model1.eval()
    model2.eval()
    D_new.eval()

    for k in range(sample_step):
        noise = torch.randn(neg_sample.shape[0], 3, 32, 32, device=device)

        noise.normal_(0, 1)
        neg_sample.data.add_(step_size, noise.data)

        dvalue = model1(neg_sample) + model2(neg_sample) + D_new(neg_sample)
        dvalue.sum().backward()
        #neg_sample.grad.data.clamp_(-0.01, 0.01)

        neg_sample.data.add_(-1, neg_sample.grad.data)

        neg_sample.grad.detach_()
        neg_sample.grad.zero_()
        neg_sample.data.clamp_(-1, 1)

    neg_sample = neg_sample.detach()
    return neg_sample

In [8]:
def sample_data(loader):
    loader_iter = iter(loader)

    while True:
        try:
            yield next(loader_iter)

        except StopIteration:
            loader_iter = iter(loader)

            yield next(loader_iter)

def clip_grad(parameters, optimizer):
    with torch.no_grad():
        for group in optimizer.param_groups:
            for p in group['params']:
                state = optimizer.state[p]

                if 'step' not in state or state['step'] < 1:
                    continue

                step = state['step']
                exp_avg_sq = state['exp_avg_sq']
                _, beta2 = group['betas']

                bound = 3 * torch.sqrt(exp_avg_sq / (1 - beta2 ** step)) + 0.1
                p.grad.data.copy_(torch.max(torch.min(p.grad.data, bound), -bound))
            
def train_tripple(model1,model2, loader,config, device, step_size=0.1, sample_step=100):
    D_new = VanillaNet(n_c = 3, n_f = 48).to(device)
    
    loader = tqdm(enumerate(sample_data(loader)))

    noise = torch.randn(128, 3, 32, 32, device=device)

    parameters = D_new.parameters()
    optimizer = optim.Adam(parameters, lr=0.00005,betas = (0.0,0.999))
    
    for i, (image) in loader:

        image = image[0]
        image = image.to(device)

        neg_sample = sample_from_stack_2_ebm(model1,model2,128, device, step_size=0.01, sample_step=0)
       
        neg_sample = Variable(neg_sample).to(device)
        neg_sample.requires_grad = True
        
        requires_grad(parameters, False)
        requires_grad(model1.parameters(), False)
        requires_grad(model2.parameters(), False)
        model1.eval()
        model2.eval()
        D_new.eval()
        
        for k in range(sample_step):
            if noise.shape[0] != neg_sample.shape[0]:
                noise = torch.randn(neg_sample.shape[0], 3, 32, 32, device=device)

            noise.normal_(0, 1)
            neg_sample.data.add_(step_size, noise.data)

            dvalue = model1(neg_sample) + model2(neg_sample) + D_new(neg_sample)
            dvalue.sum().backward()
            #neg_sample.grad.data.clamp_(-0.01, 0.01)
            
            neg_sample.data.add_(-1, neg_sample.grad.data)

            neg_sample.grad.detach_()
            neg_sample.grad.zero_()
            neg_sample.data.clamp_(-1, 1)

        neg_sample = neg_sample.detach()

        requires_grad(parameters, True)
        D_new.train()

        D_new.zero_grad()
        pos_out = D_new(image)
        neg_out = D_new(neg_sample)

        loss = pos_out - neg_out
        loss = loss.mean()

        loss.backward()

        clip_grad(parameters, optimizer)

        optimizer.step()

        loader.set_description(f'loss: {loss.item():.5f}')
                
        if i % 2000 == 0:      
            neg_img = neg_sample
            vutils.save_image(
                neg_img.detach().to('cpu'),
                './output_dir/stack3/samples/sample_iter_{}.png'.format(i),
                nrow=16,
                normalize=True
            )
        if i % 5000 == 0:
            torch.save(D_new.state_dict(), './output_dir/stack3/checkpoint/EBM_iter_{}.pth'.format(i))
            torch.save(optimizer.state_dict(), './output_dir/stack3/checkpoint/opt.pth')
        if i == 60000:
            break

In [None]:
train_tripple(EBM,D_new, loader,config, device=device, step_size=0.01, sample_step=80)

loss: 0.01524: : 23it [01:50,  4.93s/it] 

In [44]:
D_tri = VanillaNet(n_c = 3).to(device)
state_EBM3 = torch.load('./save/model/single_ebm/EBM_15001.pth', map_location=device)
D_tri.load_state_dict(state_EBM3)

<All keys matched successfully>