In [None]:
!git clone https://github.com/NVlabs/edm
!cp ../input/diffusion-2/fid.py edm/fid.py
%cd edm
from fid import calculate_inception_stats, calculate_fid_from_inception_stats
from dnnlib.util import open_url
%cd ..

def calc_fid(image_path, ref_path, num_expected, batch):
    mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, max_batch_size=batch)
    mu_ref, sigma_ref = calculate_inception_stats(image_path=ref_path, num_expected=num_expected, max_batch_size=batch)
    print(mu, sigma, mu_ref, sigma_ref)
    fid = calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref)
    return fid

In [48]:
def plot_scheduling():
    t = np.linspace(0, 1, 200) 
    beta_t = np.piecewise(t, [t <= 0.5, t > 0.5], [lambda t: 2.4e-4 * t, lambda t: 2.4e-4 * (1 - t)]) 
     
    plt.figure(figsize=(4, 3)) 
    plt.plot(t, beta_t, color='blue') 
     
    plt.xlabel("time $t$") 
    plt.title("$\\beta_t$") 
    plt.ticklabel_format(axis='y', style='sci', scilimits=(0, 0), useMathText=True) 
    plt.gca().yaxis.get_offset_text().set_position((-0.1, 0)) 

    plt.xticks([0, 0.5, 1]) 
    plt.yticks([0.6 * 1e-4, 1.2 * 1e-4])
     
    plt.grid(alpha=0.5) 
    plt.tight_layout() 
    plt.savefig("beta-scheduling.pdf", bbox_inches='tight')
    plt.show()

def plot_sigmas(): # TO DO
    t = np.linspace(0, 1, 200) 
    beta_t = np.piecewise(t, [True], [lambda t: 2.4e-4 * t, lambda t: 2.4e-4 * (1 - t)]) 
     
    plt.figure(figsize=(4, 3)) 
    plt.plot(t, beta_t, color='blue') 
     
    plt.xlabel("time $t$") 
    plt.title("$\\beta_t$") 
    plt.ticklabel_format(axis='y', style='sci', scilimits=(0, 0), useMathText=True) 
    plt.gca().yaxis.get_offset_text().set_position((-0.1, 0)) 

    plt.xticks([0, 0.5, 1]) 
    plt.yticks([0.6 * 1e-4, 1.2 * 1e-4])
     
    plt.grid(alpha=0.5) 
    plt.tight_layout() 
    plt.savefig("beta-scheduling.pdf")
    plt.show()

def plot_pixel_1d(model, batch, channel, x, y):
    _, trajectory = model.sampling(batch.to(model.device), np.arange(0, model.num_steps, 1))
    pixel_trajectories = torch.stack(trajectory)[:,:,channel,x,y].cpu()
    colors = plt.cm.viridis(np.linspace(0, 1, pixel_trajectories.shape[1]))
    
    # Plot each object's trajectory
    plt.figure(figsize=(10, 6))
    for obj in range(pixel_trajectories.shape[1]):
        plt.plot(pixel_trajectories[:, obj],  color=colors[obj], linewidth=2, linestyle='-')
    
    plt.xlabel('Time', fontsize=14)
    plt.ylabel('Blue channel', fontsize=14)
    plt.title('Trajectories of pixels by 1 channel', fontsize=16, fontweight='bold')
    plt.grid(alpha=0.5)

    plt.tight_layout()
    plt.savefig("pixel-trajectories-1d-blur-mnist.pdf", bbox_inches='tight')
    plt.show()

def plot_pixel_2d(model, batch, channel1, channel2, x, y):
    _, trajectory = model.sampling(batch.to(model.device), np.arange(0, model.num_steps, 1))
    pixel_trajectories1 = torch.stack(trajectory)[:,:,channel1,x,y].cpu()
    pixel_trajectories2 = torch.stack(trajectory)[:,:,channel2,x,y].cpu()
    colors = plt.cm.viridis(np.linspace(0, 1, pixel_trajectories1.shape[1]))
    
    # Plot each object's trajectory
    fig = plt.figure(figsize=(10, 6))
    ax = fig.add_subplot(111, projection='3d')
    
    for obj in range(pixel_trajectories1.shape[1]):
        ax.plot(pixel_trajectories1[:, obj], pixel_trajectories2[:, obj], np.arange(model.num_steps), color=colors[obj], linewidth=2, linestyle='-')

    ax.set_xlabel('Green channel', fontsize=14)
    ax.set_ylabel('Red channel', fontsize=14)
    ax.set_zlabel('Time', fontsize=14, rotation=90)
    ax.set_title('Trajectories of pixels by 2 channels', fontsize=16, fontweight='bold')
    ax.set_box_aspect(None, zoom=0.93)
    plt.tight_layout()
    plt.savefig("pixel-trajectories-2d-blur-mnist.pdf", bbox_inches='tight')
    plt.show()

def calc_l2(model, dataloader, num_batches, steps, stochastic):
    count = 0
    lerror = 0
    model.stochastic = stochastic
    print(steps)
    
    for batch in dataloader:
        count += 1
        x0 = (batch[0] * 2 - 1).to(model.device)
        batch_normed = model.corruption(x0).to(model.device)
        samples, trajes = model.sampling(batch_normed, steps)
        print(len(trajes))
        lerror += torch.mean(torch.square(x0 - samples))
        visualize_batch(batch[0][0:10])
        visualize_batch(batch_normed[0:10].cpu() * 0.5 + 0.5)
        visualize_batch(samples[0:10].cpu() * 0.5 + 0.5)

        if count >= num_batches:
            break

    return lerror / count
        

def plot_l2(model, errors, steps):  
    ind = 0

    plt.figure(figsize=(10, 6))
    plt.plot(steps, errors[0], label="Non-stochastic sampling", marker="o", linestyle="-", linewidth=1.5, markersize=4)
    plt.plot(steps, errors[1], label="Stochastic sampling", marker="s", linestyle="--", linewidth=1.5, markersize=4)
    
    # Beautify the plot
    plt.title("L2 Errors vs Iterations", fontsize=16)
    plt.xlabel("Number of Iterations", fontsize=14)
    plt.ylabel("L2 Error", fontsize=14)
    plt.grid(True, linestyle="--", alpha=0.6)
    plt.legend(fontsize=12)
    plt.tight_layout()
    plt.savefig("l2-error-blur-mnist.pdf", bbox_inches='tight')
    plt.show()

In [31]:
steps = [2, 5, 10, 15, 20, 30, 40, 50, 60, 70, 100]
errors = np.zeros(shape=(2, len(steps)))
num_batches = 4

In [None]:
#print(errors)
#plot_l2(model, errors, steps)
#plot_pixel_2d(model, model.corruption(batch[0][:40] * 2 - 1), 0, 1, 30, 30)
#plot_pixel_1d(model, model.corruption(batch[0][:40] * 2 - 1), 2, 30, 30)
#batch = next(iter(train_dataloader))
visualize_batch(batch[0][:10], title='target-images')
visualize_batch(model.corruption(batch[0][:10]), title='blurred-images')
sampled,_ = model.sampling(model.corruption(batch[0] * 2 - 1)[:10].to(model.device), np.arange(0, model.num_steps, 1))
visualize_batch(sampled.cpu() * 0.5 + 0.5, title='sampled-images')
#for i in range(10):
smpl = model.corruption(batch[0] * 2 - 1)[8].to(model.device).unsqueeze(0)
print(smpl.shape)
print(smpl.repeat(10, 1, 1, 1).shape)
sampled,_ = model.sampling(smpl.repeat(10, 1, 1, 1), np.arange(0, model.num_steps, 1))
visualize_batch(sampled.cpu() * 0.5 + 0.5, title='stochastic-generation-twos')
#plot_scheduling()

In [33]:
import os, shutil

In [None]:
def save_model_samples(name, model, test_dataloader, apply_gen=True, **model_kwargs):
    if os.path.exists(name):
        shutil.rmtree(name)
    os.makedirs(name, exist_ok=True)

    count = 0
    count_batches = 0
    with tqdm(total=len(test_dataloader)) as pbar:
        for batch in test_dataloader:
            x0 = (batch[0] * 2 - 1).to(model.device)
            if apply_gen:
                x1 = model.corruption(x0)
                out, trajectory = model.sampling(x1, np.arange(0, model.num_steps, 1))
                out = (out * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
            else:
                out = (x0 * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
            
            for i in range(out.shape[0]):
                img = Image.fromarray(out[i])
                n_digits = len(str(count))
                img_name = (6 - n_digits) * '0' + str(count) + '.png'
                count += 1
                img.save(os.path.join(name, img_name))
            
            pbar.update(1)
            count_batches += 1
            pbar.set_description('%d batches saved' % (count_batches,))

#print(len(test_dataloader) * 128)
#save_model_samples('mnist-samples', model, test_dataloader)
#save_model_samples('mnist-train-data', model, test_dataloader, False)
#fid = calc_fid('mnist-samples', "../input/diffusion-2/cmnist_train.npz", num_expected=10000, batch=256)
#print('Модель имеет FID = %.4g' % (fid,))

In [None]:
def save_model_samples_nonstoch(name, model, test_dataloader, apply_gen=True, **model_kwargs):
    if os.path.exists(name):
        shutil.rmtree(name)
    os.makedirs(name, exist_ok=True)

    count = 0
    count_batches = 0
    model.stochastic = False
    with tqdm(total=len(test_dataloader)) as pbar:
        for batch in test_dataloader:
            x0 = (batch[0] * 2 - 1).to(model.device)
            if apply_gen:
                x1 = model.corruption(x0)
                out, _ = model.sampling(x1, [0, 99])
                out = (out * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
            else:
                out = (x0 * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
            
            for i in range(out.shape[0]):
                img = Image.fromarray(out[i])
                n_digits = len(str(count))
                img_name = (6 - n_digits) * '0' + str(count) + '.png'
                count += 1
                img.save(os.path.join(name, img_name))
            
            pbar.update(1)
            count_batches += 1
            pbar.set_description('%d batches saved' % (count_batches,))

print(len(test_dataloader) * 128)
save_model_samples_nonstoch('mnist-samples-nonstoch', model, test_dataloader)
save_model_samples('mnist-train-data', model, test_dataloader, False)
fid = calc_fid('mnist-samples-nonstoch', "mnist-train-data", num_expected=10000, batch=256)
print('Модель имеет FID = %.4g' % (fid,))
model.stochastic = True