In [None]:
# Configuration

model_path = 'checkpoint.pt' # path to model checkpoint
n_samples = 2048 # number of samples to generate
generate_batch_size = 1024 # number of fakes to sample simultaneously
data_loader_batch_size = 128 # batch size for real images
fid_feature = 2048 # an integer will indicate the inceptionv3 feature layer to choose (64, 192, 768, or 2048)

In [None]:
#@title Installing dependencies
!pip3 install -q torchmetrics torch-fidelity

In [None]:
#@title Models

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm


class Diffusion:
    '''
    the diffusion class

    noise_steps: number of diffusion noising steps
    betas: start and end for the variance schedule
    img_size: generated image size
    device: the device to use for training
    '''
    def __init__(self, noise_steps=1000, betas=(1e-4, 2e-2), img_size=64, device='cuda'):
        self.noise_steps = noise_steps
        self.betas = betas
        self.img_size = img_size
        self.device = device

        # prepare the noise variance schedule and various constants:
        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1 - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)
        self.sqrt_alpha_hat = torch.sqrt(self.alpha_hat)
        self.sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat)

    # computes a linear noise schedule
    def prepare_noise_schedule(self):
        return torch.linspace(*self.betas, self.noise_steps)

    # noises the data x to timestep t
    def noise_data(self, x, t):
        # get constants
        sqrt_alpha_hat = self.sqrt_alpha_hat[t][:, None, None, None]
        sqrt_one_minus_alpha_hat = self.sqrt_one_minus_alpha_hat[t][:, None, None, None]

        # generate noise
        e = torch.randn_like(x)

        # compute scaled signal and noise
        signal = sqrt_alpha_hat * x
        noise = sqrt_one_minus_alpha_hat * e
        return signal + noise, e

    # samples time steps (just returns random ints as timesteps)
    def sample_timesteps(self, n):
        return torch.randint(1, self.noise_steps, size=(n,))

    # sample n datapoints from the model
    # labels: optional labels if the model was trained conditionally
    # cfg_scale:
    @torch.no_grad()
    def sample(self, model, n, labels=None, cfg_scale=3):
        model.eval()

        # start with gaussian noise
        x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)

        # denoise data
        for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
            # create the timestep tensor that encodes the current timestep
            t = (torch.ones(n) * i).long().to(self.device)

            # predict the noise
            predicted_noise = model(x, t, labels)

            # classifier-free guidance: linearly interpolate between
            # unconditional and conditional (above) samples
            if labels:
                uncond_predicted_noise = model(x, t)
                predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)

            # compute scaling constants
            alpha = self.alpha[t][:, None, None, None]
            alpha_hat = self.alpha_hat[t][:, None, None, None]
            beta = self.beta[t][:, None, None, None]

            # remove a small bit of noise from x
            noise = torch.randn_like(x) if i > 1 else torch.zeros_like(x)
            signal_scale = 1 / torch.sqrt(alpha)
            pred_noise_scale = (1 - alpha) / (torch.sqrt(1 - alpha_hat))
            scaled_noise = torch.sqrt(beta) * noise
            signal = x - (pred_noise_scale * predicted_noise)

            x = (signal_scale * signal) + scaled_noise

        # set the model back to training mode
        model.train()

        # clamp and rescale x values to [0, 1] (output was [-1, 1]):
        x = (x.clamp(-1, 1) + 1) / 2

        # convert x to valid pixel range:
        x = (x * 255).type(torch.uint8)

        return x


class PatchingLayer(nn.Module):
    '''
    Splits the input image into non-overlapping patches.
    '''
    def __init__(self, patch_size, num_channels):
        super(PatchingLayer, self).__init__()
        self.patch_size = patch_size
        self.num_channels = num_channels

    def forward(self, x):
        # Split the image into patches
        b, c, h, w = x.size()
        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.permute(0, 2, 3, 1, 4, 5).contiguous()
        x = x.view(b, -1, c * self.patch_size * self.patch_size)
        return x


class PositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model

    def forward(self, x):
        B, N, E = x.size()  # Batch size, Number of tokens, Embedding dimension
        pos = torch.arange(0, N).unsqueeze(0).unsqueeze(-1).to(x.device).float()  # Shape: [1, N, 1]
        div_term = torch.exp(torch.arange(0, E, 2).float() * -(math.log(10000.0) / E)).to(x.device)  # Shape: [E//2]
        div_term = div_term.unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, E//2]

        pos_enc = torch.zeros_like(x)  # Shape: [B, N, E]

        pos_enc[:, :, 0::2] = torch.sin(pos * div_term)  # Apply to even indices
        pos_enc[:, :, 1::2] = torch.cos(pos * div_term)  # Apply to odd indices

        return pos_enc


class OutputLayer(nn.Module):
    def __init__(self, d_model, patch_size, num_channels):
        super(OutputLayer, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Linear(d_model * 4, num_channels * patch_size * patch_size)
        )
        self.patch_size = patch_size
        self.num_channels = num_channels

    def forward(self, x):
        # Reshape tokens back to patches
        x = self.mlp(x)
        b, n, _ = x.size()
        x = x.view(b, n, self.num_channels, self.patch_size, self.patch_size)

        # Reconstruct the original image dimensions from the patches
        h_dim = w_dim = int((n)**0.5)
        x = x.view(b, h_dim, w_dim, self.num_channels, self.patch_size, self.patch_size)
        x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
        x = x.view(b, self.num_channels, h_dim * self.patch_size, w_dim * self.patch_size)

        return x


class ImageTransformer(nn.Module):
    '''
    Full model architecture
    '''
    def __init__(self, d_model, nhead, num_layers, patch_size, num_classes=None, num_channels=3, dropout=0.05, window_size=4):
        super(ImageTransformer, self).__init__()
        self.d_model = d_model
        self.patching_layer = PatchingLayer(patch_size, num_channels)
        self.projection = nn.Linear(num_channels * patch_size * patch_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead, batch_first=True, dropout=dropout),
            num_layers
        )
        self.output_layer = OutputLayer(d_model, patch_size, num_channels)

        if num_classes:
            self.label_emb = nn.Embedding(num_classes, d_model)

    def forward(self, x, t, label=None):
        # compute positional encoding for the timestep (len(t), self.time_dim)
        t = t.unsqueeze(-1).float()
        t = self._time_embedding(t)

        # class-conditioning
        if label:
            t = t + self.label_emb(label)

        x = self.patching_layer(x)
        x = self.projection(x)
        residual = x
        x = x + t + self.positional_encoding(x)
        x = self.encoder(x) + residual
        x = self.output_layer(x)
        return x

    def _time_embedding(self, t):
        inv_freq = 1.0 / (10000 ** (torch.arange(0, self.d_model, 2, dtype=torch.float32).to(t.device) / self.d_model))

        # Create the sine and cosine encodings
        pos_enc_sin = torch.sin(t.unsqueeze(1).float() * inv_freq)
        pos_enc_cos = torch.cos(t.unsqueeze(1).float() * inv_freq)

        # Concatenate the sine and cosine encodings
        pos_enc = torch.cat([pos_enc_sin, pos_enc_cos], dim=-1)

        return pos_enc

    @property
    def n_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

class TBlock(nn.Module):
    def __init__(self, d_model, nhead, num_layers, patch_size, num_classes=None, num_channels=3, dropout=0.05):
        super(TBlock, self).__init__()
        self.d_model = d_model
        self.patching_layer = PatchingLayer(patch_size, num_channels)
        self.projection = nn.Linear(num_channels * patch_size * patch_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead, batch_first=True, dropout=dropout),
            num_layers
        )
        self.output_layer = OutputLayer(d_model, patch_size, num_channels)

        if num_classes:
            self.label_emb = nn.Embedding(num_classes, d_model)

    def forward(self, x, t, label=None):
        # compute positional encoding for the timestep (len(t), self.time_dim)
        t = t.unsqueeze(-1).float()
        t = self._time_embedding(t)

        # class-conditioning
        if label:
            t = t + self.label_emb(label)

        x = self.patching_layer(x)
        x = self.projection(x)
        residual = x
        x = x + t + self.positional_encoding(x)
        x = self.encoder(x) + residual
        x = self.output_layer(x)
        return x

    def _time_embedding(self, t):
        inv_freq = 1.0 / (10000 ** (torch.arange(0, self.d_model, 2, dtype=torch.float32).to(t.device) / self.d_model))

        # Create the sine and cosine encodings
        pos_enc_sin = torch.sin(t.unsqueeze(1).float() * inv_freq)
        pos_enc_cos = torch.cos(t.unsqueeze(1).float() * inv_freq)

        # Concatenate the sine and cosine encodings
        pos_enc = torch.cat([pos_enc_sin, pos_enc_cos], dim=-1)

        return pos_enc


class TNet(nn.Module):
    def __init__(self, num_blocks, d_model, nhead, num_layers, patch_size, num_channels=3, dropout=0.05):
        super(TNet, self).__init__()

        # Encoder and Decoder blocks
        self.blocks = nn.ModuleList()
        for _ in range(num_blocks * 2):
            self.blocks.append(TBlock(d_model, nhead, num_layers, patch_size, num_channels=num_channels, dropout=dropout))

        self.pool = nn.AvgPool2d(2, 2)  # For downsampling
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')  # For upsampling

    def forward(self, x, t, label=None):
        # Encoding path with downsampling
        encoder_outs = []
        for i in range(len(self.blocks) // 2):
            x = self.blocks[i](x, t, label)
            encoder_outs.append(x)
            x = self.pool(x)

        # Decoding path with upsampling
        for i in range(len(self.blocks) // 2, len(self.blocks)):
            x = self.upsample(x)
            x += encoder_outs.pop()  # Skip-connection
            x = self.blocks[i](x, t, label)

        return x

    @property
    def n_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [None]:
#@title Generate n samples

def generate_samples(diffusion, model, n: int, batch_size: int = 64) -> torch.Tensor:
    samples = []

    while n > 0:
        # Generate the minimum of n or batch_size samples
        current_batch_size = min(n, batch_size)
        with torch.no_grad():
            # Sample images from the model
            sample = diffusion.sample(model, n=current_batch_size).cpu()
        samples.append(sample)
        n -= current_batch_size

    all_samples = torch.cat(samples, dim=0)
    return all_samples

diffusion = Diffusion(img_size=32)

if False:
    model = TNet(
        num_blocks=3,
        d_model=128,
        nhead=8,
        num_layers=3,
        patch_size=4,
        num_channels=3
    ).cuda()
else:
    model = ImageTransformer(
        d_model=128,
        nhead=32,
        num_layers=4,
        patch_size=4,
        num_channels=3
    ).cuda()

model.load_state_dict(torch.load(model_path, map_location='cuda'))
model.eval()

# Generate samples
generated_samples = generate_samples(diffusion, model, n=n_samples, batch_size=generate_batch_size)

999it [03:29,  4.77it/s]
999it [03:32,  4.69it/s]


In [None]:
import pickle
with open('a.p', 'rb') as f:
    generated_samples = pickle.load(f)

In [None]:
#@title Compute FID
from torchvision import datasets
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchmetrics.image.fid import FrechetInceptionDistance as FID

def float2int(images):
    images = (images.clamp(-1, 1) + 1) / 2
    images = (images * 255).type(torch.uint8)
    return images

transforms = T.Compose([
    T.Resize(32),
    T.ToTensor(),
    T.Normalize(0.5, 0.5)
])

val_dataset = datasets.CIFAR10('.', train=True, transform=transforms, download=True)
val_dataloader = DataLoader(val_dataset, batch_size=data_loader_batch_size, shuffle=True)

fid = FID(feature=fid_feature).cuda()

left = len(val_dataset) # n_samples
while left > 0:
    n_generate = min(left, data_loader_batch_size)
    reals = float2int(next(b for b in val_dataloader)[0]).cuda()
    fid.update(reals, real=True)
    left -= n_generate
    print(f'{left} real images left')
    del reals

for i in range(0, n_samples, data_loader_batch_size):
    fakes = generated_samples[i:i+data_loader_batch_size].cuda()
    fid.update(fakes, real=False)
    del fakes
# fid.update(generated_samples.cuda(), real=False)
score = fid.compute().item()
fid.reset()
print(f'FID: {score}')

Files already downloaded and verified
49872 real images left
49744 real images left
49616 real images left
49488 real images left
49360 real images left
49232 real images left
49104 real images left
48976 real images left
48848 real images left
48720 real images left
48592 real images left
48464 real images left
48336 real images left
48208 real images left
48080 real images left
47952 real images left
47824 real images left
47696 real images left
47568 real images left
47440 real images left
47312 real images left
47184 real images left
47056 real images left
46928 real images left
46800 real images left
46672 real images left
46544 real images left
46416 real images left
46288 real images left
46160 real images left
46032 real images left
45904 real images left
45776 real images left
45648 real images left
45520 real images left
45392 real images left
45264 real images left
45136 real images left
45008 real images left
44880 real images left
44752 real images left
44624 real images l

In [None]:
#@title Optional: save images

output_filename = 'cifar10-2.jpg' # @param {type:"string"}

from torchvision.utils import make_grid
from PIL import Image

def save_images(images, path, **kwargs):
    grid = make_grid(images, **kwargs)
    arr = grid.permute(1, 2, 0).cpu().numpy()
    im = Image.fromarray(arr)
    im.save(path)

save_images(generated_samples, output_filename)
print('Successfully saved images.')

0it [00:00, ?it/s]