In [1]:
#| default_exp fid

# FID

In [2]:
#|export
import pickle,gzip,math,os,time,shutil,torch,random
import fastcore.all as fc,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
from collections.abc import Mapping
from pathlib import Path
from operator import attrgetter,itemgetter
from functools import partial
from copy import copy
from contextlib import contextmanager
from scipy import linalg

from fastcore.foundation import L
import torchvision.transforms.functional as TF,torch.nn.functional as F
from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
from torch.nn import init
from torch.optim import lr_scheduler
from torcheval.metrics import MulticlassAccuracy
from datasets import load_dataset,load_dataset_builder

from miniai.datasets import *
from miniai.conv import *
from miniai.learner import *
from miniai.activations import *
from miniai.init import *
from miniai.sgd import *
from miniai.resnet import *
from miniai.augment import *
from miniai.accel import *

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from fastcore.test import test_close
from torch import distributions

torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray_r'

import logging
logging.disable(logging.WARNING)

set_seed(42)
if fc.defaults.cpus>8: fc.defaults.cpus=8

## Classifier

In [5]:
xl, yl = 'image', 'label'
name = 'fashion_mnist'
bs = 512

@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2)) * 2 - 1 for o in b[xl]]

dsd = load_dataset(name)
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus)

In [6]:
b = xb, yb = next(iter(dls.train))

In [13]:
cbs = [DeviceCB(), MixedPrecision()]
# model = torch.load('models/data_aug2.pkl')
model = torch.load('models/data_aug2.pkl', map_location=torch.device('cpu')if not torch.cuda.is_available() else None)
learn = Learner(model, dls, F.cross_entropy, cbs=cbs, opt_func=None)

In [14]:
def append_outp(hook, mod, inp, outp):
    if not hasattr(hook, 'outp'): hook.outp = []
    hook.outp.append(to_cpu(outp))

In [15]:
hcb = HooksCallback(append_outp, mods=[learn.model[6]], on_valid=True)

In [17]:
learn.fit(1, train=False, cbs=[hcb])





In [24]:
feats = hcb.hooks[0].outp[0].float()[:64]
feats.shape

torch.Size([64, 512])

In [25]:
del(learn.model[8])
del(learn.model[7])

In [26]:
feats, y = learn.capture_preds()
feats = feats.float()
feats.shape, y



(torch.Size([10000, 512]), tensor([9, 2, 1,  ..., 8, 1, 5]))

# Calc FID

In [None]:
betamin,betamax,n_steps = 0.0001,0.02,1000
beta = torch.linspace(betamin, betamax, n_steps)
alpha = 1.-beta
alphabar = alpha.cumprod(dim=0)
sigma = beta.sqrt()

In [27]:
def noisify(x0, ᾱ):
    device = x0.device
    n = len(x0)
    t = torch.randint(0, n_steps, (n,), dtype=torch.long)
    ε = torch.randn(x0.shape, device=device)
    ᾱ_t = ᾱ[t].reshape(-1, 1, 1, 1).to(device)
    xt = ᾱ_t.sqrt()*x0 + (1-ᾱ_t).sqrt()*ε
    return (xt, t.to(device)), ε

def collate_ddpm(b): return noisify(default_collate(b)[xl], alphabar)
def dl_ddpm(ds): return DataLoader(ds, batch_size=bs, collate_fn=collate_ddpm, num_workers=fc.defaults.cpus)

In [29]:
dls2 = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test']))

In [31]:
from diffusers import UNet2DModel

class Unet(UNet2DModel):
    def forward(self, x): return super().forward(*x).sample

In [None]:
smodel = torch.load('models/fashion_ddpm_mp[].pkl', map_location=torch.device('cpu')if not torch.cuda.is_available() else None)