In [1]:
!nvidia-smi -L
import torch, torchvision
print("Torch:", torch.__version__, "CUDA:", torch.version.cuda, "Available:", torch.cuda.is_available())

!pip install -q torch-fidelity==0.3.0

GPU 0: NVIDIA GeForce RTX 4090 (UUID: GPU-21484c64-3745-f39d-e1d4-e2e5fb24c475)
Torch: 2.1.2 CUDA: 12.1 Available: True
[0m

In [2]:
from pathlib import Path
import os, random, numpy as np
from PIL import Image

def seed_all(s=42):
    random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)
seed_all()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DATA_TRAIN = Path("data/raw/train/train")
DATA_VAL   = Path("data/raw/test/test")

OUTDIR = Path("./outputs_spade_finetune"); OUTDIR.mkdir(exist_ok=True)
CKPT_DIR   = OUTDIR/"checkpoints"; CKPT_DIR.mkdir(exist_ok=True)
SAMPLE_DIR = OUTDIR/"samples"; SAMPLE_DIR.mkdir(exist_ok=True)
PRED_DIR   = OUTDIR/"pred"; PRED_DIR.mkdir(exist_ok=True)
GT_DIR     = OUTDIR/"gt_photos"; GT_DIR.mkdir(exist_ok=True)

In [3]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch

class RandomJitter256:
    def __init__(self, jitter=True): self.jitter=jitter
    def __call__(self, img):
        img = img.resize((512,256), Image.BICUBIC)
        if self.jitter and random.random()<0.5: img = img.transpose(Image.FLIP_LEFT_RIGHT)
        return img

to_tensor_norm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3,[0.5]*3)
])

class SideBySideDataset(Dataset):
    def __init__(self, root, train=True):
        self.root = Path(root); self.files = sorted(self.root.glob("*.jpg"))
        self.rj = RandomJitter256(jitter=train)
    def __len__(self): return len(self.files)
    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert("RGB")
        img = self.rj(img)
        w,h = img.size; w2=w//2
        left = img.crop((0,0,w2,h)).resize((256,256),Image.BICUBIC)
        right= img.crop((w2,0,w,h)).resize((256,256),Image.BICUBIC)
        return to_tensor_norm(right), to_tensor_norm(left), self.files[idx].name

BATCH=8
train_loader = DataLoader(SideBySideDataset(DATA_TRAIN,True), batch_size=BATCH, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
val_loader   = DataLoader(SideBySideDataset(DATA_VAL,False),  batch_size=8,   shuffle=False, num_workers=4, pin_memory=True)

cond,target,_=next(iter(train_loader))
print("Shapes -> cond:",cond.shape," target:",target.shape)
assert cond.shape[-1]==256 and cond.shape[-2]==256

Shapes -> cond: torch.Size([8, 3, 256, 256])  target: torch.Size([8, 3, 256, 256])


In [9]:
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#  SPADE
class SPADE(nn.Module):
    def __init__(self, ch, seg_nc=3, hidden=128):
        super().__init__()
        self.norm = nn.InstanceNorm2d(ch, affine=False)
        self.mlp  = nn.Sequential(
            nn.Conv2d(seg_nc, hidden, 3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(hidden, ch*2, 3, padding=1),
        )
    def forward(self, x, seg):
        seg = F.interpolate(seg, size=x.shape[-2:], mode="nearest")
        h   = self.norm(x)
        gamma, beta = torch.chunk(self.mlp(seg), 2, dim=1)
        return h*(1+gamma) + beta

class SPADEConv(nn.Module):
    def __init__(self, in_c, out_c, seg_nc=3, k=4, s=2, p=1, dropout=False):
        super().__init__()
        self.conv  = nn.Conv2d(in_c, out_c, k, s, p, bias=False)
        self.spade = SPADE(out_c, seg_nc)
        self.act   = nn.LeakyReLU(0.2, True)
        self.dp    = nn.Dropout(0.5) if dropout else nn.Identity()
    def forward(self, x, seg):
        return self.dp(self.act(self.spade(self.conv(x), seg)))

class SPADEResUp(nn.Module):
    def __init__(self, in_c, out_c, seg_nc=3, dropout=False):
        super().__init__()
        self.up   = nn.Upsample(scale_factor=2, mode="nearest")  # <- important
        self.conv = SPADEConv(in_c, out_c, seg_nc, k=3, s=1, p=1, dropout=dropout)
    def forward(self, x, seg):
        return self.conv(self.up(x), seg)

class UNetGenerator_SPADE(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, seg_nc=3):
        super().__init__()
        #  Encoder
        self.d1 = SPADEConv(in_ch,   64,  seg_nc)  # 256→128
        self.d2 = SPADEConv(64,     128, seg_nc)   # 128→64
        self.d3 = SPADEConv(128,    256, seg_nc)   # 64→32
        self.d4 = SPADEConv(256,    512, seg_nc)   # 32→16
        self.d5 = SPADEConv(512,    512, seg_nc)   # 16→8
        self.d6 = SPADEConv(512,    512, seg_nc)   # 8→4
        self.d7 = SPADEConv(512,    512, seg_nc)   # 4→2  ← bottleneck

        # Decoder
        # Each upsample doubles spatial size and halves channels, then concatenates skip features
        self.u1 = SPADEResUp(512, 512, seg_nc, dropout=True)   # 2→4,  concat d6 (512) => 1024
        self.u2 = SPADEResUp(1024,512, seg_nc, dropout=True)   # 4→8,  concat d5 => 1024
        self.u3 = SPADEResUp(1024,512, seg_nc, dropout=True)   # 8→16, concat d4 => 1024
        self.u4 = SPADEResUp(1024,512, seg_nc)                 # 16→32, concat d3 => 768
        self.u5 = SPADEResUp(768,256, seg_nc)                  # 32→64, concat d2 => 384
        self.u6 = SPADEResUp(384,128, seg_nc)                  # 64→128, concat d1 => 192
        self.u7 = SPADEResUp(192,64,  seg_nc)                  # 128→256

        self.outc = nn.Conv2d(64, out_ch, 3, 1, 1)
        self.tanh = nn.Tanh()

    def forward(self, seg):
        #  Encoder
        d1 = self.d1(seg, seg)   # 128
        d2 = self.d2(d1, seg)    # 64
        d3 = self.d3(d2, seg)    # 32
        d4 = self.d4(d3, seg)    # 16
        d5 = self.d5(d4, seg)    # 8
        d6 = self.d6(d5, seg)    # 4
        d7 = self.d7(d6, seg)    # 2

        # Decoder
        u1 = self.u1(d7, seg);  u1 = torch.cat([u1, d6], 1)
        u2 = self.u2(u1, seg);  u2 = torch.cat([u2, d5], 1)
        u3 = self.u3(u2, seg);  u3 = torch.cat([u3, d4], 1)
        u4 = self.u4(u3, seg);  u4 = torch.cat([u4, d3], 1)
        u5 = self.u5(u4, seg);  u5 = torch.cat([u5, d2], 1)
        u6 = self.u6(u5, seg);  u6 = torch.cat([u6, d1], 1)
        u7 = self.u7(u6, seg)
        return self.tanh(self.outc(u7))

#  Discriminator (SN PatchGAN)
def snconv(ic, oc, k, s, p, bias=True):
    return nn.utils.spectral_norm(nn.Conv2d(ic, oc, k, s, p, bias=bias))

class PatchDiscriminatorSN(nn.Module):
    def __init__(self, in_ch=6):
        super().__init__()
        self.blocks = nn.ModuleList([
            nn.Sequential(snconv(in_ch,   64, 4,2,1),              nn.LeakyReLU(0.2,True)),
            nn.Sequential(snconv(64,     128, 4,2,1, False), nn.InstanceNorm2d(128), nn.LeakyReLU(0.2,True)),
            nn.Sequential(snconv(128,    256, 4,2,1, False), nn.InstanceNorm2d(256), nn.LeakyReLU(0.2,True)),
            nn.Sequential(snconv(256,    512, 4,1,1, False), nn.InstanceNorm2d(512), nn.LeakyReLU(0.2,True)),  # stop downsampling here
        ])
        self.head = snconv(512, 1, 4,1,1)

    def forward(self, x):
        feats = []
        h = x
        for b in self.blocks:
            h = b(h); feats.append(h)
        return self.head(h), feats

# instantiate
G = UNetGenerator_SPADE().to(device)
D = PatchDiscriminatorSN().to(device)
print(f"Params G: {sum(p.numel() for p in G.parameters())/1e6:.2f}M, D: {sum(p.numel() for p in D.parameters())/1e6:.2f}M")

Params G: 45.74M, D: 2.77M


In [14]:
from torchvision.models import vgg16, VGG16_Weights
import torch.nn.functional as F
from torch_fidelity import calculate_metrics

def d_hinge(r,f): return (F.relu(1.-r).mean()+F.relu(1.+f).mean())
def g_hinge(f): return -f.mean()
l1 = nn.L1Loss()

vgg=vgg16(weights=VGG16_Weights.IMAGENET1K_FEATURES).features[:16].eval().to(device)
for p in vgg.parameters(): p.requires_grad=False
def perceptual(x,y): return l1(vgg(x),vgg(y))
def denorm(x): return (x*0.5+0.5).clamp(0,1)

# fine-tuned weights
LAMBDA_L1, LAMBDA_FM, LAMBDA_VGG = 35.0, 5.0, 5.0
LR_G, LR_D = 2e-4, 2.5e-4
EPOCHS, SAVE_EVERY, EVAL_EVERY = 250, 50, 10

opt_G=torch.optim.Adam(G.parameters(),lr=LR_G,betas=(0.5,0.999))
opt_D=torch.optim.Adam(D.parameters(),lr=LR_D,betas=(0.5,0.999))
scheduler_G=torch.optim.lr_scheduler.StepLR(opt_G,step_size=400,gamma=0.5)
scheduler_D=torch.optim.lr_scheduler.StepLR(opt_D,step_size=400,gamma=0.5)

def compute_fid(gt_dir, pred_dir, cuda=True):
    m = calculate_metrics(
        input1=str(gt_dir),
        input2=str(pred_dir),
        cuda=cuda and torch.cuda.is_available(),
        isc=False, fid=True, kid=False, verbose=False
    )
    return float(m['frechet_inception_distance'])

In [15]:
from torchvision.utils import make_grid
from torchvision import transforms
from tqdm.auto import tqdm

def train_spade(train_loader,val_loader):
    best_fid=float("inf")
    for epoch in range(1,EPOCHS+1):
        G.train(); D.train()
        epG=epD=epL1=0.
        pbar=tqdm(train_loader,desc=f"Epoch {epoch}/{EPOCHS}")
        for cond,target,_ in pbar:
            cond,target=cond.to(device),target.to(device)

            # --- tiny contrast jitter ---
            if random.random()<0.25:
                cond=cond*(0.9+0.2*torch.rand(1,device=cond.device))

            # ----- D -----
            with torch.no_grad(): 
                fake=G(cond)
            rp=torch.cat([cond,target],1); 
            fp=torch.cat([cond,fake],1)
            
            D.zero_grad(set_to_none=True)
            rlog,_=D(rp);
            flog,_=D(fp)
            loss_D=d_hinge(rlog,flog)
            loss_D.backward(); 
            opt_D.step()

            # ----- G -----
            G.zero_grad(set_to_none=True)
            gen=G(cond); 
            gp=torch.cat([cond,gen],1)
            glog,gf=D(gp)
            loss_G_adv=g_hinge(glog)
            loss_G_L1=l1(gen,target)*LAMBDA_L1
            
            with torch.no_grad():
                _,rf=D(rp)
            loss_G_FM=sum(l1(g,r) for g,r in zip(gf,rf))*LAMBDA_FM
            loss_G_VGG=perceptual(denorm(gen),denorm(target))*LAMBDA_VGG
            
            loss_G=loss_G_adv+loss_G_L1+loss_G_FM+loss_G_VGG
            loss_G.backward(); 
            opt_G.step()

            epG+=loss_G.item(); 
            epD+=loss_D.item(); 
            epL1+=loss_G_L1.item()
            pbar.set_postfix(D=f"{loss_D.item():.3f}",G=f"{loss_G.item():.2f}",L1=f"{loss_G_L1.item():.2f}")

        scheduler_G.step(); scheduler_D.step()
        avgG,avgD,avgL1=epG/len(train_loader),epD/len(train_loader),epL1/len(train_loader)
        print(f"Epoch {epoch}: G={avgG:.3f} D={avgD:.3f} L1={avgL1:.3f}")

        # save sample grid every 10
        if epoch%10==0:
            G.eval()
            with torch.no_grad():
                c,t,_=next(iter(val_loader)); c=c.to(device)
                g=G(c)
                grid=make_grid(torch.cat([denorm(c.cpu()),denorm(g.cpu()),denorm(t)],0),nrow=c.size(0))
                img=transforms.ToPILImage()(grid)
                img.save(SAMPLE_DIR/f"epoch_{epoch:03d}.jpg",quality=95)

        # checkpoints
        if epoch%SAVE_EVERY==0:
            ckpt=CKPT_DIR/f"epoch_{epoch}.pt"
            torch.save({"G":G.state_dict(),"D":D.state_dict(),"epoch":epoch},ckpt)
            print("Checkpoint saved:",ckpt)

        # generate preds + FID 
        if epoch%EVAL_EVERY==0:
            for p in PRED_DIR.glob("*.jpg"): p.unlink()
            G.eval()
            with torch.no_grad():
                for c,_,names in tqdm(val_loader,desc="Predict val"):
                    c=c.to(device)
                    g=denorm(G(c))
                    for i,n in enumerate(names):
                        transforms.ToPILImage()(g[i].cpu()).save(PRED_DIR/n,quality=95)
            fid=compute_fid(GT_DIR,PRED_DIR)
            print(f"→ FID @ epoch {epoch}: {fid:.2f}")
            if fid<best_fid:
                best_fid=fid
                torch.save({"G":G.state_dict(),"D":D.state_dict(),
                            "epoch":epoch,"fid":fid},CKPT_DIR/"best_model.pt")
                print(f"New best FID {fid:.2f} @ epoch {epoch}")

In [12]:
from tqdm.auto import tqdm
for f in tqdm(sorted(DATA_VAL.glob("*.jpg")),desc="Extract GT photos"):
    im=Image.open(f).convert("RGB").resize((512,256),Image.BICUBIC)
    w,h=im.size; photo=im.crop((0,0,w//2,h)).resize((256,256),Image.BICUBIC)
    photo.save(GT_DIR/f.name,quality=95)
print("GT count:",len(list(GT_DIR.glob("*.jpg"))))

Extract GT photos: 100%|██████████| 500/500 [00:09<00:00, 55.16it/s]

GT count: 500





In [16]:
# load from last or start fresh
ckpt = CKPT_DIR/"epoch_200.pt"
if ckpt.exists():
    c=torch.load(ckpt,map_location=device)
    G.load_state_dict(c["G"]); D.load_state_dict(c["D"])
    print("Resumed from",ckpt)
else:
    print("Training from scratch")

train_spade(train_loader,val_loader)

Training from scratch


Epoch 1/250: 100%|██████████| 371/371 [00:50<00:00,  7.28it/s, D=0.675, G=16.90, L1=4.93]


Epoch 1: G=18.474 D=1.073 L1=5.213


Epoch 2/250: 100%|██████████| 371/371 [00:50<00:00,  7.31it/s, D=1.142, G=18.22, L1=5.16]


Epoch 2: G=18.343 D=1.077 L1=5.164


Epoch 3/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=1.103, G=20.44, L1=6.19]


Epoch 3: G=18.154 D=1.050 L1=5.032


Epoch 4/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=1.286, G=17.60, L1=4.63]


Epoch 4: G=18.087 D=1.031 L1=4.977


Epoch 5/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.925, G=18.65, L1=5.32]


Epoch 5: G=17.920 D=1.045 L1=4.870


Epoch 6/250: 100%|██████████| 371/371 [00:50<00:00,  7.38it/s, D=0.778, G=14.00, L1=3.41]


Epoch 6: G=17.601 D=1.043 L1=4.717


Epoch 7/250: 100%|██████████| 371/371 [00:50<00:00,  7.38it/s, D=0.880, G=19.38, L1=4.98]


Epoch 7: G=17.422 D=1.032 L1=4.603


Epoch 8/250: 100%|██████████| 371/371 [00:50<00:00,  7.38it/s, D=0.754, G=21.22, L1=5.83]


Epoch 8: G=17.316 D=1.035 L1=4.536


Epoch 9/250: 100%|██████████| 371/371 [00:50<00:00,  7.38it/s, D=0.786, G=15.29, L1=3.17]


Epoch 9: G=16.893 D=1.041 L1=4.345


Epoch 10/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=1.347, G=16.75, L1=4.51]

Epoch 10: G=17.066 D=1.032 L1=4.385



Predict val: 100%|██████████| 63/63 [00:04<00:00, 13.10it/s]
Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:07<00:00, 13.4MB/s]
  img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())).view(height, width, 3)


→ FID @ epoch 10: 120.44
New best FID 120.44 @ epoch 10


Epoch 11/250: 100%|██████████| 371/371 [00:50<00:00,  7.31it/s, D=1.146, G=14.97, L1=3.68]


Epoch 11: G=16.835 D=1.027 L1=4.268


Epoch 12/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=1.320, G=19.62, L1=5.39]


Epoch 12: G=16.685 D=1.050 L1=4.210


Epoch 13/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.862, G=16.67, L1=4.35]


Epoch 13: G=16.537 D=1.041 L1=4.117


Epoch 14/250: 100%|██████████| 371/371 [00:51<00:00,  7.20it/s, D=1.233, G=19.50, L1=5.28]


Epoch 14: G=16.298 D=1.040 L1=4.027


Epoch 15/250: 100%|██████████| 371/371 [00:50<00:00,  7.28it/s, D=1.172, G=15.24, L1=4.03]


Epoch 15: G=16.095 D=1.041 L1=3.922


Epoch 16/250: 100%|██████████| 371/371 [00:50<00:00,  7.37it/s, D=0.908, G=15.91, L1=3.83]


Epoch 16: G=16.052 D=1.053 L1=3.908


Epoch 17/250: 100%|██████████| 371/371 [00:50<00:00,  7.37it/s, D=1.193, G=14.44, L1=3.38]


Epoch 17: G=15.989 D=1.052 L1=3.869


Epoch 18/250: 100%|██████████| 371/371 [00:50<00:00,  7.29it/s, D=1.142, G=12.96, L1=2.78]


Epoch 18: G=15.855 D=1.042 L1=3.809


Epoch 19/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.674, G=17.38, L1=3.97]


Epoch 19: G=15.903 D=1.024 L1=3.796


Epoch 20/250: 100%|██████████| 371/371 [00:50<00:00,  7.31it/s, D=1.250, G=17.44, L1=4.29]

Epoch 20: G=15.789 D=1.025 L1=3.741



Predict val: 100%|██████████| 63/63 [00:04<00:00, 13.23it/s]


→ FID @ epoch 20: 98.35
New best FID 98.35 @ epoch 20


Epoch 21/250: 100%|██████████| 371/371 [00:50<00:00,  7.30it/s, D=1.074, G=16.52, L1=4.07]


Epoch 21: G=15.766 D=1.011 L1=3.726


Epoch 22/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.879, G=18.35, L1=4.45]


Epoch 22: G=15.574 D=1.017 L1=3.661


Epoch 23/250: 100%|██████████| 371/371 [00:50<00:00,  7.37it/s, D=1.149, G=15.21, L1=3.13]


Epoch 23: G=15.545 D=1.014 L1=3.627


Epoch 24/250: 100%|██████████| 371/371 [00:50<00:00,  7.38it/s, D=1.196, G=13.56, L1=2.96]


Epoch 24: G=15.482 D=0.999 L1=3.596


Epoch 25/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.898, G=16.73, L1=4.09]


Epoch 25: G=15.470 D=1.002 L1=3.571


Epoch 26/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.834, G=17.34, L1=4.25]


Epoch 26: G=15.437 D=0.973 L1=3.548


Epoch 27/250: 100%|██████████| 371/371 [00:50<00:00,  7.39it/s, D=0.949, G=14.41, L1=3.07]


Epoch 27: G=15.306 D=0.958 L1=3.488


Epoch 28/250: 100%|██████████| 371/371 [00:50<00:00,  7.37it/s, D=0.763, G=15.58, L1=3.22]


Epoch 28: G=15.403 D=0.900 L1=3.482


Epoch 29/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.764, G=15.73, L1=3.56]


Epoch 29: G=15.439 D=0.869 L1=3.465


Epoch 30/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.545, G=15.20, L1=3.33]

Epoch 30: G=15.538 D=0.802 L1=3.453



Predict val: 100%|██████████| 63/63 [00:04<00:00, 12.95it/s]


→ FID @ epoch 30: 88.81
New best FID 88.81 @ epoch 30


Epoch 31/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.507, G=20.13, L1=4.94]


Epoch 31: G=15.486 D=0.729 L1=3.384


Epoch 32/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.780, G=15.21, L1=3.63]


Epoch 32: G=15.651 D=0.664 L1=3.389


Epoch 33/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=1.156, G=16.24, L1=3.96]


Epoch 33: G=15.830 D=0.652 L1=3.417


Epoch 34/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=1.031, G=17.48, L1=4.15]


Epoch 34: G=15.786 D=0.615 L1=3.394


Epoch 35/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.545, G=16.07, L1=3.75]


Epoch 35: G=15.763 D=0.573 L1=3.356


Epoch 36/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.464, G=14.93, L1=3.44]


Epoch 36: G=15.642 D=0.581 L1=3.308


Epoch 37/250: 100%|██████████| 371/371 [00:50<00:00,  7.37it/s, D=0.487, G=15.34, L1=2.90]


Epoch 37: G=15.721 D=0.548 L1=3.310


Epoch 38/250: 100%|██████████| 371/371 [00:50<00:00,  7.30it/s, D=0.476, G=14.66, L1=2.79]


Epoch 38: G=15.642 D=0.541 L1=3.284


Epoch 39/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.525, G=16.69, L1=3.35]


Epoch 39: G=15.653 D=0.528 L1=3.267


Epoch 40/250: 100%|██████████| 371/371 [00:50<00:00,  7.37it/s, D=0.212, G=14.17, L1=2.66]

Epoch 40: G=15.653 D=0.504 L1=3.259



Predict val: 100%|██████████| 63/63 [00:04<00:00, 13.07it/s]


→ FID @ epoch 40: 85.48
New best FID 85.48 @ epoch 40


Epoch 41/250: 100%|██████████| 371/371 [00:50<00:00,  7.29it/s, D=0.588, G=14.96, L1=3.06]


Epoch 41: G=15.707 D=0.502 L1=3.274


Epoch 42/250: 100%|██████████| 371/371 [00:50<00:00,  7.31it/s, D=0.266, G=12.99, L1=2.73]


Epoch 42: G=15.695 D=0.472 L1=3.251


Epoch 43/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.156, G=15.96, L1=2.81]


Epoch 43: G=15.612 D=0.471 L1=3.211


Epoch 44/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.113, G=14.34, L1=2.73]


Epoch 44: G=15.724 D=0.448 L1=3.240


Epoch 45/250: 100%|██████████| 371/371 [00:50<00:00,  7.37it/s, D=0.395, G=14.30, L1=3.10]


Epoch 45: G=15.730 D=0.475 L1=3.243


Epoch 46/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.376, G=15.14, L1=2.96]


Epoch 46: G=15.629 D=0.444 L1=3.199


Epoch 47/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.744, G=15.05, L1=3.32]


Epoch 47: G=15.655 D=0.424 L1=3.183


Epoch 48/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.172, G=14.37, L1=2.59]


Epoch 48: G=15.535 D=0.407 L1=3.140


Epoch 49/250: 100%|██████████| 371/371 [00:50<00:00,  7.31it/s, D=0.211, G=15.75, L1=3.34]


Epoch 49: G=15.695 D=0.403 L1=3.185


Epoch 50/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.165, G=14.47, L1=2.83]

Epoch 50: G=15.542 D=0.412 L1=3.137





Checkpoint saved: outputs_spade_finetune/checkpoints/epoch_50.pt


Predict val: 100%|██████████| 63/63 [00:05<00:00, 12.59it/s]


→ FID @ epoch 50: 89.02


Epoch 51/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.317, G=16.80, L1=3.46]


Epoch 51: G=15.591 D=0.381 L1=3.135


Epoch 52/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.072, G=13.32, L1=2.25]


Epoch 52: G=15.668 D=0.380 L1=3.168


Epoch 53/250: 100%|██████████| 371/371 [00:50<00:00,  7.29it/s, D=0.702, G=17.86, L1=3.77]


Epoch 53: G=15.602 D=0.391 L1=3.145


Epoch 54/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.172, G=16.23, L1=3.23]


Epoch 54: G=15.532 D=0.375 L1=3.102


Epoch 55/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.180, G=18.36, L1=3.82]


Epoch 55: G=15.572 D=0.394 L1=3.129


Epoch 56/250: 100%|██████████| 371/371 [00:50<00:00,  7.29it/s, D=0.245, G=15.31, L1=3.00]


Epoch 56: G=15.542 D=0.387 L1=3.110


Epoch 57/250: 100%|██████████| 371/371 [00:50<00:00,  7.37it/s, D=0.513, G=14.93, L1=3.13]


Epoch 57: G=15.589 D=0.329 L1=3.094


Epoch 58/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.143, G=14.23, L1=2.54]


Epoch 58: G=15.635 D=0.353 L1=3.128


Epoch 59/250: 100%|██████████| 371/371 [00:50<00:00,  7.30it/s, D=0.141, G=16.08, L1=2.84]


Epoch 59: G=15.556 D=0.334 L1=3.070


Epoch 60/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.177, G=14.89, L1=2.87]

Epoch 60: G=15.592 D=0.346 L1=3.105



Predict val: 100%|██████████| 63/63 [00:04<00:00, 13.08it/s]


→ FID @ epoch 60: 90.32


Epoch 61/250: 100%|██████████| 371/371 [00:51<00:00,  7.24it/s, D=0.190, G=15.44, L1=2.98]


Epoch 61: G=15.713 D=0.301 L1=3.123


Epoch 62/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.285, G=17.09, L1=3.42]


Epoch 62: G=15.572 D=0.337 L1=3.094


Epoch 63/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.135, G=16.39, L1=3.32]


Epoch 63: G=15.605 D=0.297 L1=3.072


Epoch 64/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.402, G=14.62, L1=3.04]


Epoch 64: G=15.588 D=0.322 L1=3.069


Epoch 65/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.125, G=15.45, L1=2.99]


Epoch 65: G=15.568 D=0.290 L1=3.060


Epoch 66/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.438, G=14.20, L1=2.31]


Epoch 66: G=15.553 D=0.288 L1=3.055


Epoch 67/250: 100%|██████████| 371/371 [00:50<00:00,  7.37it/s, D=0.246, G=14.87, L1=2.66]


Epoch 67: G=15.561 D=0.261 L1=3.028


Epoch 68/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.039, G=13.73, L1=2.46]


Epoch 68: G=15.457 D=0.285 L1=3.012


Epoch 69/250: 100%|██████████| 371/371 [00:50<00:00,  7.38it/s, D=0.207, G=16.82, L1=3.60]


Epoch 69: G=15.558 D=0.270 L1=3.047


Epoch 70/250: 100%|██████████| 371/371 [00:50<00:00,  7.37it/s, D=0.133, G=15.25, L1=2.98]

Epoch 70: G=15.480 D=0.266 L1=3.020



Predict val: 100%|██████████| 63/63 [00:04<00:00, 13.02it/s]


→ FID @ epoch 70: 85.45
New best FID 85.45 @ epoch 70


Epoch 71/250: 100%|██████████| 371/371 [00:50<00:00,  7.29it/s, D=0.372, G=14.32, L1=2.51]


Epoch 71: G=15.387 D=0.277 L1=2.991


Epoch 72/250: 100%|██████████| 371/371 [00:50<00:00,  7.28it/s, D=0.113, G=11.97, L1=1.95]


Epoch 72: G=15.463 D=0.252 L1=3.005


Epoch 73/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.346, G=15.97, L1=3.03]


Epoch 73: G=15.518 D=0.220 L1=3.002


Epoch 74/250: 100%|██████████| 371/371 [00:50<00:00,  7.29it/s, D=0.243, G=12.95, L1=1.96]


Epoch 74: G=15.505 D=0.213 L1=2.983


Epoch 75/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.428, G=16.37, L1=3.37]


Epoch 75: G=15.606 D=0.221 L1=3.025


Epoch 76/250: 100%|██████████| 371/371 [00:51<00:00,  7.26it/s, D=0.376, G=16.63, L1=2.94]


Epoch 76: G=15.371 D=0.271 L1=2.968


Epoch 77/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.081, G=15.47, L1=3.05]


Epoch 77: G=15.400 D=0.205 L1=2.958


Epoch 78/250: 100%|██████████| 371/371 [00:50<00:00,  7.30it/s, D=0.200, G=13.69, L1=2.60]


Epoch 78: G=15.359 D=0.229 L1=2.958


Epoch 79/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.079, G=15.59, L1=3.05]


Epoch 79: G=15.444 D=0.213 L1=2.982


Epoch 80/250: 100%|██████████| 371/371 [00:50<00:00,  7.31it/s, D=0.201, G=16.03, L1=2.97]

Epoch 80: G=15.374 D=0.258 L1=2.964



Predict val: 100%|██████████| 63/63 [00:04<00:00, 12.99it/s]


→ FID @ epoch 80: 89.49


Epoch 81/250: 100%|██████████| 371/371 [00:51<00:00,  7.24it/s, D=0.181, G=15.54, L1=2.52]


Epoch 81: G=15.412 D=0.183 L1=2.958


Epoch 82/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.344, G=16.85, L1=3.25]


Epoch 82: G=15.359 D=0.200 L1=2.952


Epoch 83/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.150, G=15.39, L1=3.10]


Epoch 83: G=15.499 D=0.234 L1=3.000


Epoch 84/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.057, G=14.17, L1=2.41]


Epoch 84: G=15.331 D=0.190 L1=2.925


Epoch 85/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.066, G=15.52, L1=3.01]


Epoch 85: G=15.436 D=0.156 L1=2.946


Epoch 86/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.044, G=15.85, L1=2.86]


Epoch 86: G=15.212 D=0.232 L1=2.910


Epoch 87/250: 100%|██████████| 371/371 [00:50<00:00,  7.40it/s, D=0.123, G=14.02, L1=2.66]


Epoch 87: G=15.324 D=0.180 L1=2.920


Epoch 88/250: 100%|██████████| 371/371 [00:50<00:00,  7.40it/s, D=0.245, G=16.65, L1=3.25]


Epoch 88: G=15.373 D=0.174 L1=2.927


Epoch 89/250: 100%|██████████| 371/371 [00:50<00:00,  7.40it/s, D=0.145, G=15.70, L1=3.00]


Epoch 89: G=15.382 D=0.191 L1=2.939


Epoch 90/250: 100%|██████████| 371/371 [00:50<00:00,  7.40it/s, D=0.258, G=13.42, L1=2.75]

Epoch 90: G=15.212 D=0.219 L1=2.907



Predict val: 100%|██████████| 63/63 [00:04<00:00, 13.12it/s]


→ FID @ epoch 90: 91.65


Epoch 91/250: 100%|██████████| 371/371 [00:50<00:00,  7.37it/s, D=0.344, G=18.99, L1=3.83]


Epoch 91: G=15.330 D=0.171 L1=2.916


Epoch 92/250: 100%|██████████| 371/371 [00:50<00:00,  7.40it/s, D=0.008, G=12.81, L1=2.01]


Epoch 92: G=15.318 D=0.149 L1=2.891


Epoch 93/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.133, G=15.90, L1=3.15]


Epoch 93: G=15.326 D=0.190 L1=2.917


Epoch 94/250: 100%|██████████| 371/371 [00:50<00:00,  7.39it/s, D=0.257, G=13.15, L1=2.16]


Epoch 94: G=15.323 D=0.152 L1=2.901


Epoch 95/250: 100%|██████████| 371/371 [00:50<00:00,  7.31it/s, D=0.159, G=15.30, L1=3.04]


Epoch 95: G=15.328 D=0.157 L1=2.902


Epoch 96/250: 100%|██████████| 371/371 [00:50<00:00,  7.38it/s, D=0.194, G=13.63, L1=2.93]


Epoch 96: G=15.266 D=0.174 L1=2.886


Epoch 97/250: 100%|██████████| 371/371 [00:50<00:00,  7.28it/s, D=0.192, G=16.00, L1=3.16]


Epoch 97: G=15.297 D=0.149 L1=2.883


Epoch 98/250: 100%|██████████| 371/371 [00:50<00:00,  7.30it/s, D=0.121, G=11.73, L1=1.99]


Epoch 98: G=15.176 D=0.193 L1=2.864


Epoch 99/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.114, G=13.08, L1=2.51]


Epoch 99: G=15.257 D=0.128 L1=2.865


Epoch 100/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.168, G=17.15, L1=3.25]

Epoch 100: G=15.233 D=0.147 L1=2.854





Checkpoint saved: outputs_spade_finetune/checkpoints/epoch_100.pt


Predict val: 100%|██████████| 63/63 [00:05<00:00, 12.47it/s]


→ FID @ epoch 100: 91.05


Epoch 101/250: 100%|██████████| 371/371 [00:50<00:00,  7.31it/s, D=0.407, G=14.06, L1=2.69]


Epoch 101: G=15.282 D=0.174 L1=2.876


Epoch 102/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.035, G=14.51, L1=2.68]


Epoch 102: G=15.227 D=0.136 L1=2.849


Epoch 103/250: 100%|██████████| 371/371 [00:51<00:00,  7.26it/s, D=0.055, G=12.38, L1=2.09]


Epoch 103: G=15.190 D=0.141 L1=2.848


Epoch 104/250: 100%|██████████| 371/371 [00:51<00:00,  7.25it/s, D=0.079, G=16.82, L1=3.25]


Epoch 104: G=15.223 D=0.144 L1=2.848


Epoch 105/250: 100%|██████████| 371/371 [00:50<00:00,  7.31it/s, D=0.032, G=15.89, L1=2.88]


Epoch 105: G=15.208 D=0.158 L1=2.848


Epoch 106/250: 100%|██████████| 371/371 [00:50<00:00,  7.28it/s, D=0.043, G=12.87, L1=2.04]


Epoch 106: G=15.135 D=0.162 L1=2.844


Epoch 107/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.121, G=13.75, L1=2.39]


Epoch 107: G=15.189 D=0.129 L1=2.835


Epoch 108/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.009, G=16.37, L1=3.02]


Epoch 108: G=15.110 D=0.174 L1=2.827


Epoch 109/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.350, G=16.70, L1=3.58]


Epoch 109: G=15.280 D=0.116 L1=2.859


Epoch 110/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.089, G=16.35, L1=3.14]

Epoch 110: G=15.171 D=0.164 L1=2.842



Predict val: 100%|██████████| 63/63 [00:04<00:00, 12.88it/s]


→ FID @ epoch 110: 96.98


Epoch 111/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.038, G=17.45, L1=3.57]


Epoch 111: G=15.143 D=0.137 L1=2.812


Epoch 112/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.073, G=16.70, L1=3.01]


Epoch 112: G=15.163 D=0.093 L1=2.811


Epoch 113/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.014, G=15.72, L1=2.97]


Epoch 113: G=15.140 D=0.164 L1=2.826


Epoch 114/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.340, G=16.36, L1=2.90]


Epoch 114: G=15.094 D=0.124 L1=2.802


Epoch 115/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.100, G=16.68, L1=3.16]


Epoch 115: G=15.056 D=0.143 L1=2.797


Epoch 116/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.482, G=13.30, L1=2.72]


Epoch 116: G=15.095 D=0.111 L1=2.799


Epoch 117/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.097, G=17.83, L1=3.46]


Epoch 117: G=15.195 D=0.140 L1=2.835


Epoch 118/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.073, G=14.51, L1=2.63]


Epoch 118: G=15.171 D=0.104 L1=2.807


Epoch 119/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.073, G=14.15, L1=2.45]


Epoch 119: G=15.095 D=0.136 L1=2.792


Epoch 120/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.048, G=15.07, L1=2.79]

Epoch 120: G=15.062 D=0.134 L1=2.796



Predict val: 100%|██████████| 63/63 [00:04<00:00, 13.42it/s]


→ FID @ epoch 120: 102.39


Epoch 121/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.005, G=12.56, L1=1.94]


Epoch 121: G=15.188 D=0.134 L1=2.825


Epoch 122/250: 100%|██████████| 371/371 [00:50<00:00,  7.28it/s, D=0.014, G=17.67, L1=3.58]


Epoch 122: G=15.103 D=0.090 L1=2.787


Epoch 123/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.194, G=15.94, L1=3.39]


Epoch 123: G=15.030 D=0.125 L1=2.772


Epoch 124/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.022, G=14.70, L1=2.50]


Epoch 124: G=15.069 D=0.113 L1=2.789


Epoch 125/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.011, G=14.49, L1=2.38]


Epoch 125: G=14.883 D=0.168 L1=2.751


Epoch 126/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.026, G=12.61, L1=2.18]


Epoch 126: G=15.201 D=0.103 L1=2.815


Epoch 127/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.358, G=14.21, L1=2.32]


Epoch 127: G=15.143 D=0.106 L1=2.792


Epoch 128/250: 100%|██████████| 371/371 [00:51<00:00,  7.26it/s, D=0.010, G=12.89, L1=1.97]


Epoch 128: G=15.046 D=0.108 L1=2.768


Epoch 129/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.049, G=17.13, L1=3.36]


Epoch 129: G=15.076 D=0.124 L1=2.782


Epoch 130/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.122, G=15.81, L1=2.85]

Epoch 130: G=15.000 D=0.112 L1=2.754



Predict val: 100%|██████████| 63/63 [00:04<00:00, 13.26it/s]


→ FID @ epoch 130: 98.82


Epoch 131/250: 100%|██████████| 371/371 [00:50<00:00,  7.31it/s, D=0.046, G=15.66, L1=3.05]


Epoch 131: G=15.038 D=0.094 L1=2.749


Epoch 132/250: 100%|██████████| 371/371 [00:50<00:00,  7.37it/s, D=0.046, G=14.55, L1=2.86]


Epoch 132: G=15.013 D=0.123 L1=2.771


Epoch 133/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.041, G=15.99, L1=3.06]


Epoch 133: G=14.955 D=0.128 L1=2.753


Epoch 134/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.206, G=15.04, L1=2.60]


Epoch 134: G=14.893 D=0.112 L1=2.718


Epoch 135/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.008, G=14.50, L1=2.36]


Epoch 135: G=14.963 D=0.099 L1=2.737


Epoch 136/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.035, G=13.95, L1=2.52]


Epoch 136: G=14.914 D=0.097 L1=2.720


Epoch 137/250: 100%|██████████| 371/371 [00:50<00:00,  7.31it/s, D=0.095, G=17.24, L1=3.22]


Epoch 137: G=14.979 D=0.100 L1=2.734


Epoch 138/250: 100%|██████████| 371/371 [00:50<00:00,  7.29it/s, D=0.069, G=14.09, L1=2.58]


Epoch 138: G=14.892 D=0.147 L1=2.746


Epoch 139/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.055, G=16.63, L1=3.09]


Epoch 139: G=14.979 D=0.082 L1=2.729


Epoch 140/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.086, G=14.93, L1=2.86]

Epoch 140: G=14.811 D=0.127 L1=2.713



Predict val: 100%|██████████| 63/63 [00:04<00:00, 12.88it/s]


→ FID @ epoch 140: 97.35


Epoch 141/250: 100%|██████████| 371/371 [00:50<00:00,  7.29it/s, D=0.139, G=15.83, L1=2.71]


Epoch 141: G=14.949 D=0.086 L1=2.729


Epoch 142/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.041, G=16.38, L1=3.35]


Epoch 142: G=14.877 D=0.128 L1=2.729


Epoch 143/250: 100%|██████████| 371/371 [00:50<00:00,  7.37it/s, D=0.012, G=12.43, L1=1.83]


Epoch 143: G=14.984 D=0.077 L1=2.727


Epoch 144/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.016, G=13.88, L1=2.32]


Epoch 144: G=14.886 D=0.128 L1=2.727


Epoch 145/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.015, G=13.92, L1=2.27]


Epoch 145: G=14.951 D=0.109 L1=2.737


Epoch 146/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.120, G=14.78, L1=2.86]


Epoch 146: G=14.983 D=0.097 L1=2.747


Epoch 147/250: 100%|██████████| 371/371 [00:50<00:00,  7.31it/s, D=0.591, G=14.72, L1=3.13]


Epoch 147: G=14.777 D=0.097 L1=2.677


Epoch 148/250: 100%|██████████| 371/371 [00:50<00:00,  7.28it/s, D=0.096, G=15.84, L1=3.04]


Epoch 148: G=14.917 D=0.099 L1=2.724


Epoch 149/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.029, G=15.04, L1=2.77]


Epoch 149: G=14.956 D=0.095 L1=2.733


Epoch 150/250: 100%|██████████| 371/371 [00:50<00:00,  7.29it/s, D=0.207, G=14.86, L1=3.09]

Epoch 150: G=14.771 D=0.102 L1=2.680





Checkpoint saved: outputs_spade_finetune/checkpoints/epoch_150.pt


Predict val: 100%|██████████| 63/63 [00:05<00:00, 12.49it/s]


→ FID @ epoch 150: 99.02


Epoch 151/250: 100%|██████████| 371/371 [00:50<00:00,  7.28it/s, D=0.193, G=13.30, L1=1.95]


Epoch 151: G=14.851 D=0.121 L1=2.707


Epoch 152/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.005, G=15.92, L1=2.74]


Epoch 152: G=14.980 D=0.083 L1=2.734


Epoch 153/250: 100%|██████████| 371/371 [00:50<00:00,  7.31it/s, D=0.035, G=16.53, L1=3.28]


Epoch 153: G=14.895 D=0.083 L1=2.696


Epoch 154/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.149, G=14.46, L1=2.37]


Epoch 154: G=14.866 D=0.118 L1=2.711


Epoch 155/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.097, G=12.94, L1=1.98]


Epoch 155: G=14.788 D=0.096 L1=2.688


Epoch 156/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.061, G=12.94, L1=2.24]


Epoch 156: G=14.798 D=0.095 L1=2.687


Epoch 157/250: 100%|██████████| 371/371 [00:51<00:00,  7.26it/s, D=0.030, G=14.82, L1=2.77]


Epoch 157: G=14.896 D=0.068 L1=2.703


Epoch 158/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.147, G=14.78, L1=2.37]


Epoch 158: G=14.707 D=0.113 L1=2.676


Epoch 159/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.091, G=13.77, L1=2.54]


Epoch 159: G=14.830 D=0.113 L1=2.715


Epoch 160/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.035, G=16.40, L1=3.04]

Epoch 160: G=14.674 D=0.106 L1=2.647



Predict val: 100%|██████████| 63/63 [00:04<00:00, 13.22it/s]


→ FID @ epoch 160: 100.25


Epoch 161/250: 100%|██████████| 371/371 [00:50<00:00,  7.31it/s, D=0.060, G=15.26, L1=2.81]


Epoch 161: G=14.916 D=0.076 L1=2.709


Epoch 162/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.140, G=14.55, L1=2.38]


Epoch 162: G=14.731 D=0.115 L1=2.668


Epoch 163/250: 100%|██████████| 371/371 [00:50<00:00,  7.31it/s, D=0.002, G=13.55, L1=2.04]


Epoch 163: G=14.782 D=0.086 L1=2.675


Epoch 164/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.013, G=15.34, L1=2.75]


Epoch 164: G=14.733 D=0.089 L1=2.664


Epoch 165/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.096, G=15.86, L1=3.03]


Epoch 165: G=14.769 D=0.078 L1=2.660


Epoch 166/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.055, G=14.38, L1=2.79]


Epoch 166: G=14.784 D=0.073 L1=2.658


Epoch 167/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.142, G=14.24, L1=2.24]


Epoch 167: G=14.711 D=0.151 L1=2.692


Epoch 168/250: 100%|██████████| 371/371 [00:50<00:00,  7.38it/s, D=0.058, G=13.02, L1=2.35]


Epoch 168: G=14.759 D=0.087 L1=2.671


Epoch 169/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.020, G=15.80, L1=3.11]


Epoch 169: G=14.833 D=0.069 L1=2.672


Epoch 170/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.038, G=16.29, L1=3.09]

Epoch 170: G=14.771 D=0.087 L1=2.664



Predict val: 100%|██████████| 63/63 [00:04<00:00, 13.14it/s]


→ FID @ epoch 170: 102.35


Epoch 171/250: 100%|██████████| 371/371 [00:50<00:00,  7.30it/s, D=0.042, G=14.80, L1=2.57]


Epoch 171: G=14.722 D=0.119 L1=2.664


Epoch 172/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.008, G=12.33, L1=1.90]


Epoch 172: G=14.889 D=0.062 L1=2.691


Epoch 173/250: 100%|██████████| 371/371 [00:50<00:00,  7.37it/s, D=0.275, G=14.35, L1=2.28]


Epoch 173: G=14.686 D=0.090 L1=2.636


Epoch 174/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.039, G=16.05, L1=3.08]


Epoch 174: G=14.666 D=0.108 L1=2.651


Epoch 175/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.067, G=15.74, L1=2.80]


Epoch 175: G=14.706 D=0.072 L1=2.637


Epoch 176/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.034, G=14.05, L1=2.39]


Epoch 176: G=14.654 D=0.101 L1=2.631


Epoch 177/250: 100%|██████████| 371/371 [00:51<00:00,  7.22it/s, D=0.097, G=14.27, L1=2.44]


Epoch 177: G=14.699 D=0.083 L1=2.629


Epoch 178/250: 100%|██████████| 371/371 [00:51<00:00,  7.26it/s, D=0.240, G=14.05, L1=2.86]


Epoch 178: G=14.749 D=0.090 L1=2.647


Epoch 179/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.009, G=13.44, L1=2.29]


Epoch 179: G=14.713 D=0.067 L1=2.638


Epoch 180/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.160, G=15.64, L1=2.85]

Epoch 180: G=14.511 D=0.158 L1=2.620



Predict val: 100%|██████████| 63/63 [00:04<00:00, 13.02it/s]


→ FID @ epoch 180: 109.77


Epoch 181/250: 100%|██████████| 371/371 [00:51<00:00,  7.21it/s, D=0.008, G=14.88, L1=2.69]


Epoch 181: G=14.640 D=0.059 L1=2.634


Epoch 182/250: 100%|██████████| 371/371 [00:51<00:00,  7.27it/s, D=0.083, G=15.84, L1=2.98]


Epoch 182: G=14.602 D=0.107 L1=2.635


Epoch 183/250: 100%|██████████| 371/371 [00:50<00:00,  7.30it/s, D=0.062, G=15.67, L1=2.87]


Epoch 183: G=14.730 D=0.046 L1=2.632


Epoch 184/250: 100%|██████████| 371/371 [00:50<00:00,  7.29it/s, D=0.062, G=12.45, L1=2.08]


Epoch 184: G=14.620 D=0.105 L1=2.627


Epoch 185/250: 100%|██████████| 371/371 [00:50<00:00,  7.37it/s, D=0.139, G=15.58, L1=2.80]


Epoch 185: G=14.711 D=0.055 L1=2.644


Epoch 186/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.320, G=15.97, L1=2.72]


Epoch 186: G=14.577 D=0.114 L1=2.624


Epoch 187/250: 100%|██████████| 371/371 [00:51<00:00,  7.19it/s, D=0.178, G=12.77, L1=2.18]


Epoch 187: G=14.609 D=0.079 L1=2.617


Epoch 188/250: 100%|██████████| 371/371 [00:51<00:00,  7.24it/s, D=0.010, G=15.13, L1=2.85]


Epoch 188: G=14.664 D=0.077 L1=2.627


Epoch 189/250: 100%|██████████| 371/371 [00:51<00:00,  7.27it/s, D=0.055, G=12.89, L1=2.26]


Epoch 189: G=14.598 D=0.085 L1=2.619


Epoch 190/250: 100%|██████████| 371/371 [00:50<00:00,  7.37it/s, D=0.033, G=14.11, L1=2.56]

Epoch 190: G=14.703 D=0.071 L1=2.643



Predict val: 100%|██████████| 63/63 [00:04<00:00, 13.10it/s]


→ FID @ epoch 190: 103.19


Epoch 191/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.024, G=12.73, L1=1.87]


Epoch 191: G=14.569 D=0.083 L1=2.603


Epoch 192/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.105, G=14.70, L1=2.50]


Epoch 192: G=14.561 D=0.079 L1=2.596


Epoch 193/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.007, G=12.01, L1=1.83]


Epoch 193: G=14.509 D=0.110 L1=2.609


Epoch 194/250: 100%|██████████| 371/371 [00:50<00:00,  7.31it/s, D=0.002, G=15.38, L1=2.65]


Epoch 194: G=14.536 D=0.068 L1=2.595


Epoch 195/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.009, G=15.81, L1=3.01]


Epoch 195: G=14.646 D=0.054 L1=2.614


Epoch 196/250: 100%|██████████| 371/371 [00:50<00:00,  7.37it/s, D=0.074, G=14.82, L1=2.91]


Epoch 196: G=14.456 D=0.111 L1=2.599


Epoch 197/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.014, G=14.82, L1=2.79]


Epoch 197: G=14.584 D=0.068 L1=2.601


Epoch 198/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.043, G=13.86, L1=2.36]


Epoch 198: G=14.685 D=0.076 L1=2.650


Epoch 199/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.035, G=13.57, L1=2.28]


Epoch 199: G=14.567 D=0.082 L1=2.609


Epoch 200/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.018, G=13.62, L1=2.26]

Epoch 200: G=14.552 D=0.069 L1=2.592





Checkpoint saved: outputs_spade_finetune/checkpoints/epoch_200.pt


Predict val: 100%|██████████| 63/63 [00:05<00:00, 12.34it/s]


→ FID @ epoch 200: 97.85


Epoch 201/250: 100%|██████████| 371/371 [00:50<00:00,  7.28it/s, D=0.016, G=15.13, L1=2.91]


Epoch 201: G=14.644 D=0.063 L1=2.612


Epoch 202/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.058, G=15.57, L1=2.90]


Epoch 202: G=14.533 D=0.101 L1=2.605


Epoch 203/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.057, G=13.36, L1=2.35]


Epoch 203: G=14.555 D=0.078 L1=2.597


Epoch 204/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.001, G=13.48, L1=2.24]


Epoch 204: G=14.657 D=0.061 L1=2.623


Epoch 205/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.012, G=15.41, L1=2.82]


Epoch 205: G=14.482 D=0.118 L1=2.594


Epoch 206/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=1.529, G=13.50, L1=2.50]


Epoch 206: G=14.509 D=0.087 L1=2.579


Epoch 207/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.028, G=12.51, L1=2.08]


Epoch 207: G=14.414 D=0.116 L1=2.604


Epoch 208/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.024, G=13.24, L1=2.07]


Epoch 208: G=14.443 D=0.036 L1=2.552


Epoch 209/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.012, G=12.95, L1=2.11]


Epoch 209: G=14.460 D=0.096 L1=2.583


Epoch 210/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.019, G=14.87, L1=2.84]

Epoch 210: G=14.598 D=0.074 L1=2.605



Predict val: 100%|██████████| 63/63 [00:04<00:00, 13.02it/s]


→ FID @ epoch 210: 105.23


Epoch 211/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.035, G=17.04, L1=3.31]


Epoch 211: G=14.463 D=0.075 L1=2.574


Epoch 212/250: 100%|██████████| 371/371 [00:50<00:00,  7.36it/s, D=0.012, G=13.58, L1=2.30]


Epoch 212: G=14.535 D=0.048 L1=2.581


Epoch 213/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.044, G=12.24, L1=1.81]


Epoch 213: G=14.402 D=0.120 L1=2.581


Epoch 214/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.007, G=14.51, L1=2.27]


Epoch 214: G=14.447 D=0.069 L1=2.571


Epoch 215/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.003, G=14.79, L1=2.70]


Epoch 215: G=14.486 D=0.052 L1=2.573


Epoch 216/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.064, G=15.64, L1=2.96]


Epoch 216: G=14.403 D=0.105 L1=2.573


Epoch 217/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.028, G=13.03, L1=2.05]


Epoch 217: G=14.462 D=0.093 L1=2.586


Epoch 218/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.085, G=15.80, L1=3.10]


Epoch 218: G=14.517 D=0.044 L1=2.579


Epoch 219/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.386, G=13.07, L1=2.35]


Epoch 219: G=14.447 D=0.082 L1=2.580


Epoch 220/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.026, G=13.25, L1=2.31]

Epoch 220: G=14.512 D=0.083 L1=2.596



Predict val: 100%|██████████| 63/63 [00:04<00:00, 12.97it/s]


→ FID @ epoch 220: 107.87


Epoch 221/250: 100%|██████████| 371/371 [00:50<00:00,  7.28it/s, D=0.238, G=16.63, L1=3.10]


Epoch 221: G=14.724 D=0.052 L1=2.633


Epoch 222/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.015, G=14.38, L1=2.68]


Epoch 222: G=14.476 D=0.068 L1=2.572


Epoch 223/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.010, G=14.28, L1=2.53]


Epoch 223: G=14.007 D=0.292 L1=2.555


Epoch 224/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.002, G=13.78, L1=2.32]


Epoch 224: G=14.370 D=0.078 L1=2.552


Epoch 225/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.001, G=15.16, L1=2.72]


Epoch 225: G=14.584 D=0.068 L1=2.612


Epoch 226/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.078, G=12.62, L1=2.13]


Epoch 226: G=14.321 D=0.104 L1=2.539


Epoch 227/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.001, G=16.66, L1=3.13]


Epoch 227: G=14.347 D=0.032 L1=2.534


Epoch 228/250: 100%|██████████| 371/371 [00:50<00:00,  7.34it/s, D=0.051, G=13.77, L1=2.27]


Epoch 228: G=14.227 D=0.094 L1=2.524


Epoch 229/250: 100%|██████████| 371/371 [00:51<00:00,  7.27it/s, D=0.032, G=12.72, L1=1.96]


Epoch 229: G=14.542 D=0.051 L1=2.584


Epoch 230/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.010, G=17.32, L1=3.47]

Epoch 230: G=14.397 D=0.062 L1=2.541



Predict val: 100%|██████████| 63/63 [00:04<00:00, 13.05it/s]


→ FID @ epoch 230: 109.24


Epoch 231/250: 100%|██████████| 371/371 [00:50<00:00,  7.32it/s, D=0.014, G=15.21, L1=2.80]


Epoch 231: G=14.262 D=0.113 L1=2.538


Epoch 232/250: 100%|██████████| 371/371 [00:50<00:00,  7.39it/s, D=0.003, G=13.77, L1=2.39]


Epoch 232: G=14.398 D=0.089 L1=2.563


Epoch 233/250: 100%|██████████| 371/371 [00:50<00:00,  7.39it/s, D=0.036, G=13.12, L1=2.53]


Epoch 233: G=14.330 D=0.056 L1=2.522


Epoch 234/250: 100%|██████████| 371/371 [00:50<00:00,  7.39it/s, D=0.022, G=14.82, L1=2.67]


Epoch 234: G=14.422 D=0.045 L1=2.546


Epoch 235/250: 100%|██████████| 371/371 [00:50<00:00,  7.40it/s, D=0.141, G=13.74, L1=2.52]


Epoch 235: G=14.265 D=0.096 L1=2.529


Epoch 236/250: 100%|██████████| 371/371 [00:50<00:00,  7.38it/s, D=0.013, G=14.03, L1=2.36]


Epoch 236: G=14.395 D=0.077 L1=2.563


Epoch 237/250: 100%|██████████| 371/371 [00:50<00:00,  7.39it/s, D=0.034, G=13.40, L1=2.15]


Epoch 237: G=14.377 D=0.083 L1=2.560


Epoch 238/250: 100%|██████████| 371/371 [00:50<00:00,  7.39it/s, D=0.001, G=13.51, L1=2.12]


Epoch 238: G=14.352 D=0.049 L1=2.528


Epoch 239/250: 100%|██████████| 371/371 [00:50<00:00,  7.40it/s, D=0.007, G=15.51, L1=3.02]


Epoch 239: G=14.396 D=0.066 L1=2.556


Epoch 240/250: 100%|██████████| 371/371 [00:50<00:00,  7.39it/s, D=0.017, G=16.74, L1=3.23]

Epoch 240: G=14.380 D=0.062 L1=2.542



Predict val: 100%|██████████| 63/63 [00:04<00:00, 13.25it/s]


→ FID @ epoch 240: 113.24


Epoch 241/250: 100%|██████████| 371/371 [00:51<00:00,  7.25it/s, D=0.047, G=12.04, L1=1.96]


Epoch 241: G=14.412 D=0.039 L1=2.539


Epoch 242/250: 100%|██████████| 371/371 [00:50<00:00,  7.28it/s, D=0.034, G=13.82, L1=2.11]


Epoch 242: G=14.202 D=0.136 L1=2.527


Epoch 243/250: 100%|██████████| 371/371 [00:50<00:00,  7.28it/s, D=0.035, G=14.65, L1=2.66]


Epoch 243: G=14.306 D=0.058 L1=2.530


Epoch 244/250: 100%|██████████| 371/371 [00:51<00:00,  7.27it/s, D=0.023, G=14.82, L1=2.90]


Epoch 244: G=14.260 D=0.092 L1=2.525


Epoch 245/250: 100%|██████████| 371/371 [00:50<00:00,  7.30it/s, D=0.029, G=13.33, L1=2.28]


Epoch 245: G=14.378 D=0.038 L1=2.533


Epoch 246/250: 100%|██████████| 371/371 [00:51<00:00,  7.27it/s, D=0.003, G=13.98, L1=2.41]


Epoch 246: G=14.299 D=0.080 L1=2.535


Epoch 247/250: 100%|██████████| 371/371 [00:51<00:00,  7.27it/s, D=0.010, G=11.50, L1=1.64]


Epoch 247: G=14.334 D=0.059 L1=2.534


Epoch 248/250: 100%|██████████| 371/371 [00:50<00:00,  7.28it/s, D=0.004, G=15.80, L1=2.79]


Epoch 248: G=14.261 D=0.080 L1=2.527


Epoch 249/250: 100%|██████████| 371/371 [00:50<00:00,  7.33it/s, D=0.014, G=14.00, L1=2.29]


Epoch 249: G=14.405 D=0.074 L1=2.564


Epoch 250/250: 100%|██████████| 371/371 [00:50<00:00,  7.35it/s, D=0.006, G=14.77, L1=2.50]

Epoch 250: G=14.340 D=0.033 L1=2.525





Checkpoint saved: outputs_spade_finetune/checkpoints/epoch_250.pt


Predict val: 100%|██████████| 63/63 [00:05<00:00, 12.52it/s]


→ FID @ epoch 250: 107.63


In [18]:
best=torch.load(CKPT_DIR/"best_model.pt",map_location=device)
G.load_state_dict(best["G"])
print(f"Loaded best model epoch {best['epoch']} FID {best['fid']:.2f}")

# regenerate predictions + compute FID
for p in PRED_DIR.glob("*.jpg"): p.unlink()
G.eval()
with torch.no_grad():
    for c,_,names in tqdm(val_loader,desc="Predict final"):
        c=c.to(device)
        g=denorm(G(c))
        for i,n in enumerate(names):
            transforms.ToPILImage()(g[i].cpu()).save(PRED_DIR/n,quality=95)
fid=compute_fid(GT_DIR,PRED_DIR)
print("Final FID:",round(fid,2))

Loaded best model epoch 70 FID 85.45


Predict final: 100%|██████████| 63/63 [00:04<00:00, 12.61it/s]


Final FID: 85.45
